diff --git a/fastapi_code_generator/__main__.py b/fastapi_code_generator/__main__.py index 4d3e661..64c1191 100644 --- a/fastapi_code_generator/__main__.py +++ b/fastapi_code_generator/__main__.py @@ -7,7 +7,6 @@ import typer from datamodel_code_generator import DataModelType, LiteralType, PythonVersion, chdir from datamodel_code_generator.format import CodeFormatter -from datamodel_code_generator.imports import Import, Imports from datamodel_code_generator.model import get_data_model_types from datamodel_code_generator.reference import Reference from datamodel_code_generator.types import DataType @@ -65,6 +64,11 @@ def main( python_version: PythonVersion = typer.Option( PythonVersion.PY_39.value, "--python-version", "-p" ), + strict_nullable: bool = typer.Option( + False, + "--strict-nullable", + help="Strictly follow nullable attribute in OpenAPI spec", + ), ) -> None: input_name: str = input_file input_text: str @@ -89,6 +93,7 @@ def main( specify_tags=specify_tags, output_model_type=output_model_type, python_version=python_version, + strict_nullable=strict_nullable, ) @@ -117,6 +122,7 @@ def generate_code( specify_tags: Optional[str] = None, output_model_type: DataModelType = DataModelType.PydanticBaseModel, python_version: PythonVersion = PythonVersion.PY_39, + strict_nullable: bool = False, ) -> None: if not model_path: model_path = MODEL_PATH @@ -142,6 +148,7 @@ def generate_code( dump_resolve_reference_action=data_model_types.dump_resolve_reference_action, custom_template_dir=model_template_dir, target_python_version=python_version, + strict_nullable=strict_nullable, ) with chdir(output_dir): diff --git a/tests/data/expected/openapi/default_template/nullable_test/main.py b/tests/data/expected/openapi/default_template/nullable_test/main.py new file mode 100644 index 0000000..6cb589a --- /dev/null +++ b/tests/data/expected/openapi/default_template/nullable_test/main.py @@ -0,0 +1,24 @@ +# generated by fastapi-codegen: +# filename: nullable_test.yaml +# timestamp: 2020-06-19T00:00:00+00:00 + +from __future__ import annotations + +from fastapi import FastAPI + +from .models import User + +app = FastAPI( + version='1.0.0', + title='Nullable Test API', + description='API for testing nullable field behavior', + servers=[{'url': 'http://api.example.com/v1'}], +) + + +@app.get('/users', response_model=User) +def get_user_details() -> User: + """ + Get user details + """ + pass diff --git a/tests/data/expected/openapi/default_template/nullable_test/models.py b/tests/data/expected/openapi/default_template/nullable_test/models.py new file mode 100644 index 0000000..bd232f2 --- /dev/null +++ b/tests/data/expected/openapi/default_template/nullable_test/models.py @@ -0,0 +1,21 @@ +# generated by fastapi-codegen: +# filename: nullable_test.yaml +# timestamp: 2020-06-19T00:00:00+00:00 + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel, Field + + +class User(BaseModel): + id: int + username: str + email: Optional[str] = Field( + None, description="User's email address (explicitly nullable)" + ) + phone: str = Field(..., description="User's phone number (explicitly non-nullable)") + nickname: Optional[str] = Field( + None, description="User's nickname (implicitly nullable)" + ) diff --git a/tests/data/expected/openapi/default_template/nullable_test_strict/main.py b/tests/data/expected/openapi/default_template/nullable_test_strict/main.py new file mode 100644 index 0000000..6cb589a --- /dev/null +++ b/tests/data/expected/openapi/default_template/nullable_test_strict/main.py @@ -0,0 +1,24 @@ +# generated by fastapi-codegen: +# filename: nullable_test.yaml +# timestamp: 2020-06-19T00:00:00+00:00 + +from __future__ import annotations + +from fastapi import FastAPI + +from .models import User + +app = FastAPI( + version='1.0.0', + title='Nullable Test API', + description='API for testing nullable field behavior', + servers=[{'url': 'http://api.example.com/v1'}], +) + + +@app.get('/users', response_model=User) +def get_user_details() -> User: + """ + Get user details + """ + pass diff --git a/tests/data/expected/openapi/default_template/nullable_test_strict/models.py b/tests/data/expected/openapi/default_template/nullable_test_strict/models.py new file mode 100644 index 0000000..bd232f2 --- /dev/null +++ b/tests/data/expected/openapi/default_template/nullable_test_strict/models.py @@ -0,0 +1,21 @@ +# generated by fastapi-codegen: +# filename: nullable_test.yaml +# timestamp: 2020-06-19T00:00:00+00:00 + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel, Field + + +class User(BaseModel): + id: int + username: str + email: Optional[str] = Field( + None, description="User's email address (explicitly nullable)" + ) + phone: str = Field(..., description="User's phone number (explicitly non-nullable)") + nickname: Optional[str] = Field( + None, description="User's nickname (implicitly nullable)" + ) diff --git a/tests/data/openapi/default_template/nullable_test.yaml b/tests/data/openapi/default_template/nullable_test.yaml new file mode 100644 index 0000000..2b24fc8 --- /dev/null +++ b/tests/data/openapi/default_template/nullable_test.yaml @@ -0,0 +1,44 @@ +openapi: "3.0.0" +info: + version: 1.0.0 + title: Nullable Test API + description: API for testing nullable field behavior +servers: + - url: http://api.example.com/v1 +paths: + /users: + get: + summary: Get user details + operationId: getUserDetails + responses: + "200": + description: User details + content: + application/json: + schema: + $ref: "#/components/schemas/User" +components: + schemas: + User: + type: object + required: + - id + - phone + - username + properties: + id: + type: integer + format: int64 + username: + type: string + email: + type: string + nullable: true + description: User's email address (explicitly nullable) + phone: + type: string + nullable: false + description: User's phone number (explicitly non-nullable) + nickname: + type: string + description: User's nickname (implicitly nullable) diff --git a/tests/test_generate.py b/tests/test_generate.py index 0ef13f3..6699cc9 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -203,3 +203,26 @@ def test_generate_modify_specific_routers(oas_file): assert output_inner.read_text() == expected_inner.read_text() else: assert output_file.read_text() == expected_file.read_text(), oas_file + + +@freeze_time("2020-06-19") +def test_generate_nullable_strict(): + oas_file = DATA_DIR / OPEN_API_DEFAULT_TEMPLATE_DIR_NAME / 'nullable_test.yaml' + with TemporaryDirectory() as tmp_dir: + output_dir = Path(tmp_dir) / (oas_file.stem + '_strict') + generate_code( + input_name=oas_file.name, + input_text=oas_file.read_text(), + encoding=ENCODING, + output_dir=output_dir, + template_dir=None, + strict_nullable=True, + ) + expected_dir = ( + EXPECTED_DIR / OPEN_API_DEFAULT_TEMPLATE_DIR_NAME / "nullable_test_strict" + ) + output_files = sorted(list(output_dir.glob("**/*.py"))) + expected_files = sorted(list(expected_dir.glob("**/*.py"))) + assert [f.name for f in output_files] == [f.name for f in expected_files] + for output_file, expected_file in zip(output_files, expected_files): + assert output_file.read_text() == expected_file.read_text()