Skip to content

Commit 470ba12

Browse files
authored
Merge pull request #95 from J-Bu/main
Fix ResourceType::from_resource for resources with multiple extensions
2 parents b64d53c + 3f5b4f3 commit 470ba12

File tree

4 files changed

+54
-6
lines changed

4 files changed

+54
-6
lines changed

scim2_models/rfc7643/resource.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Generic
55
from typing import Optional
66
from typing import TypeVar
7-
from typing import Union
87
from typing import get_args
98
from typing import get_origin
109

@@ -24,6 +23,7 @@
2423
from ..base import Uniqueness
2524
from ..base import URIReference
2625
from ..base import is_complex_attribute
26+
from ..utils import UNION_TYPES
2727
from ..utils import normalize_attribute_name
2828

2929

@@ -117,7 +117,7 @@ def __new__(cls, name, bases, attrs, **kwargs):
117117
extensions = kwargs["__pydantic_generic_metadata__"]["args"][0]
118118
extensions = (
119119
get_args(extensions)
120-
if get_origin(extensions) == Union
120+
if get_origin(extensions) in UNION_TYPES
121121
else [extensions]
122122
)
123123
for extension in extensions:
@@ -183,7 +183,8 @@ def get_extension_models(cls) -> dict[str, type[Extension]]:
183183
extension_models = cls.__pydantic_generic_metadata__.get("args", [])
184184
extension_models = (
185185
get_args(extension_models[0])
186-
if len(extension_models) == 1 and get_origin(extension_models[0]) == Union
186+
if len(extension_models) == 1
187+
and get_origin(extension_models[0]) in UNION_TYPES
187188
else extension_models
188189
)
189190

@@ -301,7 +302,7 @@ def model_to_schema(model: type[BaseModel]):
301302

302303
def get_reference_types(type) -> list[str]:
303304
first_arg = get_args(type)[0]
304-
types = get_args(first_arg) if get_origin(first_arg) == Union else [first_arg]
305+
types = get_args(first_arg) if get_origin(first_arg) in UNION_TYPES else [first_arg]
305306

306307
def serialize_ref_type(ref_type):
307308
if ref_type == URIReference:

scim2_models/rfc7643/resource_type.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Annotated
22
from typing import Optional
3+
from typing import get_args
4+
from typing import get_origin
35

46
from pydantic import Field
57
from typing_extensions import Self
@@ -11,6 +13,7 @@
1113
from ..base import Required
1214
from ..base import Returned
1315
from ..base import URIReference
16+
from ..utils import UNION_TYPES
1417
from .resource import Resource
1518

1619

@@ -82,7 +85,16 @@ def from_resource(cls, resource_model: type[Resource]) -> Self:
8285
"""Build a naive ResourceType from a resource model."""
8386
schema = resource_model.model_fields["schemas"].default[0]
8487
name = schema.split(":")[-1]
85-
extensions = resource_model.__pydantic_generic_metadata__["args"]
88+
if resource_model.__pydantic_generic_metadata__["args"]:
89+
extensions = resource_model.__pydantic_generic_metadata__["args"][0]
90+
extensions = (
91+
get_args(extensions)
92+
if get_origin(extensions) in UNION_TYPES
93+
else [extensions]
94+
)
95+
else:
96+
extensions = []
97+
8698
return ResourceType(
8799
id=name,
88100
name=name,

scim2_models/rfc7644/list_response.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ..base import Context
2121
from ..base import Required
2222
from ..rfc7643.resource import AnyResource
23+
from ..utils import UNION_TYPES
2324
from .message import Message
2425

2526

@@ -29,7 +30,7 @@ def tagged_resource_union(resource_union):
2930
3031
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
3132
"""
32-
if not get_origin(resource_union) == Union:
33+
if get_origin(resource_union) not in UNION_TYPES:
3334
return resource_union
3435

3536
resource_types = get_args(resource_union)

tests/test_resource_type.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
from typing import Annotated
2+
from typing import Union
3+
14
from scim2_models import EnterpriseUser
5+
from scim2_models import Extension
26
from scim2_models import Reference
7+
from scim2_models import Required
38
from scim2_models import ResourceType
49
from scim2_models import User
510

@@ -61,3 +66,32 @@ def test_from_resource_with_extensions():
6166
== "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
6267
)
6368
assert not enterprise_user_rt.schema_extensions[0].required
69+
70+
71+
def test_from_resource_with_mulitple_extensions():
72+
class TestExtension(Extension):
73+
schemas: Annotated[list[str], Required.true] = [
74+
"urn:ietf:params:scim:schemas:extension:Test:1.0:User"
75+
]
76+
77+
test: Union[str, None] = None
78+
test2: Union[list[str], None] = None
79+
80+
enterprise_user_rt = ResourceType.from_resource(
81+
User[Union[EnterpriseUser, TestExtension]]
82+
)
83+
assert enterprise_user_rt.id == "User"
84+
assert enterprise_user_rt.name == "User"
85+
assert enterprise_user_rt.description == "User"
86+
assert enterprise_user_rt.endpoint == "/Users"
87+
assert enterprise_user_rt.schema_ == "urn:ietf:params:scim:schemas:core:2.0:User"
88+
assert (
89+
enterprise_user_rt.schema_extensions[0].schema_
90+
== "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
91+
)
92+
assert not enterprise_user_rt.schema_extensions[0].required
93+
assert (
94+
enterprise_user_rt.schema_extensions[1].schema_
95+
== "urn:ietf:params:scim:schemas:extension:Test:1.0:User"
96+
)
97+
assert not enterprise_user_rt.schema_extensions[1].required

0 commit comments

Comments
 (0)