Skip to content

Commit e097def

Browse files
committed
support for pydantic2
1 parent 7bf89ad commit e097def

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed

fastapi_code_generator/__main__.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from typing import Any, Dict, List, Optional
66

77
import typer
8-
from datamodel_code_generator import LiteralType, PythonVersion, chdir
8+
from datamodel_code_generator import LiteralType, PythonVersion, chdir, DataModelType
9+
from datamodel_code_generator.model import get_data_model_types
910
from datamodel_code_generator.format import CodeFormatter
1011
from datamodel_code_generator.imports import Import, Imports
1112
from datamodel_code_generator.reference import Reference
@@ -57,6 +58,10 @@ def main(
5758
None, "--custom-visitor", "-c"
5859
),
5960
disable_timestamp: bool = typer.Option(False, "--disable-timestamp"),
61+
output_model_type: DataModelType = typer.Option(
62+
DataModelType.PydanticBaseModel, "--data-model-type", "-d"),
63+
python_version: PythonVersion = typer.Option(
64+
PythonVersion.PY_38, "--python-version", "-p"),
6065
) -> None:
6166
input_name: str = input_file
6267
input_text: str
@@ -69,31 +74,20 @@ def main(
6974
else:
7075
model_path = MODEL_PATH
7176

72-
if enum_field_as_literal:
73-
return generate_code(
74-
input_name,
75-
input_text,
76-
encoding,
77-
output_dir,
78-
template_dir,
79-
model_path,
80-
enum_field_as_literal, # type: ignore[arg-type]
81-
custom_visitors=custom_visitors,
82-
disable_timestamp=disable_timestamp,
83-
generate_routers=generate_routers,
84-
specify_tags=specify_tags,
85-
)
8677
return generate_code(
8778
input_name,
8879
input_text,
8980
encoding,
9081
output_dir,
9182
template_dir,
9283
model_path,
84+
enum_field_as_literal=enum_field_as_literal or None,
9385
custom_visitors=custom_visitors,
9486
disable_timestamp=disable_timestamp,
9587
generate_routers=generate_routers,
9688
specify_tags=specify_tags,
89+
output_model_type=output_model_type,
90+
python_version=python_version,
9791
)
9892

9993

@@ -119,6 +113,8 @@ def generate_code(
119113
disable_timestamp: bool = False,
120114
generate_routers: Optional[bool] = None,
121115
specify_tags: Optional[str] = None,
116+
output_model_type: DataModelType = DataModelType.PydanticBaseModel,
117+
python_version: PythonVersion = PythonVersion.PY_38,
122118
) -> None:
123119
if not model_path:
124120
model_path = MODEL_PATH
@@ -130,10 +126,18 @@ def generate_code(
130126
template_dir = (
131127
BUILTIN_MODULAR_TEMPLATE_DIR if generate_routers else BUILTIN_TEMPLATE_DIR
132128
)
133-
if enum_field_as_literal:
134-
parser = OpenAPIParser(input_text, enum_field_as_literal=enum_field_as_literal) # type: ignore[arg-type]
135-
else:
136-
parser = OpenAPIParser(input_text)
129+
130+
data_model_types = get_data_model_types(output_model_type, python_version)
131+
132+
parser = OpenAPIParser(input_text,
133+
enum_field_as_literal=enum_field_as_literal,
134+
data_model_type=data_model_types.data_model,
135+
data_model_root_type=data_model_types.root_model,
136+
data_model_field_type=data_model_types.field_model,
137+
data_type_manager_type=data_model_types.data_type_manager,
138+
dump_resolve_reference_action=data_model_types.dump_resolve_reference_action,
139+
)
140+
137141
with chdir(output_dir):
138142
models = parser.parse()
139143
output = output_dir / model_path
@@ -153,7 +157,7 @@ def generate_code(
153157
)
154158

155159
results: Dict[Path, str] = {}
156-
code_formatter = CodeFormatter(PythonVersion.PY_38, Path().resolve())
160+
code_formatter = CodeFormatter(python_version, Path().resolve())
157161

158162
template_vars: Dict[str, object] = {"info": parser.parse_info()}
159163
visitors: List[Visitor] = []

0 commit comments

Comments
 (0)