Skip to content

Commit 757003a

Browse files
committed
fix: additional type fixes
1 parent f6d34ca commit 757003a

File tree

7 files changed

+45
-36
lines changed

7 files changed

+45
-36
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ plugins = [
110110
"pydantic.mypy"
111111
]
112112
warn_unused_ignores = true
113+
exclude = [
114+
"tests/",
115+
"conftest.py",
116+
]
113117

114118
[tool.pydantic-mypy]
115119
init_forbid_extra = true

scim2_models/base.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -161,27 +161,30 @@ def normalize_value(value: Any) -> Any:
161161
return value
162162

163163
normalized_value = normalize_value(value)
164-
return handler(normalized_value)
164+
obj = handler(normalized_value)
165+
assert isinstance(obj, cls)
166+
return obj
165167

166168
@model_validator(mode="wrap")
167169
@classmethod
168170
def check_response_attributes_returnability(
169171
cls, value: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
170172
) -> Self:
171173
"""Check that the fields returnability is expected according to the responses validation context, as defined in :rfc:`RFC7643 §7 <7653#section-7>`."""
172-
value = handler(value)
174+
obj = handler(value)
175+
assert isinstance(obj, cls)
173176

174177
if (
175178
not info.context
176179
or not info.context.get("scim")
177180
or not Context.is_response(info.context["scim"])
178181
):
179-
return value
182+
return obj
180183

181184
for field_name in cls.model_fields:
182185
returnability = cls.get_field_annotation(field_name, Returned)
183186

184-
if returnability == Returned.always and getattr(value, field_name) is None:
187+
if returnability == Returned.always and getattr(obj, field_name) is None:
185188
raise PydanticCustomError(
186189
"returned_error",
187190
"Field '{field_name}' has returnability 'always' but value is missing or null",
@@ -190,10 +193,7 @@ def check_response_attributes_returnability(
190193
},
191194
)
192195

193-
if (
194-
returnability == Returned.never
195-
and getattr(value, field_name) is not None
196-
):
196+
if returnability == Returned.never and getattr(obj, field_name) is not None:
197197
raise PydanticCustomError(
198198
"returned_error",
199199
"Field '{field_name}' has returnability 'never' but value is set",
@@ -202,15 +202,16 @@ def check_response_attributes_returnability(
202202
},
203203
)
204204

205-
return value
205+
return obj
206206

207207
@model_validator(mode="wrap")
208208
@classmethod
209209
def check_response_attributes_necessity(
210210
cls, value: Any, handler: ValidatorFunctionWrapHandler, info: ValidationInfo
211211
) -> Self:
212212
"""Check that the required attributes are present in creations and replacement requests."""
213-
value = handler(value)
213+
obj = handler(value)
214+
assert isinstance(obj, cls)
214215

215216
if (
216217
not info.context
@@ -221,12 +222,12 @@ def check_response_attributes_necessity(
221222
Context.RESOURCE_REPLACEMENT_REQUEST,
222223
)
223224
):
224-
return value
225+
return obj
225226

226227
for field_name in cls.model_fields:
227228
necessity = cls.get_field_annotation(field_name, Required)
228229

229-
if necessity == Required.true and getattr(value, field_name) is None:
230+
if necessity == Required.true and getattr(obj, field_name) is None:
230231
raise PydanticCustomError(
231232
"required_error",
232233
"Field '{field_name}' is required but value is missing or null",
@@ -235,7 +236,7 @@ def check_response_attributes_necessity(
235236
},
236237
)
237238

238-
return value
239+
return obj
239240

240241
@model_validator(mode="wrap")
241242
@classmethod
@@ -245,7 +246,8 @@ def check_replacement_request_mutability(
245246
"""Check if 'immutable' attributes have been mutated in replacement requests."""
246247
from scim2_models.rfc7643.resource import Resource
247248

248-
value = handler(value)
249+
obj = handler(value)
250+
assert isinstance(obj, cls)
249251

250252
context = info.context.get("scim") if info.context else None
251253
original = info.context.get("original") if info.context else None
@@ -254,8 +256,8 @@ def check_replacement_request_mutability(
254256
and issubclass(cls, Resource)
255257
and original is not None
256258
):
257-
cls.check_mutability_issues(original, value)
258-
return value
259+
cls.check_mutability_issues(original, obj)
260+
return obj
259261

260262
@classmethod
261263
def check_mutability_issues(
@@ -403,7 +405,7 @@ def model_serializer_exclude_none(
403405
@classmethod
404406
def model_validate(
405407
cls,
406-
*args,
408+
*args: Any,
407409
scim_ctx: Optional[Context] = Context.DEFAULT,
408410
original: Optional["BaseModel"] = None,
409411
**kwargs: Any,
@@ -443,6 +445,3 @@ def get_attribute_urn(self, field_name: str) -> str:
443445
else f"{main_schema}:{alias}"
444446
)
445447
return full_urn
446-
447-
448-
BaseModelType: type = type(BaseModel)

scim2_models/reference.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def __get_pydantic_core_schema__(
5454

5555
@classmethod
5656
def _validate(cls, input_value: Any, /) -> str:
57-
if isinstance(input_value, cls):
58-
return str(input_value)
59-
return input_value
57+
return str(input_value)
6058

6159
@classmethod
6260
def get_types(cls, type_annotation: Any) -> list[str]:
@@ -79,6 +77,6 @@ def serialize_ref_type(ref_type: Any) -> str:
7977
elif ref_type == ExternalReference:
8078
return "external"
8179

82-
return get_args(ref_type)[0]
80+
return str(get_args(ref_type)[0])
8381

8482
return list(map(serialize_ref_type, types))

scim2_models/rfc7643/resource.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Optional
77
from typing import TypeVar
88
from typing import Union
9+
from typing import cast
910
from typing import get_args
1011
from typing import get_origin
1112

@@ -103,7 +104,7 @@ def from_schema(cls, schema: "Schema") -> type["Extension"]:
103104

104105
AnyExtension = TypeVar("AnyExtension", bound="Extension")
105106

106-
_PARAMETERIZED_CLASSES: dict[tuple[type, tuple], type] = {}
107+
_PARAMETERIZED_CLASSES: dict[tuple[type, tuple[Any, ...]], type] = {}
107108

108109

109110
def extension_serializer(
@@ -154,9 +155,11 @@ def __class_getitem__(cls, item: Any) -> type["Resource"]:
154155

155156
extensions = get_args(item) if get_origin(item) in UNION_TYPES else [item]
156157

157-
# Skip TypeVar parameters (used for generic class definitions)
158+
# Skip TypeVar parameters and Any (used for generic class definitions)
158159
valid_extensions = [
159-
extension for extension in extensions if not isinstance(extension, TypeVar)
160+
extension
161+
for extension in extensions
162+
if not isinstance(extension, TypeVar) and extension is not Any
160163
]
161164

162165
if not valid_extensions:
@@ -209,9 +212,9 @@ def __getitem__(self, item: Any) -> Optional[Extension]:
209212
if not isinstance(item, type) or not issubclass(item, Extension):
210213
raise KeyError(f"{item} is not a valid extension type")
211214

212-
return getattr(self, item.__name__)
215+
return cast(Optional[Extension], getattr(self, item.__name__))
213216

214-
def __setitem__(self, item: Any, value: "Resource") -> None:
217+
def __setitem__(self, item: Any, value: "Extension") -> None:
215218
if not isinstance(item, type) or not issubclass(item, Extension):
216219
raise KeyError(f"{item} is not a valid extension type")
217220

@@ -260,7 +263,9 @@ def get_by_schema(
260263

261264
@staticmethod
262265
def get_by_payload(
263-
resource_types: list[type["Resource"]], payload: dict, **kwargs: Any
266+
resource_types: list[type["Resource"]],
267+
payload: dict[str, Any],
268+
**kwargs: Any,
264269
) -> Optional[type]:
265270
"""Given a resource type list and a payload, find the matching resource type."""
266271
if not payload or not payload.get("schemas"):
@@ -317,7 +322,7 @@ def model_dump(
317322
attributes: Optional[list[str]] = None,
318323
excluded_attributes: Optional[list[str]] = None,
319324
**kwargs: Any,
320-
) -> dict:
325+
) -> dict[str, Any]:
321326
"""Create a model representation that can be included in SCIM messages by using Pydantic :code:`BaseModel.model_dump`.
322327
323328
:param scim_ctx: If a SCIM context is passed, some default values of
@@ -366,7 +371,7 @@ def model_dump_json(
366371

367372
def dedicated_attributes(
368373
model: type[BaseModel], excluded_models: list[type[BaseModel]]
369-
) -> dict:
374+
) -> dict[str, Any]:
370375
"""Return attributes that are not members the parent 'excluded_models'."""
371376

372377
def compare_field_infos(fi1: Any, fi2: Any) -> bool:

scim2_models/rfc7644/list_response.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def check_results_number(
5151
- 'resources' must be set if 'totalResults' is non-zero.
5252
"""
5353
obj = handler(value)
54+
assert isinstance(obj, cls)
5455

5556
if (
5657
not info.context

scim2_models/rfc7644/message.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
from pydantic import Discriminator
1010
from pydantic import Tag
11+
from pydantic._internal._model_construction import ModelMetaclass
1112

1213
from ..base import BaseModel
13-
from ..base import BaseModelType
1414
from ..scim_object import ScimObject
1515
from ..utils import UNION_TYPES
1616

@@ -91,10 +91,12 @@ def create_tagged_resource_union(resource_union: Any) -> Any:
9191
return Annotated[union, discriminator]
9292

9393

94-
class GenericMessageMetaclass(BaseModelType):
94+
class GenericMessageMetaclass(ModelMetaclass):
9595
"""Metaclass for SCIM generic types with discriminated unions."""
9696

97-
def __new__(cls, name: str, bases: tuple, attrs: dict, **kwargs: Any) -> type:
97+
def __new__(
98+
cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any], **kwargs: Any
99+
) -> type:
98100
"""Create class with tagged resource unions for generic parameters."""
99101
if kwargs.get("__pydantic_generic_metadata__") and kwargs[
100102
"__pydantic_generic_metadata__"

scim2_models/scim_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def model_dump(
115115
*args: Any,
116116
scim_ctx: Optional[Context] = Context.DEFAULT,
117117
**kwargs: Any,
118-
) -> dict:
118+
) -> dict[str, Any]:
119119
"""Create a model representation that can be included in SCIM messages by using Pydantic :code:`BaseModel.model_dump`.
120120
121121
:param scim_ctx: If a SCIM context is passed, some default values of

0 commit comments

Comments
 (0)