Skip to content

Commit c7e82ad

Browse files
authored
core: Use parametrized test in test_correct_get_tracer_project (#31513)
1 parent 8a0782c commit c7e82ad

File tree

1 file changed

+51
-57
lines changed

1 file changed

+51
-57
lines changed

libs/core/tests/unit_tests/tracers/test_langchain.py

Lines changed: 51 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -94,60 +94,54 @@ def test_log_lock() -> None:
9494
tracer.wait_for_futures()
9595

9696

97-
class LangChainProjectNameTest(unittest.TestCase):
98-
"""Test that the project name is set correctly for runs."""
99-
100-
class SetProperTracerProjectTestCase:
101-
def __init__(
102-
self, test_name: str, envvars: dict[str, str], expected_project_name: str
103-
):
104-
self.test_name = test_name
105-
self.envvars = envvars
106-
self.expected_project_name = expected_project_name
107-
108-
def test_correct_get_tracer_project(self) -> None:
109-
cases = [
110-
self.SetProperTracerProjectTestCase(
111-
test_name="default to 'default' when no project provided",
112-
envvars={},
113-
expected_project_name="default",
114-
),
115-
self.SetProperTracerProjectTestCase(
116-
test_name="use session_name for legacy tracers",
117-
envvars={"LANGCHAIN_SESSION": "old_timey_session"},
118-
expected_project_name="old_timey_session",
119-
),
120-
self.SetProperTracerProjectTestCase(
121-
test_name="use LANGCHAIN_PROJECT over SESSION_NAME",
122-
envvars={
123-
"LANGCHAIN_SESSION": "old_timey_session",
124-
"LANGCHAIN_PROJECT": "modern_session",
125-
},
126-
expected_project_name="modern_session",
127-
),
128-
]
129-
130-
for case in cases:
131-
get_env_var.cache_clear()
132-
get_tracer_project.cache_clear()
133-
with self.subTest(msg=case.test_name), pytest.MonkeyPatch.context() as mp:
134-
for k, v in case.envvars.items():
135-
mp.setenv(k, v)
136-
137-
client = unittest.mock.MagicMock(spec=Client)
138-
tracer = LangChainTracer(client=client)
139-
projects = []
140-
141-
def mock_create_run(**kwargs: Any) -> Any:
142-
projects.append(kwargs.get("session_name")) # noqa: B023
143-
return unittest.mock.MagicMock()
144-
145-
client.create_run = mock_create_run
146-
147-
tracer.on_llm_start(
148-
{"name": "example_1"},
149-
["foo"],
150-
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
151-
)
152-
tracer.wait_for_futures()
153-
assert projects == [case.expected_project_name]
97+
@pytest.mark.parametrize(
98+
("envvars", "expected_project_name"),
99+
[
100+
(
101+
{},
102+
"default",
103+
),
104+
(
105+
{"LANGCHAIN_SESSION": "old_timey_session"},
106+
"old_timey_session",
107+
),
108+
(
109+
{
110+
"LANGCHAIN_SESSION": "old_timey_session",
111+
"LANGCHAIN_PROJECT": "modern_session",
112+
},
113+
"modern_session",
114+
),
115+
],
116+
ids=[
117+
"default to 'default' when no project provided",
118+
"use session_name for legacy tracers",
119+
"use LANGCHAIN_PROJECT over SESSION_NAME",
120+
],
121+
)
122+
def test_correct_get_tracer_project(
123+
envvars: dict[str, str], expected_project_name: str
124+
) -> None:
125+
get_env_var.cache_clear()
126+
get_tracer_project.cache_clear()
127+
with pytest.MonkeyPatch.context() as mp:
128+
for k, v in envvars.items():
129+
mp.setenv(k, v)
130+
131+
client = unittest.mock.MagicMock(spec=Client)
132+
tracer = LangChainTracer(client=client)
133+
projects = []
134+
135+
def mock_create_run(**kwargs: Any) -> Any:
136+
projects.append(kwargs.get("session_name"))
137+
return unittest.mock.MagicMock()
138+
139+
client.create_run = mock_create_run
140+
141+
tracer.on_llm_start(
142+
{"name": "example_1"},
143+
["foo"],
144+
run_id=UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a"),
145+
)
146+
tracer.wait_for_futures()
147+
assert projects == [expected_project_name]

0 commit comments

Comments
 (0)