Skip to content

Commit 96a4167

Browse files
author
Andrei Neagu
committed
fixed ytping
1 parent 0ab5c37 commit 96a4167

File tree

5 files changed

+58
-15
lines changed

5 files changed

+58
-15
lines changed

packages/models-library/src/models_library/osparc_variable_identifier.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pydantic import BaseModel, Discriminator, PositiveInt, Tag
88

99
from .utils.string_substitution import OSPARC_IDENTIFIER_PREFIX
10+
from .utils.types import get_types_from_annotated_union
1011

1112
T = TypeVar("T")
1213

@@ -89,7 +90,7 @@ def example_func(par: OsparcVariableIdentifier | int) -> None:
8990
Raises:
9091
TypeError: if the the OsparcVariableIdentifier was unresolved
9192
"""
92-
if isinstance(var, OsparcVariableIdentifier):
93+
if isinstance(var, get_types_from_annotated_union(OsparcVariableIdentifier)):
9394
raise UnresolvedOsparcVariableIdentifierError(value=var)
9495
return var
9596

@@ -116,7 +117,7 @@ def replace_osparc_variable_identifier( # noqa: C901
116117
```
117118
"""
118119

119-
if isinstance(obj, OsparcVariableIdentifier):
120+
if isinstance(obj, get_types_from_annotated_union(OsparcVariableIdentifier)):
120121
if obj.name in osparc_variables:
121122
return deepcopy(osparc_variables[obj.name]) # type: ignore
122123
if obj.default_value is not None:
@@ -154,7 +155,7 @@ def raise_if_unresolved_osparc_variable_identifier_found(obj: Any) -> None:
154155
UnresolvedOsparcVariableIdentifierError: if not all instances of
155156
`OsparcVariableIdentifier` were replaced
156157
"""
157-
if isinstance(obj, OsparcVariableIdentifier):
158+
if isinstance(obj, get_types_from_annotated_union(OsparcVariableIdentifier)):
158159
raise_if_unresolved(obj)
159160
elif isinstance(obj, dict):
160161
for key, value in obj.items():

packages/models-library/src/models_library/service_settings_nat_rule.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
from collections.abc import Generator
22
from typing import Final
33

4-
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator
4+
from pydantic import (
5+
BaseModel,
6+
ConfigDict,
7+
Field,
8+
TypeAdapter,
9+
ValidationInfo,
10+
field_validator,
11+
)
512

613
from .basic_types import PortInt
7-
from .osparc_variable_identifier import OsparcVariableIdentifier, raise_if_unresolved
14+
from .osparc_variable_identifier import (
15+
OsparcVariableIdentifier,
16+
raise_if_unresolved,
17+
)
18+
from .utils.types import get_types_from_annotated_union
819

920
# Cloudflare DNS server address
1021
DEFAULT_DNS_SERVER_ADDRESS: Final[str] = "1.1.1.1" # NOSONAR
@@ -20,13 +31,15 @@ class _PortRange(BaseModel):
2031
@field_validator("upper")
2132
@classmethod
2233
def lower_less_than_upper(cls, v, info: ValidationInfo) -> PortInt:
23-
if isinstance(v, OsparcVariableIdentifier):
34+
if isinstance(v, get_types_from_annotated_union(OsparcVariableIdentifier)):
2435
return v # type: ignore # bypass validation if unresolved
2536

2637
upper = v
2738
lower: PortInt | OsparcVariableIdentifier | None = info.data.get("lower")
2839

29-
if lower and isinstance(lower, OsparcVariableIdentifier):
40+
if lower and isinstance(
41+
lower, get_types_from_annotated_union(OsparcVariableIdentifier)
42+
):
3043
return v # type: ignore # bypass validation if unresolved
3144

3245
if lower is None or lower >= upper:
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Annotated, Any, Union, get_args, get_origin
2+
3+
4+
def get_types_from_annotated_union(annotated_alias: Any) -> tuple[type, ...]:
5+
"""
6+
Introspects a complex Annotated alias to extract the base types from its inner Union.
7+
"""
8+
if get_origin(annotated_alias) is not Annotated:
9+
msg = "Expected an Annotated type."
10+
raise TypeError(msg)
11+
12+
# Get the contents of Annotated, e.g., (Union[...], Discriminator(...))
13+
annotated_args = get_args(annotated_alias)
14+
union_type = annotated_args[0]
15+
16+
# The Union can be from typing.Union or the | operator
17+
if get_origin(union_type) is not Union:
18+
msg = "Expected a Union inside the Annotated type."
19+
raise TypeError(msg)
20+
21+
# Get the members of the Union, e.g., (Annotated[TypeA, ...], Annotated[TypeB, ...])
22+
union_members = get_args(union_type)
23+
24+
extracted_types = []
25+
for member in union_members:
26+
# Each member is also Annotated, so we extract its base type
27+
if get_origin(member) is Annotated:
28+
extracted_types.append(get_args(member)[0])
29+
else:
30+
extracted_types.append(member) # Handle non-annotated members in the union
31+
32+
return tuple(extracted_types)

packages/models-library/tests/test_service_settings_nat_rule.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
replace_osparc_variable_identifier,
1010
)
1111
from models_library.service_settings_nat_rule import NATRule
12+
from models_library.utils.types import get_types_from_annotated_union
1213
from pydantic import TypeAdapter
1314

1415
SUPPORTED_TEMPLATES: set[str] = {
@@ -111,13 +112,13 @@ def test_______(replace_with_value: Any):
111112
a_var = TypeAdapter(OsparcVariableIdentifier).validate_python(
112113
"$OSPARC_VARIABLE_some_var"
113114
)
114-
assert isinstance(a_var, OsparcVariableIdentifier)
115+
assert isinstance(a_var, get_types_from_annotated_union(OsparcVariableIdentifier))
115116

116117
replaced_var = replace_osparc_variable_identifier(
117118
a_var, {"OSPARC_VARIABLE_some_var": replace_with_value}
118119
)
119120
# NOTE: after replacement the original reference still points
120-
assert isinstance(a_var, OsparcVariableIdentifier)
121+
assert isinstance(a_var, get_types_from_annotated_union(OsparcVariableIdentifier))
121122
assert replaced_var == replace_with_value
122123

123124

@@ -154,15 +155,15 @@ def test_replace_an_instance_of_osparc_variable_identifier(
154155
formatted_template = var_template
155156

156157
a_var = TypeAdapter(OsparcVariableIdentifier).validate_python(formatted_template)
157-
assert isinstance(a_var, OsparcVariableIdentifier)
158+
assert isinstance(a_var, get_types_from_annotated_union(OsparcVariableIdentifier))
158159

159160
replace_with_identifier_default = identifier_has_default and replace_with_default
160161
replacement_content = (
161162
{} if replace_with_identifier_default else {a_var.name: replace_with_value}
162163
)
163164
replaced_var = replace_osparc_variable_identifier(a_var, replacement_content)
164165
# NOTE: after replacement the original reference still points
165-
assert isinstance(a_var, OsparcVariableIdentifier)
166+
assert isinstance(a_var, get_types_from_annotated_union(OsparcVariableIdentifier))
166167
if replace_with_identifier_default:
167168
assert replaced_var == default_value
168169
else:

packages/service-integration/tests/data/runtime.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@ settings:
1111
type: string
1212
value:
1313
- node.platform.os == linux
14-
# # https://docs.docker.com/compose/compose-file/compose-file-v3/#environment
15-
# - name: environment
16-
# type: string
17-
# -
1814
paths-mapping:
1915
inputs_path: "/config/workspace/inputs"
2016
outputs_path: "/config/workspace/outputs"

0 commit comments

Comments
 (0)