|
5 | 5 | import sys |
6 | 6 | from collections.abc import Callable |
7 | 7 | from pathlib import Path |
8 | | -from typing import Any, ClassVar, NamedTuple |
| 8 | +from typing import ( |
| 9 | + Annotated, |
| 10 | + Any, |
| 11 | + ClassVar, |
| 12 | + NamedTuple, |
| 13 | + Optional, |
| 14 | + Union, |
| 15 | + get_args, |
| 16 | + get_origin, |
| 17 | +) |
9 | 18 |
|
10 | 19 | import yaml |
11 | 20 | from common_library.json_serialization import json_dumps |
@@ -36,33 +45,65 @@ def __modify_schema__(cls, field_schema: dict[str, Any]) -> None: |
36 | 45 | return _Json |
37 | 46 |
|
38 | 47 |
|
| 48 | +def replace_basemodel_in_annotation(annotation, new_type): |
| 49 | + origin = get_origin(annotation) |
| 50 | + |
| 51 | + # Handle Annotated |
| 52 | + if origin is Annotated: |
| 53 | + args = get_args(annotation) |
| 54 | + base_type = args[0] |
| 55 | + metadata = args[1:] |
| 56 | + if isinstance(base_type, type) and issubclass(base_type, BaseModel): |
| 57 | + # Replace the BaseModel subclass |
| 58 | + base_type = new_type |
| 59 | + |
| 60 | + return Annotated[(base_type, *metadata)] |
| 61 | + |
| 62 | + # Handle Optionals, Unions, or other generic types |
| 63 | + if origin in (Optional, Union, list, dict, tuple): # Extendable for other generics |
| 64 | + new_args = tuple( |
| 65 | + replace_basemodel_in_annotation(arg, new_type) |
| 66 | + for arg in get_args(annotation) |
| 67 | + ) |
| 68 | + return origin[new_args] |
| 69 | + |
| 70 | + # Replace BaseModel subclass directly |
| 71 | + if isinstance(annotation, type) and issubclass(annotation, BaseModel): |
| 72 | + return new_type |
| 73 | + |
| 74 | + # Return as-is if no changes |
| 75 | + return annotation |
| 76 | + |
| 77 | + |
39 | 78 | def as_query(model_class: type[BaseModel]) -> type[BaseModel]: |
40 | 79 | fields = {} |
41 | 80 | for field_name, field_info in model_class.model_fields.items(): |
42 | 81 |
|
43 | | - field_type = get_type(field_info) |
44 | | - default_value = field_info.default |
45 | | - |
46 | | - kwargs = { |
| 82 | + query_kwargs = { |
| 83 | + "default": field_info.default, |
47 | 84 | "alias": field_info.alias, |
48 | 85 | "title": field_info.title, |
49 | 86 | "description": field_info.description, |
50 | 87 | "metadata": field_info.metadata, |
51 | | - "json_schema_extra": field_info.json_schema_extra, |
| 88 | + "json_schema_extra": field_info.json_schema_extra or {}, |
52 | 89 | } |
53 | 90 |
|
54 | | - if issubclass(field_type, BaseModel): |
55 | | - # Complex fields |
56 | | - assert "json_schema_extra" in kwargs # nosec |
57 | | - assert kwargs["json_schema_extra"] # nosec |
58 | | - field_type = _create_json_type( |
59 | | - description=kwargs["description"], |
60 | | - example=kwargs.get("json_schema_extra", {}).get("example_json"), |
61 | | - ) |
| 91 | + json_field_type = _create_json_type( |
| 92 | + description=query_kwargs["description"], |
| 93 | + example=query_kwargs.get("json_schema_extra", {}).get("example_json"), |
| 94 | + ) |
62 | 95 |
|
63 | | - default_value = json_dumps(default_value) if default_value else None |
| 96 | + annotation = replace_basemodel_in_annotation( |
| 97 | + field_info.annotation, new_type=json_field_type |
| 98 | + ) |
| 99 | + |
| 100 | + if annotation != field_info.annotation: |
| 101 | + # Complex fields are transformed to Json |
| 102 | + query_kwargs["default"] = ( |
| 103 | + json_dumps(query_kwargs["default"]) if query_kwargs["default"] else None |
| 104 | + ) |
64 | 105 |
|
65 | | - fields[field_name] = (field_type, Query(default=default_value, **kwargs)) |
| 106 | + fields[field_name] = (annotation, Query(**query_kwargs)) |
66 | 107 |
|
67 | 108 | new_model_name = f"{model_class.__name__}Query" |
68 | 109 | return create_model(new_model_name, **fields) |
|
0 commit comments