|
7 | 7 | from dataclasses import dataclass, field
|
8 | 8 | from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
|
9 | 9 |
|
10 |
| -from pydantic import TypeAdapter, ValidationError |
| 10 | +from pydantic import Json, TypeAdapter, ValidationError |
11 | 11 | from pydantic_core import SchemaValidator, to_json
|
12 | 12 | from typing_extensions import Self, TypedDict, TypeVar, assert_never
|
13 | 13 |
|
@@ -624,21 +624,33 @@ def __init__(
|
624 | 624 | json_schema = self._function_schema.json_schema
|
625 | 625 | json_schema['description'] = self._function_schema.description
|
626 | 626 | else:
|
627 |
| - type_adapter: TypeAdapter[Any] |
| 627 | + json_schema_type_adapter: TypeAdapter[Any] |
| 628 | + validation_type_adapter: TypeAdapter[Any] |
628 | 629 | if _utils.is_model_like(output):
|
629 |
| - type_adapter = TypeAdapter(output) |
| 630 | + json_schema_type_adapter = validation_type_adapter = TypeAdapter(output) |
630 | 631 | else:
|
631 | 632 | self.outer_typed_dict_key = 'response'
|
| 633 | + output_type: type[OutputDataT] = cast(type[OutputDataT], output) |
| 634 | + |
632 | 635 | response_data_typed_dict = TypedDict( # noqa: UP013
|
633 | 636 | 'response_data_typed_dict',
|
634 |
| - {'response': cast(type[OutputDataT], output)}, # pyright: ignore[reportInvalidTypeForm] |
| 637 | + {'response': output_type}, # pyright: ignore[reportInvalidTypeForm] |
| 638 | + ) |
| 639 | + json_schema_type_adapter = TypeAdapter(response_data_typed_dict) |
| 640 | + |
| 641 | + # More lenient validator: allow either the native type or a JSON string containing it |
| 642 | + # i.e. `response: OutputDataT | Json[OutputDataT]`, as some models don't follow the schema correctly, |
| 643 | + # e.g. `BedrockConverseModel('us.meta.llama3-2-11b-instruct-v1:0')` |
| 644 | + response_validation_typed_dict = TypedDict( # noqa: UP013 |
| 645 | + 'response_validation_typed_dict', |
| 646 | + {'response': output_type | Json[output_type]}, # pyright: ignore[reportInvalidTypeForm] |
635 | 647 | )
|
636 |
| - type_adapter = TypeAdapter(response_data_typed_dict) |
| 648 | + validation_type_adapter = TypeAdapter(response_validation_typed_dict) |
637 | 649 |
|
638 | 650 | # Really a PluggableSchemaValidator, but it's API-compatible
|
639 |
| - self.validator = cast(SchemaValidator, type_adapter.validator) |
| 651 | + self.validator = cast(SchemaValidator, validation_type_adapter.validator) |
640 | 652 | json_schema = _utils.check_object_json_schema(
|
641 |
| - type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) |
| 653 | + json_schema_type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) |
642 | 654 | )
|
643 | 655 |
|
644 | 656 | if self.outer_typed_dict_key:
|
|
0 commit comments