55from typing import Any , Dict , List , Optional
66
77import typer
8- from datamodel_code_generator import LiteralType , PythonVersion , chdir
8+ from datamodel_code_generator import LiteralType , PythonVersion , chdir , DataModelType
9+ from datamodel_code_generator .model import get_data_model_types
910from datamodel_code_generator .format import CodeFormatter
1011from datamodel_code_generator .imports import Import , Imports
1112from datamodel_code_generator .reference import Reference
@@ -57,6 +58,10 @@ def main(
5758 None , "--custom-visitor" , "-c"
5859 ),
5960 disable_timestamp : bool = typer .Option (False , "--disable-timestamp" ),
61+ output_model_type : DataModelType = typer .Option (
62+ DataModelType .PydanticBaseModel , "--data-model-type" , "-d" ),
63+ python_version : PythonVersion = typer .Option (
64+ PythonVersion .PY_38 , "--python-version" , "-p" ),
6065) -> None :
6166 input_name : str = input_file
6267 input_text : str
@@ -69,31 +74,20 @@ def main(
6974 else :
7075 model_path = MODEL_PATH
7176
72- if enum_field_as_literal :
73- return generate_code (
74- input_name ,
75- input_text ,
76- encoding ,
77- output_dir ,
78- template_dir ,
79- model_path ,
80- enum_field_as_literal , # type: ignore[arg-type]
81- custom_visitors = custom_visitors ,
82- disable_timestamp = disable_timestamp ,
83- generate_routers = generate_routers ,
84- specify_tags = specify_tags ,
85- )
8677 return generate_code (
8778 input_name ,
8879 input_text ,
8980 encoding ,
9081 output_dir ,
9182 template_dir ,
9283 model_path ,
84+ enum_field_as_literal = enum_field_as_literal or None ,
9385 custom_visitors = custom_visitors ,
9486 disable_timestamp = disable_timestamp ,
9587 generate_routers = generate_routers ,
9688 specify_tags = specify_tags ,
89+ output_model_type = output_model_type ,
90+ python_version = python_version ,
9791 )
9892
9993
@@ -119,6 +113,8 @@ def generate_code(
119113 disable_timestamp : bool = False ,
120114 generate_routers : Optional [bool ] = None ,
121115 specify_tags : Optional [str ] = None ,
116+ output_model_type : DataModelType = DataModelType .PydanticBaseModel ,
117+ python_version : PythonVersion = PythonVersion .PY_38 ,
122118) -> None :
123119 if not model_path :
124120 model_path = MODEL_PATH
@@ -130,10 +126,18 @@ def generate_code(
130126 template_dir = (
131127 BUILTIN_MODULAR_TEMPLATE_DIR if generate_routers else BUILTIN_TEMPLATE_DIR
132128 )
133- if enum_field_as_literal :
134- parser = OpenAPIParser (input_text , enum_field_as_literal = enum_field_as_literal ) # type: ignore[arg-type]
135- else :
136- parser = OpenAPIParser (input_text )
129+
130+ data_model_types = get_data_model_types (output_model_type , python_version )
131+
132+ parser = OpenAPIParser (input_text ,
133+ enum_field_as_literal = enum_field_as_literal ,
134+ data_model_type = data_model_types .data_model ,
135+ data_model_root_type = data_model_types .root_model ,
136+ data_model_field_type = data_model_types .field_model ,
137+ data_type_manager_type = data_model_types .data_type_manager ,
138+ dump_resolve_reference_action = data_model_types .dump_resolve_reference_action ,
139+ )
140+
137141 with chdir (output_dir ):
138142 models = parser .parse ()
139143 output = output_dir / model_path
@@ -153,7 +157,7 @@ def generate_code(
153157 )
154158
155159 results : Dict [Path , str ] = {}
156- code_formatter = CodeFormatter (PythonVersion . PY_38 , Path ().resolve ())
160+ code_formatter = CodeFormatter (python_version , Path ().resolve ())
157161
158162 template_vars : Dict [str , object ] = {"info" : parser .parse_info ()}
159163 visitors : List [Visitor ] = []
0 commit comments