Skip to content

Commit 46745f9

Browse files
authored
core: Use parametric tests in test_openai_tools (#31839)
1 parent 181c22c commit 46745f9

File tree

1 file changed

+65
-64
lines changed

1 file changed

+65
-64
lines changed

libs/core/tests/unit_tests/output_parsers/test_openai_tools.py

Lines changed: 65 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -386,92 +386,93 @@ async def input_iter(_: Any) -> AsyncIterator[BaseMessage]:
386386
return input_iter
387387

388388

389-
def test_partial_json_output_parser() -> None:
390-
for use_tool_calls in [False, True]:
391-
input_iter = _get_iter(use_tool_calls=use_tool_calls)
392-
chain = input_iter | JsonOutputToolsParser()
393-
394-
actual = list(chain.stream(None))
395-
expected: list = [[]] + [
396-
[{"type": "NameCollector", "args": chunk}]
397-
for chunk in EXPECTED_STREAMED_JSON
398-
]
399-
assert actual == expected
389+
@pytest.mark.parametrize("use_tool_calls", [False, True])
390+
def test_partial_json_output_parser(*, use_tool_calls: bool) -> None:
391+
input_iter = _get_iter(use_tool_calls=use_tool_calls)
392+
chain = input_iter | JsonOutputToolsParser()
393+
394+
actual = list(chain.stream(None))
395+
expected: list = [[]] + [
396+
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
397+
]
398+
assert actual == expected
400399

401400

402-
async def test_partial_json_output_parser_async() -> None:
403-
for use_tool_calls in [False, True]:
404-
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
405-
chain = input_iter | JsonOutputToolsParser()
401+
@pytest.mark.parametrize("use_tool_calls", [False, True])
402+
async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None:
403+
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
404+
chain = input_iter | JsonOutputToolsParser()
406405

407-
actual = [p async for p in chain.astream(None)]
408-
expected: list = [[]] + [
409-
[{"type": "NameCollector", "args": chunk}]
410-
for chunk in EXPECTED_STREAMED_JSON
411-
]
412-
assert actual == expected
406+
actual = [p async for p in chain.astream(None)]
407+
expected: list = [[]] + [
408+
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
409+
]
410+
assert actual == expected
413411

414412

415-
def test_partial_json_output_parser_return_id() -> None:
416-
for use_tool_calls in [False, True]:
417-
input_iter = _get_iter(use_tool_calls=use_tool_calls)
418-
chain = input_iter | JsonOutputToolsParser(return_id=True)
413+
@pytest.mark.parametrize("use_tool_calls", [False, True])
414+
def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
415+
input_iter = _get_iter(use_tool_calls=use_tool_calls)
416+
chain = input_iter | JsonOutputToolsParser(return_id=True)
419417

420-
actual = list(chain.stream(None))
421-
expected: list = [[]] + [
422-
[
423-
{
424-
"type": "NameCollector",
425-
"args": chunk,
426-
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
427-
}
428-
]
429-
for chunk in EXPECTED_STREAMED_JSON
418+
actual = list(chain.stream(None))
419+
expected: list = [[]] + [
420+
[
421+
{
422+
"type": "NameCollector",
423+
"args": chunk,
424+
"id": "call_OwL7f5PEPJTYzw9sQlNJtCZl",
425+
}
430426
]
431-
assert actual == expected
427+
for chunk in EXPECTED_STREAMED_JSON
428+
]
429+
assert actual == expected
432430

433431

434-
def test_partial_json_output_key_parser() -> None:
435-
for use_tool_calls in [False, True]:
436-
input_iter = _get_iter(use_tool_calls=use_tool_calls)
437-
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
432+
@pytest.mark.parametrize("use_tool_calls", [False, True])
433+
def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
434+
input_iter = _get_iter(use_tool_calls=use_tool_calls)
435+
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
438436

439-
actual = list(chain.stream(None))
440-
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
441-
assert actual == expected
437+
actual = list(chain.stream(None))
438+
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
439+
assert actual == expected
442440

443441

444-
async def test_partial_json_output_parser_key_async() -> None:
445-
for use_tool_calls in [False, True]:
446-
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
442+
@pytest.mark.parametrize("use_tool_calls", [False, True])
443+
async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) -> None:
444+
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
447445

448-
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
446+
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
449447

450-
actual = [p async for p in chain.astream(None)]
451-
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
452-
assert actual == expected
448+
actual = [p async for p in chain.astream(None)]
449+
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
450+
assert actual == expected
453451

454452

455-
def test_partial_json_output_key_parser_first_only() -> None:
456-
for use_tool_calls in [False, True]:
457-
input_iter = _get_iter(use_tool_calls=use_tool_calls)
453+
@pytest.mark.parametrize("use_tool_calls", [False, True])
454+
def test_partial_json_output_key_parser_first_only(*, use_tool_calls: bool) -> None:
455+
input_iter = _get_iter(use_tool_calls=use_tool_calls)
458456

459-
chain = input_iter | JsonOutputKeyToolsParser(
460-
key_name="NameCollector", first_tool_only=True
461-
)
457+
chain = input_iter | JsonOutputKeyToolsParser(
458+
key_name="NameCollector", first_tool_only=True
459+
)
462460

463-
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
461+
assert list(chain.stream(None)) == EXPECTED_STREAMED_JSON
464462

465463

466-
async def test_partial_json_output_parser_key_async_first_only() -> None:
467-
for use_tool_calls in [False, True]:
468-
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
464+
@pytest.mark.parametrize("use_tool_calls", [False, True])
465+
async def test_partial_json_output_parser_key_async_first_only(
466+
*,
467+
use_tool_calls: bool,
468+
) -> None:
469+
input_iter = _get_aiter(use_tool_calls=use_tool_calls)
469470

470-
chain = input_iter | JsonOutputKeyToolsParser(
471-
key_name="NameCollector", first_tool_only=True
472-
)
471+
chain = input_iter | JsonOutputKeyToolsParser(
472+
key_name="NameCollector", first_tool_only=True
473+
)
473474

474-
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
475+
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
475476

476477

477478
class Person(BaseModel):

0 commit comments

Comments
 (0)