Skip to content

Commit 1fc369b

Browse files
committed
Cleanup and restructure Pydantic models for Schema
1 parent 08f2762 commit 1fc369b

File tree

16 files changed

+664
-527
lines changed

16 files changed

+664
-527
lines changed

infrahub_sdk/checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pathlib import Path
1818

1919
from . import InfrahubClient
20-
from .schema import InfrahubCheckDefinitionConfig
20+
from .schema.repository import InfrahubCheckDefinitionConfig
2121

2222
INFRAHUB_CHECK_VARIABLE_TO_IMPORT = "INFRAHUB_CHECKS"
2323

infrahub_sdk/code_generator.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from collections.abc import Mapping
2-
from typing import Any, Optional
2+
from typing import Any, Optional, Union
33

44
import jinja2
55

66
from . import protocols as sdk_protocols
77
from .ctl.constants import PROTOCOLS_TEMPLATE
88
from .schema import (
9-
AttributeSchema,
9+
AttributeSchemaAPI,
1010
GenericSchema,
11-
MainSchemaTypes,
11+
GenericSchemaAPI,
12+
MainSchemaTypesAll,
1213
NodeSchema,
13-
ProfileSchema,
14-
RelationshipSchema,
14+
NodeSchemaAPI,
15+
ProfileSchemaAPI,
16+
RelationshipSchemaAPI,
1517
)
1618

1719
ATTRIBUTE_KIND_MAP = {
@@ -40,17 +42,17 @@
4042

4143

4244
class CodeGenerator:
43-
def __init__(self, schema: dict[str, MainSchemaTypes]):
44-
self.generics: dict[str, GenericSchema] = {}
45-
self.nodes: dict[str, NodeSchema] = {}
46-
self.profiles: dict[str, ProfileSchema] = {}
45+
def __init__(self, schema: dict[str, MainSchemaTypesAll]):
46+
self.generics: dict[str, Union[GenericSchemaAPI, GenericSchema]] = {}
47+
self.nodes: dict[str, Union[NodeSchemaAPI, NodeSchema]] = {}
48+
self.profiles: dict[str, ProfileSchemaAPI] = {}
4749

4850
for name, schema_type in schema.items():
49-
if isinstance(schema_type, GenericSchema):
51+
if isinstance(schema_type, (GenericSchemaAPI, GenericSchema)):
5052
self.generics[name] = schema_type
51-
if isinstance(schema_type, NodeSchema):
53+
if isinstance(schema_type, (NodeSchemaAPI, NodeSchema)):
5254
self.nodes[name] = schema_type
53-
if isinstance(schema_type, ProfileSchema):
55+
if isinstance(schema_type, ProfileSchemaAPI):
5456
self.profiles[name] = schema_type
5557

5658
self.base_protocols = [
@@ -92,7 +94,7 @@ def _jinja2_filter_inheritance(value: dict[str, Any]) -> str:
9294
return ", ".join(inherit_from)
9395

9496
@staticmethod
95-
def _jinja2_filter_render_attribute(value: AttributeSchema) -> str:
97+
def _jinja2_filter_render_attribute(value: AttributeSchemaAPI) -> str:
9698
attribute_kind: str = ATTRIBUTE_KIND_MAP[value.kind]
9799

98100
if value.optional:
@@ -101,7 +103,7 @@ def _jinja2_filter_render_attribute(value: AttributeSchema) -> str:
101103
return f"{value.name}: {attribute_kind}"
102104

103105
@staticmethod
104-
def _jinja2_filter_render_relationship(value: RelationshipSchema, sync: bool = False) -> str:
106+
def _jinja2_filter_render_relationship(value: RelationshipSchemaAPI, sync: bool = False) -> str:
105107
name = value.name
106108
cardinality = value.cardinality
107109

@@ -116,12 +118,12 @@ def _jinja2_filter_render_relationship(value: RelationshipSchema, sync: bool = F
116118

117119
@staticmethod
118120
def _sort_and_filter_models(
119-
models: Mapping[str, MainSchemaTypes], filters: Optional[list[str]] = None
120-
) -> list[MainSchemaTypes]:
121+
models: Mapping[str, MainSchemaTypesAll], filters: Optional[list[str]] = None
122+
) -> list[MainSchemaTypesAll]:
121123
if filters is None:
122124
filters = ["CoreNode"]
123125

124-
filtered: list[MainSchemaTypes] = []
126+
filtered: list[MainSchemaTypesAll] = []
125127
for name, model in models.items():
126128
if name in filters:
127129
continue

infrahub_sdk/ctl/check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ..ctl.repository import get_repository_config
1818
from ..ctl.utils import catch_exception, execute_graphql_query
1919
from ..exceptions import ModuleImportError
20-
from ..schema import InfrahubCheckDefinitionConfig, InfrahubRepositoryConfig
20+
from ..schema.repository import InfrahubCheckDefinitionConfig, InfrahubRepositoryConfig
2121

2222
app = typer.Typer()
2323
console = Console()

infrahub_sdk/ctl/cli_commands.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,8 @@
3838
from ..ctl.validate import app as validate_app
3939
from ..exceptions import GraphQLError, ModuleImportError
4040
from ..jinja2 import identify_faulty_jinja_code
41-
from ..schema import (
42-
InfrahubRepositoryConfig,
43-
MainSchemaTypes,
44-
SchemaRoot,
45-
)
41+
from ..schema import MainSchemaTypesAll, SchemaRoot
42+
from ..schema.repository import InfrahubRepositoryConfig
4643
from ..utils import get_branch, write_to_file
4744
from ..yaml import SchemaFile
4845
from .exporter import dump
@@ -364,7 +361,7 @@ def protocols(
364361
) -> None:
365362
"""Export Python protocols corresponding to a schema."""
366363

367-
schema: dict[str, MainSchemaTypes] = {}
364+
schema: dict[str, MainSchemaTypesAll] = {}
368365

369366
if schemas:
370367
schemas_data = load_yamlfile_from_disk_and_exit(paths=schemas, file_type=SchemaFile, console=console)

infrahub_sdk/node.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
UninitializedError,
1515
)
1616
from .graphql import Mutation, Query
17-
from .schema import GenericSchema, RelationshipCardinality, RelationshipKind
17+
from .schema import GenericSchemaAPI, RelationshipCardinality, RelationshipKind
1818
from .utils import compare_lists, get_flat_value
1919
from .uuidt import UUIDT
2020

2121
if TYPE_CHECKING:
2222
from typing_extensions import Self
2323

2424
from .client import InfrahubClient, InfrahubClientSync
25-
from .schema import AttributeSchema, MainSchemaTypes, RelationshipSchema
25+
from .schema import AttributeSchemaAPI, MainSchemaTypesAPI, RelationshipSchemaAPI
2626

2727
# pylint: disable=too-many-lines
2828

@@ -46,7 +46,7 @@
4646
class Attribute:
4747
"""Represents an attribute of a Node, including its schema, value, and properties."""
4848

49-
def __init__(self, name: str, schema: AttributeSchema, data: Union[Any, dict]):
49+
def __init__(self, name: str, schema: AttributeSchemaAPI, data: Union[Any, dict]):
5050
"""
5151
Args:
5252
name (str): The name of the attribute.
@@ -143,7 +143,7 @@ def _generate_mutation_query(self) -> dict[str, Any]:
143143
class RelatedNodeBase:
144144
"""Base class for representing a related node in a relationship."""
145145

146-
def __init__(self, branch: str, schema: RelationshipSchema, data: Union[Any, dict], name: Optional[str] = None):
146+
def __init__(self, branch: str, schema: RelationshipSchemaAPI, data: Union[Any, dict], name: Optional[str] = None):
147147
"""
148148
Args:
149149
branch (str): The branch where the related node resides.
@@ -300,7 +300,7 @@ def __init__(
300300
self,
301301
client: InfrahubClient,
302302
branch: str,
303-
schema: RelationshipSchema,
303+
schema: RelationshipSchemaAPI,
304304
data: Union[Any, dict],
305305
name: Optional[str] = None,
306306
):
@@ -347,7 +347,7 @@ def __init__(
347347
self,
348348
client: InfrahubClientSync,
349349
branch: str,
350-
schema: RelationshipSchema,
350+
schema: RelationshipSchemaAPI,
351351
data: Union[Any, dict],
352352
name: Optional[str] = None,
353353
):
@@ -390,7 +390,7 @@ def get(self) -> InfrahubNodeSync:
390390
class RelationshipManagerBase:
391391
"""Base class for RelationshipManager and RelationshipManagerSync"""
392392

393-
def __init__(self, name: str, branch: str, schema: RelationshipSchema):
393+
def __init__(self, name: str, branch: str, schema: RelationshipSchemaAPI):
394394
"""
395395
Args:
396396
name (str): The name of the relationship.
@@ -473,7 +473,7 @@ def __init__(
473473
client: InfrahubClient,
474474
node: InfrahubNode,
475475
branch: str,
476-
schema: RelationshipSchema,
476+
schema: RelationshipSchemaAPI,
477477
data: Union[Any, dict],
478478
):
479479
"""
@@ -568,7 +568,7 @@ def __init__(
568568
client: InfrahubClientSync,
569569
node: InfrahubNodeSync,
570570
branch: str,
571-
schema: RelationshipSchema,
571+
schema: RelationshipSchemaAPI,
572572
data: Union[Any, dict],
573573
):
574574
"""
@@ -657,12 +657,12 @@ def remove(self, data: Union[str, RelatedNodeSync, dict]) -> None:
657657
class InfrahubNodeBase:
658658
"""Base class for InfrahubNode and InfrahubNodeSync"""
659659

660-
def __init__(self, schema: MainSchemaTypes, branch: str, data: Optional[dict] = None) -> None:
660+
def __init__(self, schema: MainSchemaTypesAPI, branch: str, data: Optional[dict] = None) -> None:
661661
"""
662662
Args:
663-
schema (MainSchemaTypes): The schema of the node.
664-
branch (str): The branch where the node resides.
665-
data (Optional[dict]): Optional data to initialize the node.
663+
schema: The schema of the node.
664+
branch: The branch where the node resides.
665+
data: Optional data to initialize the node.
666666
"""
667667
self._schema = schema
668668
self._data = data
@@ -1035,16 +1035,16 @@ class InfrahubNode(InfrahubNodeBase):
10351035
def __init__(
10361036
self,
10371037
client: InfrahubClient,
1038-
schema: MainSchemaTypes,
1038+
schema: MainSchemaTypesAPI,
10391039
branch: Optional[str] = None,
10401040
data: Optional[dict] = None,
10411041
) -> None:
10421042
"""
10431043
Args:
1044-
client (InfrahubClient): The client used to interact with the backend.
1045-
schema (MainSchemaTypes): The schema of the node.
1046-
branch (Optional[str]): The branch where the node resides.
1047-
data (Optional[dict]): Optional data to initialize the node.
1044+
client: The client used to interact with the backend.
1045+
schema: The schema of the node.
1046+
branch: The branch where the node resides.
1047+
data: Optional data to initialize the node.
10481048
"""
10491049
self._client = client
10501050
self.__class__ = type(f"{schema.kind}InfrahubNode", (self.__class__,), {})
@@ -1060,7 +1060,7 @@ async def from_graphql(
10601060
client: InfrahubClient,
10611061
branch: str,
10621062
data: dict,
1063-
schema: Optional[MainSchemaTypes] = None,
1063+
schema: Optional[MainSchemaTypesAPI] = None,
10641064
timeout: Optional[int] = None,
10651065
) -> Self:
10661066
if not schema:
@@ -1146,7 +1146,7 @@ async def save(
11461146
if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING:
11471147
update_group_context = True
11481148

1149-
if not isinstance(self._schema, GenericSchema):
1149+
if not isinstance(self._schema, GenericSchemaAPI):
11501150
if "CoreGroup" in self._schema.inherit_from:
11511151
await self._client.group_context.add_related_groups(
11521152
ids=[self.id], update_group_context=update_group_context
@@ -1183,7 +1183,7 @@ async def generate_query_data(
11831183
)
11841184
)
11851185

1186-
if isinstance(self._schema, GenericSchema) and fragment:
1186+
if isinstance(self._schema, GenericSchemaAPI) and fragment:
11871187
for child in self._schema.used_by:
11881188
child_schema = await self._client.schema.get(kind=child)
11891189
child_node = InfrahubNode(client=self._client, schema=child_schema)
@@ -1540,7 +1540,7 @@ class InfrahubNodeSync(InfrahubNodeBase):
15401540
def __init__(
15411541
self,
15421542
client: InfrahubClientSync,
1543-
schema: MainSchemaTypes,
1543+
schema: MainSchemaTypesAPI,
15441544
branch: Optional[str] = None,
15451545
data: Optional[dict] = None,
15461546
) -> None:
@@ -1565,7 +1565,7 @@ def from_graphql(
15651565
client: InfrahubClientSync,
15661566
branch: str,
15671567
data: dict,
1568-
schema: Optional[MainSchemaTypes] = None,
1568+
schema: Optional[MainSchemaTypesAPI] = None,
15691569
timeout: Optional[int] = None,
15701570
) -> Self:
15711571
if not schema:
@@ -1648,7 +1648,7 @@ def save(
16481648
if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING:
16491649
update_group_context = True
16501650

1651-
if not isinstance(self._schema, GenericSchema):
1651+
if not isinstance(self._schema, GenericSchemaAPI):
16521652
if "CoreGroup" in self._schema.inherit_from:
16531653
self._client.group_context.add_related_groups(ids=[self.id], update_group_context=update_group_context)
16541654
else:
@@ -1681,7 +1681,7 @@ def generate_query_data(
16811681
)
16821682
)
16831683

1684-
if isinstance(self._schema, GenericSchema) and fragment:
1684+
if isinstance(self._schema, GenericSchemaAPI) and fragment:
16851685
for child in self._schema.used_by:
16861686
child_schema = self._client.schema.get(kind=child)
16871687
child_node = InfrahubNodeSync(client=self._client, schema=child_schema)

infrahub_sdk/pytest_plugin/items/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
if TYPE_CHECKING:
1414
from pathlib import Path
1515

16-
from ...schema import InfrahubRepositoryConfigElement
16+
from ...schema.repository import InfrahubRepositoryConfigElement
1717
from ..models import InfrahubTest
1818

1919

infrahub_sdk/pytest_plugin/items/check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pytest import ExceptionInfo
1616

1717
from ...checks import InfrahubCheck
18-
from ...schema import InfrahubRepositoryConfigElement
18+
from ...schema.repository import InfrahubRepositoryConfigElement
1919
from ..models import InfrahubTest
2020

2121

infrahub_sdk/pytest_plugin/items/python_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
if TYPE_CHECKING:
1515
from pytest import ExceptionInfo
1616

17-
from ...schema import InfrahubRepositoryConfigElement
17+
from ...schema.repository import InfrahubRepositoryConfigElement
1818
from ...transforms import InfrahubTransform
1919
from ..models import InfrahubTest
2020

infrahub_sdk/query_groups.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
if TYPE_CHECKING:
1010
from .client import InfrahubClient, InfrahubClientSync
1111
from .node import InfrahubNode, InfrahubNodeSync, RelatedNodeBase
12-
from .schema import MainSchemaTypes
12+
from .schema import MainSchemaTypesAPI
1313

1414

1515
class InfrahubGroupContextBase:
@@ -63,7 +63,7 @@ def _generate_group_name(self, suffix: Optional[str] = None) -> str:
6363

6464
return group_name
6565

66-
def _generate_group_description(self, schema: MainSchemaTypes) -> str:
66+
def _generate_group_description(self, schema: MainSchemaTypesAPI) -> str:
6767
"""Generate the description of the group from the params
6868
and ensure it's not longer than the maximum length of the description field."""
6969
if not self.params:

0 commit comments

Comments
 (0)