Skip to content

Commit 111014b

Browse files
committed
Satisfy mypy with either version of Pydantic
1 parent e4a5e14 commit 111014b

26 files changed

+335
-195
lines changed

openapi_pydantic/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,6 @@
3333
from .v3 import ServerVariable as ServerVariable
3434
from .v3 import Tag as Tag
3535
from .v3 import parse_obj as parse_obj
36+
from .v3 import schema_validate as schema_validate
3637

3738
logging.getLogger(__name__).addHandler(logging.NullHandler())

openapi_pydantic/compat.py

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Compatibility layer to make this package usable with Pydantic 1 or 2"""
22

3+
from typing import TYPE_CHECKING
4+
35
from pydantic.version import VERSION as PYDANTIC_VERSION
46

57
PYDANTIC_MAJOR_VERSION = int(PYDANTIC_VERSION.split(".", 1)[0])
@@ -9,45 +11,107 @@
911
else:
1012
PYDANTIC_V2 = False
1113

12-
if PYDANTIC_V2:
13-
from typing import Literal
14+
if TYPE_CHECKING:
15+
# Provide stubs for either version of Pydantic
16+
17+
from enum import Enum
18+
from typing import Any, Literal, Type, TypedDict
19+
20+
from pydantic import BaseModel
21+
from pydantic import ConfigDict as PydanticConfigDict
22+
23+
def ConfigDict(
24+
extra: Literal["allow", "ignore", "forbid"] = "allow",
25+
json_schema_extra: dict[str, Any] | None = None,
26+
populate_by_name: bool = True,
27+
) -> PydanticConfigDict:
28+
...
29+
30+
class Extra(Enum):
31+
allow = "allow"
32+
ignore = "ignore"
33+
forbid = "forbid"
34+
35+
class RootModel(BaseModel):
36+
...
37+
38+
JsonSchemaMode = Literal["validation", "serialization"]
39+
40+
def models_json_schema(
41+
models: list[tuple[Type[BaseModel], JsonSchemaMode]],
42+
*,
43+
by_alias: bool = True,
44+
ref_template: str = "#/$defs/{model}",
45+
) -> tuple[dict, dict[str, Any]]:
46+
...
47+
48+
def v1_schema(
49+
models: list[Type[BaseModel]],
50+
*,
51+
by_alias: bool = True,
52+
ref_prefix: str = "#/$defs",
53+
) -> dict[str, Any]:
54+
...
55+
56+
DEFS_KEY = "$defs"
1457

15-
from pydantic import ConfigDict
16-
from pydantic.json_schema import JsonSchemaMode, models_json_schema # type: ignore
58+
class MinLengthArg(TypedDict):
59+
pass
60+
61+
def min_length_arg(min_length: int) -> MinLengthArg:
62+
...
63+
64+
elif PYDANTIC_V2:
65+
from typing import Literal, TypedDict
66+
67+
from pydantic import ConfigDict, RootModel
68+
from pydantic.json_schema import JsonSchemaMode, models_json_schema
1769

1870
# Pydantic 2 renders JSON schemas using the keyword "$defs"
1971
DEFS_KEY = "$defs"
2072

21-
# Add V1 stubs to this module, but hide them from typing
22-
globals().update(
23-
{
24-
"Extra": None,
25-
"v1_schema": None,
26-
}
27-
)
73+
class MinLengthArg(TypedDict):
74+
min_length: int
75+
76+
def min_length_arg(min_length: int) -> MinLengthArg:
77+
return {"min_length": min_length}
78+
79+
# Create V1 stubs
80+
Extra = None
81+
v1_schema = None
82+
2883

2984
else:
85+
from typing import TypedDict
86+
3087
from pydantic import Extra
3188
from pydantic.schema import schema as v1_schema
3289

3390
# Pydantic 1 renders JSON schemas using the keyword "definitions"
3491
DEFS_KEY = "definitions"
3592

36-
# Add V2 stubs to this module, but hide them from typing
37-
globals().update(
38-
{
39-
"ConfigDict": None,
40-
"Literal": None,
41-
"models_json_schema": None,
42-
"JsonSchemaMode": None,
43-
}
44-
)
93+
class MinLengthArg(TypedDict):
94+
min_items: int
95+
96+
def min_length_arg(min_length: int) -> MinLengthArg:
97+
return {"min_items": min_length}
98+
99+
# Create V2 stubs
100+
ConfigDict = None
101+
Literal = None
102+
models_json_schema = None
103+
JsonSchemaMode = None
104+
RootModel = None
105+
45106

46107
__all__ = [
47108
"Literal",
48109
"ConfigDict",
49110
"JsonSchemaMode",
50111
"models_json_schema",
112+
"RootModel",
51113
"Extra",
52114
"v1_schema",
115+
"DEFS_KEY",
116+
"min_length_arg",
53117
]

openapi_pydantic/util.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Any, Generic, List, Optional, Set, Type, TypeVar
2+
from typing import Any, Generic, List, Optional, Set, Type, TypeVar, cast
33

44
from pydantic import BaseModel
55

@@ -11,7 +11,7 @@
1111
v1_schema,
1212
)
1313

14-
from . import Components, OpenAPI, Reference, Schema
14+
from . import Components, OpenAPI, Reference, Schema, schema_validate
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -32,13 +32,16 @@ def get_mode(
3232
) -> JsonSchemaMode:
3333
"""Get the JSON schema mode for a model class.
3434
35-
The mode can be either "serialization" or "validation". In validation mode,
35+
The mode can be either "validation" or "serialization". In validation mode,
3636
computed fields are dropped and optional fields remain optional. In
3737
serialization mode, computed and optional fields are required.
3838
"""
3939
if not hasattr(cls, "model_config"):
4040
return default
41-
return cls.model_config.get("json_schema_mode", default)
41+
mode = cls.model_config.get("json_schema_mode", default)
42+
if mode not in ("validation", "serialization"):
43+
raise ValueError(f"invalid json_schema_mode: {mode}")
44+
return cast(JsonSchemaMode, mode)
4245

4346

4447
def construct_open_api_with_schema_class(
@@ -62,10 +65,8 @@ def construct_open_api_with_schema_class(
6265
If there is no update in "#/components/schemas" values, the original
6366
`open_api` will be returned.
6467
"""
65-
if PYDANTIC_V2:
66-
new_open_api = open_api.model_copy(deep=True)
67-
else:
68-
new_open_api = open_api.copy(deep=True)
68+
copy_func = getattr(open_api, "model_copy" if PYDANTIC_V2 else "copy")
69+
new_open_api: OpenAPI = copy_func(deep=True)
6970

7071
if scan_for_pydantic_schema_reference:
7172
extracted_schema_classes = _handle_pydantic_schema(new_open_api)
@@ -80,7 +81,7 @@ def construct_open_api_with_schema_class(
8081
return open_api
8182

8283
schema_classes.sort(key=lambda x: x.__name__)
83-
logger.debug(f"schema_classes{schema_classes}")
84+
logger.debug("schema_classes: %s", schema_classes)
8485

8586
# update new_open_api with new #/components/schemas
8687
if PYDANTIC_V2:
@@ -94,7 +95,6 @@ def construct_open_api_with_schema_class(
9495
schema_classes, by_alias=by_alias, ref_prefix=ref_prefix
9596
)
9697

97-
schema_validate = Schema.model_validate if PYDANTIC_V2 else Schema.parse_obj
9898
if not new_open_api.components:
9999
new_open_api.components = Components()
100100
if new_open_api.components.schemas:
@@ -111,6 +111,8 @@ def construct_open_api_with_schema_class(
111111
}
112112
)
113113
else:
114+
for key, schema_dict in schema_definitions[DEFS_KEY].items():
115+
schema_validate(schema_dict)
114116
new_open_api.components.schemas = {
115117
key: schema_validate(schema_dict)
116118
for key, schema_dict in schema_definitions[DEFS_KEY].items()
@@ -136,13 +138,13 @@ def _handle_pydantic_schema(open_api: OpenAPI) -> List[Type[BaseModel]]:
136138

137139
def _traverse(obj: Any) -> None:
138140
if isinstance(obj, BaseModel):
139-
fields = obj.model_fields_set if PYDANTIC_V2 else obj.__fields_set__
141+
fields = getattr(
142+
obj, "model_fields_set" if PYDANTIC_V2 else "__fields_set__"
143+
)
140144
for field in fields:
141145
child_obj = obj.__getattribute__(field)
142146
if isinstance(child_obj, PydanticSchema):
143-
logger.debug(
144-
f"PydanticSchema found in {obj.__repr_name__()}: {child_obj}"
145-
)
147+
logger.debug("PydanticSchema found in %s: %s", obj, child_obj)
146148
obj.__setattr__(field, _construct_ref_obj(child_obj))
147149
pydantic_types.add(child_obj.schema_class)
148150
else:
@@ -169,6 +171,6 @@ def _traverse(obj: Any) -> None:
169171

170172

171173
def _construct_ref_obj(pydantic_schema: PydanticSchema[PydanticType]) -> Reference:
172-
ref_obj = Reference(ref=ref_prefix + pydantic_schema.schema_class.__name__)
174+
ref_obj = Reference(**{"$ref": ref_prefix + pydantic_schema.schema_class.__name__})
173175
logger.debug(f"ref_obj={ref_obj}")
174176
return ref_obj

openapi_pydantic/v3/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@
3131
from .v3_1_0 import Server as Server
3232
from .v3_1_0 import ServerVariable as ServerVariable
3333
from .v3_1_0 import Tag as Tag
34+
from .v3_1_0 import schema_validate as schema_validate

openapi_pydantic/v3/parser.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Union
1+
from typing import TYPE_CHECKING, Any, Union
22

33
from pydantic import BaseModel, Field
44

@@ -9,21 +9,25 @@
99

1010
OpenAPIv3 = Union[OpenAPIv3_1, OpenAPIv3_0]
1111

12-
if PYDANTIC_V2:
12+
if TYPE_CHECKING:
13+
14+
def parse_obj(data: Any) -> OpenAPIv3:
15+
"""Parse a raw object into an OpenAPI model with version inference."""
16+
...
17+
18+
elif PYDANTIC_V2:
1319
from pydantic import RootModel
1420

15-
class _OpenAPIV2(RootModel):
21+
class _OpenAPI(RootModel):
1622
root: OpenAPIv3 = Field(discriminator="openapi")
1723

24+
def parse_obj(data: Any) -> OpenAPIv3:
25+
return _OpenAPI.model_validate(data).root
26+
1827
else:
1928

20-
class _OpenAPIV1(BaseModel):
29+
class _OpenAPI(BaseModel):
2130
__root__: OpenAPIv3 = Field(discriminator="openapi")
2231

23-
24-
def parse_obj(data: Any) -> OpenAPIv3:
25-
"""Parse a raw object into an OpenAPI model with version inference."""
26-
if PYDANTIC_V2:
27-
return _OpenAPIV2.model_validate(data).root
28-
else:
29-
return _OpenAPIV1.parse_obj(data).__root__
32+
def parse_obj(data: Any) -> OpenAPIv3:
33+
return _OpenAPI.parse_obj(data).__root__

openapi_pydantic/v3/v3_0_3/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.3.md#table-of-contents
77
"""
88

9+
from typing import TYPE_CHECKING
10+
911
from openapi_pydantic.compat import PYDANTIC_V2
1012

1113
from .callback import Callback as Callback
@@ -34,20 +36,24 @@
3436
from .response import Response as Response
3537
from .responses import Responses as Responses
3638
from .schema import Schema as Schema
39+
from .schema import schema_validate as schema_validate
3740
from .security_requirement import SecurityRequirement as SecurityRequirement
3841
from .security_scheme import SecurityScheme as SecurityScheme
3942
from .server import Server as Server
4043
from .server_variable import ServerVariable as ServerVariable
4144
from .tag import Tag as Tag
4245
from .xml import XML as XML
4346

44-
# resolve forward references
45-
if PYDANTIC_V2:
47+
if TYPE_CHECKING:
48+
pass
49+
elif PYDANTIC_V2:
50+
# resolve forward references
4651
Encoding.model_rebuild()
4752
OpenAPI.model_rebuild()
4853
Components.model_rebuild()
4954
Operation.model_rebuild()
5055
else:
56+
# resolve forward references
5157
Encoding.update_forward_refs(Header=Header)
5258
Schema.update_forward_refs()
5359
Operation.update_forward_refs(PathItem=PathItem)
Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import TYPE_CHECKING
2+
13
from pydantic import Field
24

35
from openapi_pydantic.compat import PYDANTIC_V2, ConfigDict, Extra, Literal
@@ -12,10 +14,9 @@
1214
]
1315

1416

15-
if PYDANTIC_V2:
16-
LiteralEmptyString = Literal[""]
17+
if TYPE_CHECKING:
1718

18-
class HeaderV2(Parameter):
19+
class Header(Parameter):
1920
"""
2021
The Header Object follows the structure of the
2122
[Parameter Object](#parameterObject) with the following changes:
@@ -27,6 +28,15 @@ class HeaderV2(Parameter):
2728
to a location of `header` (for example, [`style`](#parameterStyle)).
2829
"""
2930

31+
name: str = Field(default="")
32+
param_in: ParameterLocation = Field(
33+
default=ParameterLocation.HEADER, alias="in"
34+
)
35+
36+
elif PYDANTIC_V2:
37+
LiteralEmptyString = Literal[""]
38+
39+
class Header(Parameter):
3040
name: LiteralEmptyString = Field(default="")
3141
param_in: Literal[ParameterLocation.HEADER] = Field(
3242
default=ParameterLocation.HEADER, alias="in"
@@ -38,23 +48,9 @@ class HeaderV2(Parameter):
3848
json_schema_extra={"examples": _examples},
3949
)
4050

41-
Header = HeaderV2
42-
Header.__name__ = "Header"
43-
4451
else:
4552

46-
class HeaderV1(Parameter):
47-
"""
48-
The Header Object follows the structure of the
49-
[Parameter Object](#parameterObject) with the following changes:
50-
51-
1. `name` MUST NOT be specified, it is given in the corresponding
52-
`headers` map.
53-
2. `in` MUST NOT be specified, it is implicitly in `header`.
54-
3. All traits that are affected by the location MUST be applicable
55-
to a location of `header` (for example, [`style`](#parameterStyle)).
56-
"""
57-
53+
class Header(Parameter):
5854
name: str = Field(default="", const=True)
5955
param_in: ParameterLocation = Field(
6056
default=ParameterLocation.HEADER, const=True, alias="in"
@@ -64,6 +60,3 @@ class Config:
6460
extra = Extra.allow
6561
allow_population_by_field_name = True
6662
schema_extra = {"examples": _examples}
67-
68-
Header = HeaderV1
69-
Header.__name__ = "Header"

0 commit comments

Comments
 (0)