Skip to content

Commit 69c1c2b

Browse files
committed
feat: fully support pydantic v2 type handling
1 parent 5a8211e commit 69c1c2b

File tree

10 files changed

+487
-31
lines changed

10 files changed

+487
-31
lines changed

src/openapi_python_generator/language_converters/python/common.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import keyword
22
import re
33
from typing import Optional
4+
from openapi_python_generator.common import PydanticVersion
45

56

67
_use_orjson: bool = False
8+
_pydantic_version: PydanticVersion = PydanticVersion.V2
79
_custom_template_path: str = None
810
_symbol_ascii_strip_re = re.compile(r"[^A-Za-z0-9_]")
911

@@ -16,6 +18,13 @@ def set_use_orjson(value: bool) -> None:
1618
global _use_orjson
1719
_use_orjson = value
1820

21+
def set_pydantic_version(value: PydanticVersion) -> None:
22+
"""
23+
Set the value of the global variable
24+
:param value: value of the variable
25+
"""
26+
global _pydantic_version
27+
_pydantic_version = value
1928

2029
def get_use_orjson() -> bool:
2130
"""
@@ -25,6 +34,13 @@ def get_use_orjson() -> bool:
2534
global _use_orjson
2635
return _use_orjson
2736

37+
def get_pydantic_version() -> PydanticVersion:
38+
"""
39+
Get the value of the global variable _pydantic_version.
40+
:return: value of the variable
41+
"""
42+
global _pydantic_version
43+
return _pydantic_version
2844

2945
def set_custom_template_path(value: Optional[str]) -> None:
3046
"""

src/openapi_python_generator/language_converters/python/generator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def generator(
3131

3232
common.set_use_orjson(use_orjson)
3333
common.set_custom_template_path(custom_template_path)
34+
common.set_pydantic_version(pydantic_version)
3435

3536
if data.components is not None:
3637
models = generate_models(data.components, pydantic_version)

src/openapi_python_generator/language_converters/python/model_generator.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,11 @@ def type_converter( # noqa: C901
121121
# With custom string format fields, in order to cast these to strict types (e.g. date, datetime, UUID)
122122
# orjson is required for JSON serialiation.
123123
elif (
124-
schema.type == "string"
125-
and schema.schema_format is not None
126-
and schema.schema_format.startswith("uuid")
127-
and common.get_use_orjson()
124+
schema.type == "string"
125+
and schema.schema_format is not None
126+
and schema.schema_format.startswith("uuid")
127+
# orjson and pydantic v2 both support UUID
128+
and (common.get_use_orjson() or common.get_pydantic_version() == PydanticVersion.V2)
128129
):
129130
if len(schema.schema_format) > 4 and schema.schema_format[4].isnumeric():
130131
uuid_type = schema.schema_format.upper()
@@ -133,13 +134,31 @@ def type_converter( # noqa: C901
133134
else:
134135
converted_type = pre_type + "UUID" + post_type
135136
import_types = ["from uuid import UUID"]
136-
elif schema.type == "string" and schema.schema_format == "date-time" and common.get_use_orjson():
137+
elif (
138+
schema.type == "string"
139+
and schema.schema_format == "date-time"
140+
# orjson and pydantic v2 both support datetime
141+
and (common.get_use_orjson() or common.get_pydantic_version() == PydanticVersion.V2)
142+
):
137143
converted_type = pre_type + "datetime" + post_type
138144
import_types = ["from datetime import datetime"]
139-
elif schema.type == "string" and schema.schema_format == "date" and common.get_use_orjson():
145+
elif (
146+
schema.type == "string"
147+
and schema.schema_format == "date"
148+
# orjson and pydantic v2 both support date
149+
and (common.get_use_orjson() or common.get_pydantic_version() == PydanticVersion.V2)
150+
):
140151
converted_type = pre_type + "date" + post_type
141152
import_types = ["from datetime import date"]
142-
elif schema.type == "string" and schema.schema_format == "decimal" and common.get_use_orjson():
153+
elif (
154+
schema.type == "string"
155+
and schema.schema_format == "decimal"
156+
# orjson does not support Decimal
157+
# See https://github.com/ijl/orjson/issues/444
158+
and not common.get_use_orjson()
159+
# pydantic v2 supports Decimal
160+
and common.get_pydantic_version() == PydanticVersion.V2
161+
):
143162
converted_type = pre_type + "Decimal" + post_type
144163
import_types = ["from decimal import Decimal"]
145164
elif schema.type == "string":

src/openapi_python_generator/language_converters/python/service_generator.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,29 @@
2626

2727
HTTP_OPERATIONS = ["get", "post", "put", "delete", "options", "head", "patch", "trace"]
2828

29+
def _generate_body_dump_expression(data = "data") -> str:
30+
"""
31+
Generate expression for dumping abstract body as a dictionary.
32+
"""
33+
34+
# Use old v1 method for pydantic
35+
if common.get_pydantic_version() == common.PydanticVersion.V1:
36+
return f"{data}.dict()"
37+
38+
# Dump model but allow orjson to serialise (fastest)
39+
if common.get_use_orjson():
40+
return f"{data}.model_dump()"
41+
42+
# rely on pydantic v2 to serialise (slowest, but best compatibility)
43+
return f"{data}.model_dump_json()"
44+
2945

3046
def generate_body_param(operation: Operation) -> Union[str, None]:
3147
if operation.requestBody is None:
3248
return None
3349
else:
3450
if isinstance(operation.requestBody, Reference):
35-
return "data.dict()"
51+
return _generate_body_dump_expression("data")
3652

3753
if operation.requestBody.content is None:
3854
return None # pragma: no cover
@@ -46,11 +62,12 @@ def generate_body_param(operation: Operation) -> Union[str, None]:
4662
return None # pragma: no cover
4763

4864
if isinstance(media_type.media_type_schema, Reference):
49-
return "data.dict()"
65+
return _generate_body_dump_expression("data")
5066
elif isinstance(media_type.media_type_schema, Schema):
5167
schema = media_type.media_type_schema
5268
if schema.type == "array":
53-
return "[i.dict() for i in data]"
69+
expression = _generate_body_dump_expression("i")
70+
return f"[{expression} for i in data]"
5471
elif schema.type == "object":
5572
return "data"
5673
else:

tests/build_test_api/api.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from datetime import datetime
1+
from datetime import datetime, date
22
from typing import List
33
from typing import Optional
44

@@ -20,17 +20,19 @@ class User(BaseModel):
2020
username: str
2121
email: str
2222
password: str
23-
is_active: Optional[bool]
23+
is_active: Optional[bool] = None
24+
created_at: Optional[datetime] = None
25+
birthdate: Optional[date] = None
2426

2527

2628
class Team(BaseModel):
2729
id: int
2830
name: str
2931
description: str
30-
is_active: Optional[bool]
31-
created_at: Optional[datetime]
32-
updated_at: Optional[datetime]
33-
users: Optional[List[User]]
32+
is_active: Optional[bool] = None
33+
created_at: Optional[datetime] = None
34+
updated_at: Optional[datetime] = None
35+
users: Optional[List[User]] = None
3436

3537

3638
@app.get("/", response_model=RootResponse, tags=["general"])

tests/conftest.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
from typing import Dict
55
from typing import Generator
6+
from openapi_python_generator.language_converters.python import common
67

78
import pytest
89
from openapi_pydantic.v3.v3_0 import OpenAPI
@@ -32,3 +33,52 @@ def model_data_with_cleanup_fixture(model_data) -> OpenAPI: # type: ignore
3233
if test_result_path.exists():
3334
# delete folder and all subfolders
3435
shutil.rmtree(test_result_path)
36+
37+
38+
@pytest.fixture
39+
def with_orjson_enabled():
40+
"""
41+
Fixture to enable orjson for the duration of the test
42+
"""
43+
orjson_usage = common.get_use_orjson()
44+
common.set_use_orjson(True)
45+
try:
46+
yield
47+
finally:
48+
common.set_use_orjson(orjson_usage)
49+
50+
@pytest.fixture
51+
def with_orjson_disabled():
52+
"""
53+
Fixture to enable orjson for the duration of the test
54+
"""
55+
orjson_usage = common.get_use_orjson()
56+
common.set_use_orjson(False)
57+
try:
58+
yield
59+
finally:
60+
common.set_use_orjson(orjson_usage)
61+
62+
@pytest.fixture
63+
def with_pydantic_v1():
64+
"""
65+
Fixture to set pydantic to v1 for the duration of the test
66+
"""
67+
pydantic_version = common.get_pydantic_version()
68+
common.set_pydantic_version(common.PydanticVersion.V1)
69+
try:
70+
yield
71+
finally:
72+
common.set_pydantic_version(pydantic_version)
73+
74+
@pytest.fixture
75+
def with_pydantic_v2():
76+
"""
77+
Fixture to set pydantic to v2 for the duration of the test
78+
"""
79+
pydantic_version = common.get_pydantic_version()
80+
common.set_pydantic_version(common.PydanticVersion.V2)
81+
try:
82+
yield
83+
finally:
84+
common.set_pydantic_version(pydantic_version)

tests/test_data/test_api.json

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -532,11 +532,6 @@
532532
"title": "Birthdate",
533533
"type": "string",
534534
"format": "date"
535-
},
536-
"position": {
537-
"title": "Position",
538-
"type": "string",
539-
"format": "decimal"
540535
}
541536
}
542537
},

tests/test_generated_code.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def test_set_auth_token():
6464
],
6565
)
6666
@respx.mock
67-
def test_generate_code(model_data_with_cleanup, library, use_orjson, custom_ip):
67+
def test_generate_code(model_data_with_cleanup, library, use_orjson, custom_ip, with_pydantic_v2):
6868
generate_data(test_data_path, test_result_path, library, use_orjson=use_orjson)
6969
result = generator(model_data_with_cleanup, library_config_dict[library])
7070

@@ -356,8 +356,8 @@ def test_generate_code(model_data_with_cleanup, library, use_orjson, custom_ip):
356356
name="team1",
357357
description="team1",
358358
is_active=True,
359-
created_at="",
360-
updated_at="",
359+
created_at=None,
360+
updated_at=None,
361361
)
362362

363363
exec_code_base = f"from .test_result.services.general_service import *\nfrom datetime import datetime\nresp_result = create_team_teams_post(Team(**{data}), passed_api_config)"

0 commit comments

Comments
 (0)