diff --git a/docs/docs/python-sdk/examples/pydantic_car.py b/docs/docs/python-sdk/examples/pydantic_car.py new file mode 100644 index 00000000..58d0501e --- /dev/null +++ b/docs/docs/python-sdk/examples/pydantic_car.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from asyncio import run as aiorun +from typing import Annotated + +from pydantic import ConfigDict, Field +from rich import print as rprint + +from infrahub_sdk import InfrahubClient +from infrahub_sdk.schema import ( + AttributeKind, + GenericModel, + NodeModel, + NodeSchema, + from_pydantic, +) +from infrahub_sdk.schema.pydantic_utils import ( + Attribute, + GenericModel, + InfrahubConfig, + NodeModel, + Relationship, + SchemaModel, + analyze_field, + field_to_attribute, + field_to_relationship, + from_pydantic, + get_attribute_kind, + get_kind, + model_to_node, +) + + +class Tag(NodeModel): + model_config = InfrahubConfig(namespace="Test", human_readable_fields=["name__value"]) + + name: str = Attribute(unique=True, description="The name of the tag") + label: str | None = Field(description="The label of the tag") + description: str | None = Attribute(None, kind=AttributeKind.TEXTAREA) + + +class TestCar(NodeModel): + name: str = Field(description="The name of the car") + tags: list[Tag] + owner: TestPerson = Relationship(identifier="car__person")] + secondary_owner: TestPerson | None = None + + +class TestPerson(GenericModel): + name: str + + +class TestCarOwner(NodeModel, TestPerson): + cars: list[TestCar] = Relationship(identifier="car__person") + + +async def main() -> None: + client = InfrahubClient() + schema = from_pydantic(models=[TestPerson, TestCar, Tag, TestPerson, TestCarOwner]) + rprint(schema.to_schema_dict()) + response = await client.schema.load(schemas=[schema.to_schema_dict()], wait_until_converged=True) + rprint(response) + + # Create a Tag + tag = await client.create("TestTag", name="Blue", label="Blue") + await tag.save(allow_upsert=True) + + +if __name__ == "__main__": + aiorun(main()) diff --git a/docs/docs/python-sdk/examples/pydantic_infra.py b/docs/docs/python-sdk/examples/pydantic_infra.py new file mode 100644 index 00000000..52599274 --- /dev/null +++ b/docs/docs/python-sdk/examples/pydantic_infra.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from typing import Annotated + +from pydantic import ConfigDict, Field +from rich import print as rprint + +from infrahub_sdk import InfrahubClient +from infrahub_sdk.async_typer import AsyncTyper +from infrahub_sdk.schema import ( + GenericSchema, + NodeSchema, + RelationshipKind, +) +from infrahub_sdk.schema.pydantic_utils import ( + Attribute, + GenericModel, + InfrahubConfig, + NodeModel, + Relationship, + SchemaModel, + analyze_field, + field_to_attribute, + field_to_relationship, + from_pydantic, + get_attribute_kind, + get_kind, + model_to_node, +) + +app = AsyncTyper() + + +class Site(NodeModel): + model_config = InfrahubConfig( + namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"] + ) + + name: str = Attribute(unique=True, description="The name of the site") + + +class Vlan(NodeModel): + model_config = InfrahubConfig( + namespace="Infra", human_friendly_id=["vlan_id__value"], display_labels=["vlan_id__value"] + ) + + name: str + vlan_id: int + description: str | None = None + + +class Device(NodeModel): + model_config = InfrahubConfig( + name="Device", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"] + ) + + name: str = Attribute(unique=True, description="The name of the car") + site: Site = Relationship(kind=RelationshipKind.ATTRIBUTE, identifier="device__site") + interfaces: list[Interface] = Relationship(kind=RelationshipKind.COMPONENT, identifier="device__interfaces") + + +class Interface(GenericModel): + model_config = InfrahubConfig( + namespace="Infra", human_friendly_id=["device__name__value", "name__value"], display_labels=["name__value"] + ) + + device: Device = Relationship(kind=RelationshipKind.PARENT, identifier="device__interfaces") + name: str + description: str | None = None + + +class L2Interface(Interface): + model_config = InfrahubConfig(namespace="Infra") + + vlans: list[Vlan] = Field(default_factory=list) + + +class LoopbackInterface(Interface): + model_config = InfrahubConfig(namespace="Infra") + + +@app.command() +async def load_schema() -> None: + client = InfrahubClient() + schema = from_pydantic(models=[Site, Device, Interface, L2Interface, LoopbackInterface, Vlan]) + rprint(schema.to_schema_dict()) + response = await client.schema.load(schemas=[schema.to_schema_dict()], wait_until_converged=True) + rprint(response) + + +@app.command() +async def load_data() -> None: + client = InfrahubClient() + + atl = await client.create("InfraSite", name="ATL") + await atl.save(allow_upsert=True) + cdg = await client.create("InfraSite", name="CDG") + await cdg.save(allow_upsert=True) + + device1 = await client.create("InfraDevice", name="atl1-dev1", site=atl) + await device1.save(allow_upsert=True) + device2 = await client.create("InfraDevice", name="atl1-dev2", site=atl) + await device2.save(allow_upsert=True) + + lo0dev1 = await client.create("InfraLoopbackInterface", name="lo0", device=device1) + await lo0dev1.save(allow_upsert=True) + lo0dev2 = await client.create("InfraLoopbackInterface", name="lo0", device=device2) + await lo0dev2.save(allow_upsert=True) + + for idx in range(1, 3): + interface = await client.create("InfraL2Interface", name=f"Ethernet{idx}", device=device1) + await interface.save(allow_upsert=True) + + +@app.command() +async def query_data() -> None: + client = InfrahubClient() + sites = await client.all(kind=Site) + rprint(sites) + + devices = await client.all(kind=Device) + for device in devices: + rprint(device) + + +if __name__ == "__main__": + app() diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index 16c1c73a..2a870021 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -50,7 +50,7 @@ from .protocols_base import CoreNode, CoreNodeSync from .queries import QUERY_USER, get_commit_update_mutation from .query_groups import InfrahubGroupContext, InfrahubGroupContextSync -from .schema import InfrahubSchema, InfrahubSchemaSync, NodeSchemaAPI +from .schema import InfrahubSchema, InfrahubSchemaSync, NodeSchemaAPI, SchemaModel from .store import NodeStore, NodeStoreSync from .task.manager import InfrahubTaskManager, InfrahubTaskManagerSync from .timestamp import Timestamp @@ -63,6 +63,7 @@ from .context import RequestContext +SchemaModelType = TypeVar("SchemaModelType", bound=SchemaModel) SchemaType = TypeVar("SchemaType", bound=CoreNode) SchemaTypeSync = TypeVar("SchemaTypeSync", bound=CoreNodeSync) @@ -417,6 +418,63 @@ async def get( **kwargs: Any, ) -> SchemaType: ... + @overload + async def get( + self, + kind: type[SchemaModelType], + raise_when_missing: Literal[False], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType | None: ... + + @overload + async def get( + self, + kind: type[SchemaModelType], + raise_when_missing: Literal[True], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType: ... + + @overload + async def get( + self, + kind: type[SchemaModelType], + raise_when_missing: bool = ..., + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType: ... + @overload async def get( self, @@ -476,7 +534,7 @@ async def get( async def get( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, raise_when_missing: bool = True, at: Timestamp | None = None, branch: str | None = None, @@ -490,7 +548,7 @@ async def get( prefetch_relationships: bool = False, property: bool = False, **kwargs: Any, - ) -> InfrahubNode | SchemaType | None: + ) -> InfrahubNode | SchemaType | SchemaModelType | None: branch = branch or self.default_branch schema = await self.schema.get(kind=kind, branch=branch) @@ -573,7 +631,7 @@ async def _process_nodes_and_relationships( async def count( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -623,6 +681,25 @@ async def all( order: Order | None = ..., ) -> list[SchemaType]: ... + @overload + async def all( + self, + kind: type[SchemaModelType], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + populate_store: bool = ..., + offset: int | None = ..., + limit: int | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + parallel: bool = ..., + order: Order | None = ..., + ) -> list[SchemaModelType]: ... + @overload async def all( self, @@ -644,7 +721,7 @@ async def all( async def all( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -658,7 +735,7 @@ async def all( property: bool = False, parallel: bool = False, order: Order | None = None, - ) -> list[InfrahubNode] | list[SchemaType]: + ) -> list[InfrahubNode] | list[SchemaType] | list[SchemaModelType]: """Retrieve all nodes of a given kind Args: @@ -717,6 +794,27 @@ async def filters( **kwargs: Any, ) -> list[SchemaType]: ... + @overload + async def filters( + self, + kind: type[SchemaModelType], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + populate_store: bool = ..., + offset: int | None = ..., + limit: int | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + partial_match: bool = ..., + property: bool = ..., + parallel: bool = ..., + order: Order | None = ..., + **kwargs: Any, + ) -> list[SchemaModelType]: ... + @overload async def filters( self, @@ -740,7 +838,7 @@ async def filters( async def filters( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -756,7 +854,7 @@ async def filters( parallel: bool = False, order: Order | None = None, **kwargs: Any, - ) -> list[InfrahubNode] | list[SchemaType]: + ) -> list[InfrahubNode] | list[SchemaType] | list[SchemaModelType]: """Retrieve nodes of a given kind based on provided filters. Args: @@ -780,6 +878,7 @@ async def filters( list[InfrahubNodeSync]: List of Nodes that match the given filters. """ branch = branch or self.default_branch + schema = await self.schema.get(kind=kind, branch=branch) if at: at = Timestamp(at) @@ -867,7 +966,11 @@ async def process_non_batch() -> tuple[list[InfrahubNode], list[InfrahubNode]]: related_nodes = list(set(related_nodes)) for node in related_nodes: if node.id: - self.store.set(node=node) + self.store.set(key=node.id, node=node) + + if isinstance(kind, type) and issubclass(kind, SchemaModel): + return [kind.from_node(node) for node in nodes] # type: ignore[return-value] + return nodes def clone(self, branch: str | None = None) -> InfrahubClient: @@ -1702,7 +1805,7 @@ def execute_graphql( def count( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -1752,6 +1855,25 @@ def all( order: Order | None = ..., ) -> list[SchemaTypeSync]: ... + @overload + def all( + self, + kind: type[SchemaModelType], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + populate_store: bool = ..., + offset: int | None = ..., + limit: int | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + parallel: bool = ..., + order: Order | None = ..., + ) -> list[SchemaModelType]: ... + @overload def all( self, @@ -1773,7 +1895,7 @@ def all( def all( self, - kind: str | type[SchemaTypeSync], + kind: type[SchemaTypeSync | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -1787,7 +1909,7 @@ def all( property: bool = False, parallel: bool = False, order: Order | None = None, - ) -> list[InfrahubNodeSync] | list[SchemaTypeSync]: + ) -> list[InfrahubNodeSync] | list[SchemaTypeSync] | list[SchemaModelType]: """Retrieve all nodes of a given kind Args: @@ -1881,6 +2003,27 @@ def filters( **kwargs: Any, ) -> list[SchemaTypeSync]: ... + @overload + def filters( + self, + kind: type[SchemaModelType], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + populate_store: bool = ..., + offset: int | None = ..., + limit: int | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + partial_match: bool = ..., + property: bool = ..., + parallel: bool = ..., + order: Order | None = ..., + **kwargs: Any, + ) -> list[SchemaModelType]: ... + @overload def filters( self, @@ -1904,7 +2047,7 @@ def filters( def filters( self, - kind: str | type[SchemaTypeSync], + kind: type[SchemaTypeSync | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -1920,7 +2063,7 @@ def filters( parallel: bool = False, order: Order | None = None, **kwargs: Any, - ) -> list[InfrahubNodeSync] | list[SchemaTypeSync]: + ) -> list[InfrahubNodeSync] | list[SchemaTypeSync] | list[SchemaModelType]: """Retrieve nodes of a given kind based on provided filters. Args: @@ -2033,7 +2176,11 @@ def process_non_batch() -> tuple[list[InfrahubNodeSync], list[InfrahubNodeSync]] related_nodes = list(set(related_nodes)) for node in related_nodes: if node.id: - self.store.set(node=node) + self.store.set(key=node.id, node=node) + + if isinstance(kind, type) and issubclass(kind, SchemaModel): + return [kind.from_node(node) for node in nodes] # type: ignore[return-value] + return nodes @overload @@ -2093,6 +2240,63 @@ def get( **kwargs: Any, ) -> SchemaTypeSync: ... + @overload + def get( + self, + kind: type[SchemaModelType], + raise_when_missing: Literal[False], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType | None: ... + + @overload + def get( + self, + kind: type[SchemaModelType], + raise_when_missing: Literal[True], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType: ... + + @overload + def get( + self, + kind: type[SchemaModelType], + raise_when_missing: bool = ..., + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType: ... + @overload def get( self, @@ -2152,7 +2356,7 @@ def get( def get( self, - kind: str | type[SchemaTypeSync], + kind: type[SchemaTypeSync | SchemaModelType] | str, raise_when_missing: bool = True, at: Timestamp | None = None, branch: str | None = None, @@ -2166,7 +2370,7 @@ def get( prefetch_relationships: bool = False, property: bool = False, **kwargs: Any, - ) -> InfrahubNodeSync | SchemaTypeSync | None: + ) -> InfrahubNodeSync | SchemaTypeSync | SchemaModelType | None: branch = branch or self.default_branch schema = self.schema.get(kind=kind, branch=branch) diff --git a/infrahub_sdk/schema/__init__.py b/infrahub_sdk/schema/__init__.py index be1cfab9..b0c1e292 100644 --- a/infrahub_sdk/schema/__init__.py +++ b/infrahub_sdk/schema/__init__.py @@ -23,6 +23,7 @@ from ..graphql import Mutation from ..queries import SCHEMA_HASH_SYNC_STATUS from .main import ( + AttributeKind, AttributeSchema, AttributeSchemaAPI, BranchSchema, @@ -40,20 +41,29 @@ SchemaRootAPI, TemplateSchemaAPI, ) +from .pydantic_utils import ( + GenericModel, + NodeModel, + SchemaModel, + from_pydantic, +) if TYPE_CHECKING: - from ..client import InfrahubClient, InfrahubClientSync, SchemaType, SchemaTypeSync + from ..client import InfrahubClient, InfrahubClientSync, SchemaModelType, SchemaType, SchemaTypeSync from ..node import InfrahubNode, InfrahubNodeSync InfrahubNodeTypes = Union[InfrahubNode, InfrahubNodeSync] __all__ = [ + "AttributeKind", "AttributeSchema", "AttributeSchemaAPI", "BranchSupportType", + "GenericModel", "GenericSchema", "GenericSchemaAPI", + "NodeModel", "NodeSchema", "NodeSchemaAPI", "ProfileSchemaAPI", @@ -61,9 +71,11 @@ "RelationshipKind", "RelationshipSchema", "RelationshipSchemaAPI", + "SchemaModel", "SchemaRoot", "SchemaRootAPI", "TemplateSchemaAPI", + "from_pydantic", ] @@ -184,14 +196,17 @@ def _validate_load_schema_response(response: httpx.Response) -> SchemaLoadRespon raise InvalidResponseError(message=f"Invalid response received from server HTTP {response.status_code}") @staticmethod - def _get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str) -> str: + def _get_schema_name(schema: type[SchemaType | SchemaTypeSync | SchemaModelType] | str) -> str: if hasattr(schema, "_is_runtime_protocol") and schema._is_runtime_protocol: # type: ignore[union-attr] return schema.__name__ # type: ignore[union-attr] + if isinstance(schema, type) and issubclass(schema, SchemaModel): + return schema.get_kind() + if isinstance(schema, str): return schema - raise ValueError("schema must be a protocol or a string") + raise ValueError("schema must be a protocol, a SchemaModel, or a string") @staticmethod def _parse_schema_response(response: httpx.Response, branch: str) -> MutableMapping[str, Any]: @@ -227,7 +242,7 @@ class InfrahubSchema(InfrahubSchemaBase): async def get( self, - kind: type[SchemaType | SchemaTypeSync] | str, + kind: type[SchemaType | SchemaTypeSync | SchemaModelType] | str, branch: str | None = None, refresh: bool = False, timeout: int | None = None, @@ -522,7 +537,7 @@ def all( def get( self, - kind: type[SchemaType | SchemaTypeSync] | str, + kind: type[SchemaType | SchemaTypeSync | SchemaModelType] | str, branch: str | None = None, refresh: bool = False, timeout: int | None = None, diff --git a/infrahub_sdk/schema/main.py b/infrahub_sdk/schema/main.py index ba18cf49..007d74e2 100644 --- a/infrahub_sdk/schema/main.py +++ b/infrahub_sdk/schema/main.py @@ -344,7 +344,7 @@ class SchemaRoot(BaseModel): node_extensions: list[NodeExtensionSchema] = Field(default_factory=list) def to_schema_dict(self) -> dict[str, Any]: - return self.model_dump(exclude_unset=True, exclude_defaults=True) + return self.model_dump(exclude_defaults=True, mode="json") class SchemaRootAPI(BaseModel): diff --git a/infrahub_sdk/schema/pydantic_utils.py b/infrahub_sdk/schema/pydantic_utils.py new file mode 100644 index 00000000..1833d91e --- /dev/null +++ b/infrahub_sdk/schema/pydantic_utils.py @@ -0,0 +1,584 @@ +from __future__ import annotations + +import re +import typing +from dataclasses import dataclass +from types import UnionType +from typing import TYPE_CHECKING, Any, Callable, ForwardRef, Literal, TypeVar, Union + +from pydantic import BaseModel +from pydantic import ConfigDict as BaseConfig +from pydantic._internal._model_construction import ModelMetaclass # noqa: PLC2701 +from pydantic._internal._repr import Representation # noqa: PLC2701 +from pydantic.fields import FieldInfo as PydanticFieldInfo +from pydantic.fields import PydanticUndefined as Undefined +from typing_extensions import Self + +from .main import ( + AttributeKind, + AttributeSchema, + BranchSupportType, + GenericSchema, + NodeSchema, + RelationshipKind, + RelationshipSchema, + SchemaRoot, + SchemaState, +) + +if TYPE_CHECKING: + from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync + +_T = TypeVar("_T") + +KIND_MAPPING: dict[type, AttributeKind] = { + int: AttributeKind.NUMBER, + float: AttributeKind.NUMBER, + str: AttributeKind.TEXT, + bool: AttributeKind.BOOLEAN, +} + +NAMESPACE_REGEX = r"^[A-Z][a-z0-9]+$" +NODE_KIND_REGEX = r"^[A-Z][a-zA-Z0-9]+$" + + +def __dataclass_transform__( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + field_descriptors: tuple[Union[type, Callable[..., Any]], ...] = (()), +) -> Callable[[_T], _T]: + return lambda a: a + + +class InfrahubConfig(BaseConfig, total=False): + generic: bool = False + name: str | None = None + namespace: str | None = None + display_labels: list[str] | None = None + description: str | None = None + state: SchemaState = SchemaState.PRESENT + label: str | None = None + include_in_menu: bool | None = None + menu_placement: str | None = None + + +class AttributeInfo(PydanticFieldInfo): + def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: + unique = kwargs.pop("unique", False) + label = kwargs.pop("label", None) + kind = kwargs.pop("kind", None) + regex = kwargs.pop("regex", None) + branch = kwargs.pop("branch", None) + super().__init__(default=default, **kwargs) + self.unique = unique + self.label = label + self.kind = kind + self.regex = regex + self.branch = branch + + +class RelationshipInfo(Representation): + def __init__( + self, + *, + alias: str | None = None, + kind: RelationshipKind | None = None, + peer: str | None = None, + description: str | None = None, + identifier: str | None = None, + branch: BranchSupportType | None = None, + optional: bool = False, + ) -> None: + self.alias = alias + self.kind = kind + self.identifier = identifier + self.branch = branch + self.description = description + self.peer = peer + self.optional = optional + + +def Relationship( + *, + alias: str | None = None, + kind: RelationshipKind | None = None, + identifier: str | None = None, + branch: BranchSupportType | None = None, + peer: str | None = None, + description: str | None = None, + optional: bool = False, +) -> Any: + relationship_info = RelationshipInfo( + alias=alias, + kind=kind, + identifier=identifier, + branch=branch, + peer=peer, + description=description, + optional=optional, + ) + return relationship_info + + +def Attribute( + default: Any = Undefined, + *, + alias: str | None = None, + description: str | None = None, + state: SchemaState = SchemaState.PRESENT, + kind: AttributeKind | None = None, + label: str | None = None, + unique: bool = False, + branch: BranchSupportType | None = None, + regex: str | None = None, + pattern: str | None = None, +) -> Any: + current_schema_extra = {} + field_info = AttributeInfo( + default, + alias=alias, + description=description, + state=state, + kind=kind, + label=label, + unique=unique, + branch=branch, + regex=regex, + pattern=pattern, + **current_schema_extra, + ) + return field_info + + +@__dataclass_transform__(kw_only_default=True, field_descriptors=(Attribute, AttributeInfo)) +class InfrahubMetaclass(ModelMetaclass): + __infrahub_relationships__: dict[str, RelationshipInfo] + model_config: InfrahubConfig + model_fields: dict[str, AttributeInfo] + + def __new__( + cls, + name: str, + bases: tuple[type[Any], ...], + class_dict: dict[str, Any], + **kwargs: Any, + ) -> Any: + relationships: dict[str, RelationshipInfo] = {} + dict_for_pydantic = {} + original_annotations: dict[str, Any] = class_dict.get("__annotations__", {}) + pydantic_annotations = {} + relationship_annotations = {} + for k, v in class_dict.items(): + if isinstance(v, RelationshipInfo): + relationships[k] = v + else: + dict_for_pydantic[k] = v + for k, v in original_annotations.items(): + if k in relationships: + relationship_annotations[k] = v + else: + pydantic_annotations[k] = v + dict_used = { + **dict_for_pydantic, + "__infrahub_relationships__": relationships, + "__annotations__": pydantic_annotations, + } + # Duplicate logic from Pydantic to filter config kwargs because if they are + # passed directly including the registry Pydantic will pass them over to the + # superclass causing an error + allowed_config_kwargs: set[str] = { + key + for key in dir(BaseConfig) + if not (key.startswith("__") and key.endswith("__")) # skip dunder methods and attributes + } + config_kwargs = {key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs} + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + new_cls.__annotations__ = { + **relationship_annotations, + **pydantic_annotations, + **new_cls.__annotations__, + } + + # def get_config(name: str) -> Any: + # config_class_value = new_cls.model_config.get(name, Undefined) + # if config_class_value is not Undefined: + # return config_class_value + # kwarg_value = kwargs.get(name, Undefined) + # if kwarg_value is not Undefined: + # return kwarg_value + # return Undefined + + # new_cls.model_config["generic"] = get_config("generic") + + return new_cls + + +class SchemaModel(BaseModel, metaclass=InfrahubMetaclass): + id: str | None = Attribute(default=None, description="The ID of the node") + + @classmethod + def get_kind(cls) -> str: + return get_kind(cls) + + @classmethod + def from_node(cls, node: InfrahubNode | InfrahubNodeSync) -> Self: + data = {} + for field_name, field in cls.model_fields.items(): + field_info = analyze_field(field_name, field) + if field_name == "id": + data[field_name] = node.id + elif field_info.is_attribute: + attr = getattr(node, field_name) + data[field_name] = attr.value + + # elif field_info.is_relationship: + # rel = getattr(node, field_name) + # data[field_name] = rel.value + + return cls(**data) + + +class NodeModel(SchemaModel): + pass + + +class GenericModel(SchemaModel): + pass + + +@dataclass +class InfrahubFieldInfo: + name: str + types: list[type] + optional: bool + default: Any + field_kind: Literal["attribute", "relationship"] | None = None + + @property + def primary_type(self) -> type: + if not self.types: + raise ValueError("No types found") + + # if isinstance(self.primary_type, ForwardRef): + # raise TypeError("Forward References are not supported yet, please ensure the models are defined in the right order") + + if self.is_list: + return typing.get_args(self.types[0])[0] + + return self.types[0] + + @property + def is_attribute(self) -> bool: + if self.field_kind == "attribute": + return True + return self.primary_type in KIND_MAPPING + + @property + def is_relationship(self) -> bool: + if self.field_kind == "relationship": + return True + if isinstance(self.primary_type, ForwardRef): + return True + return issubclass(self.primary_type, BaseModel) + + @property + def is_list(self) -> bool: + return typing.get_origin(self.types[0]) is list + + def to_dict(self) -> dict: + return { + "name": self.name, + "primary_type": self.primary_type, + "optional": self.optional, + "default": self.default, + "is_attribute": self.is_attribute, + "is_relationship": self.is_relationship, + "is_list": self.is_list, + } + + +def analyze_field(field_name: str, field: AttributeInfo | RelationshipInfo | PydanticFieldInfo) -> InfrahubFieldInfo: + if isinstance(field, RelationshipInfo): + return InfrahubFieldInfo( + name=field.alias or field_name, + types=[field.peer] if field.peer else [], + optional=field.optional, + field_kind="relationship", + default=None, + ) + + clean_types = [] + if isinstance(field.annotation, UnionType) or ( + hasattr(field.annotation, "_name") and field.annotation._name == "Optional" # type: ignore[union-attr] + ): + clean_types = [t for t in field.annotation.__args__ if t is not type(None)] # type: ignore[union-attr] + else: + clean_types.append(field.annotation) + + return InfrahubFieldInfo( + name=field.alias or field_name, + types=clean_types, + optional=not field.is_required(), + default=field.default if field.default is not Undefined else None, + ) + + +def get_attribute_kind(field: AttributeInfo | PydanticFieldInfo) -> AttributeKind: + if isinstance(field, AttributeInfo) and field.kind: + return field.kind + + if field.annotation in KIND_MAPPING: + return KIND_MAPPING[field.annotation] + + if isinstance(field.annotation, UnionType) or ( + hasattr(field.annotation, "_name") and field.annotation._name == "Optional" # type: ignore[union-attr] + ): + valid_types = [t for t in field.annotation.__args__ if t is not type(None)] # type: ignore[union-attr] + if len(valid_types) == 1 and valid_types[0] in KIND_MAPPING: + return KIND_MAPPING[valid_types[0]] + + raise ValueError(f"Unknown field type: {field.annotation}") + + +def field_to_attribute( + field_name: str, field_info: InfrahubFieldInfo, field: AttributeInfo | PydanticFieldInfo +) -> AttributeSchema: + pattern = field._attributes_set.get("pattern", None) + max_length = field._attributes_set.get("max_length", None) + min_length = field._attributes_set.get("min_length", None) + + if isinstance(field, AttributeInfo): + return AttributeSchema( + name=field_name, + label=field.label, + description=field.description, + kind=get_attribute_kind(field), + optional=field_info.optional, # not field.is_required(), + unique=field.unique, + branch=field.branch, + default_value=field_info.default, + regex=str(pattern) if pattern else None, + max_length=int(str(max_length)) if max_length else None, + min_length=int(str(min_length)) if min_length else None, + ) + + return AttributeSchema( + name=field_name, + # label=field.label, + description=field.description, + kind=get_attribute_kind(field), + optional=not field.is_required(), + # unique=field.unique, + # branch=field.branch, + default_value=field_info.default, + regex=str(pattern) if pattern else None, + max_length=int(str(max_length)) if max_length else None, + min_length=int(str(min_length)) if min_length else None, + ) + + +def field_to_relationship( + field_name: str, + field_info: InfrahubFieldInfo, + field: RelationshipInfo | PydanticFieldInfo, +) -> RelationshipSchema: + if isinstance(field, RelationshipInfo): + return RelationshipSchema( + name=field_name, + description=field.description, + peer=field.peer or get_kind(field_info.primary_type), + identifier=field.identifier, + cardinality="many" if field_info.is_list else "one", + optional=field_info.optional, + branch=field.branch, + ) + + return RelationshipSchema( + name=field_name, + description=field.description, + peer=get_kind(field_info.primary_type), + cardinality="many" if field_info.is_list else "one", + optional=field_info.optional, + ) + + +def extract_validate_generic(model: type[BaseModel]) -> list[str]: + return [get_kind(ancestor) for ancestor in model.__bases__ if issubclass(ancestor, GenericModel)] + + +def validate_kind(kind: str) -> tuple[str, str]: + """Validate the kind of a model. + + TODO Move the function to the main module + """ + + # First, handle transition from a lowercase to uppercase + name_with_spaces = re.sub(r"([a-z])([A-Z])", r"\1 \2", kind) + + # Then, handle consecutive uppercase letters followed by a lowercase + # (e.g., "HTTPRequest" -> "HTTP Request") + name_with_spaces = re.sub(r"([A-Z])([A-Z][a-z])", r"\1 \2", name_with_spaces) + + name_parts = name_with_spaces.split(" ") + + if len(name_parts) == 1: + raise ValueError(f"Invalid kind: {kind}, must contain a Namespace and a Name") + kind_namespace = name_parts[0] + kind_name = "".join(name_parts[1:]) + + if not kind_namespace[0].isupper(): + raise ValueError(f"Invalid namespace: {kind_namespace}, must start with an uppercase letter") + + return kind_namespace, kind_name + + +def is_generic(model: type[BaseModel]) -> bool: + return GenericModel in model.__bases__ + + +def get_kind(model: type[BaseModel] | ForwardRef) -> str: + """Get the kind of a model. + + If the model name and namespace are set in model_config, return the full kind. + If the model namespace is set in model_config, use the name of the class as name. + If the model has no name or namespace, extract both from the name of the class. + """ + + model_class: type[BaseModel] + + if isinstance(model, type) and issubclass(model, BaseModel): + model_class = model + elif isinstance(model, ForwardRef): + return model.__forward_arg__ + else: + raise ValueError(f"Expected BaseModel class, got {model}") + + name = model_class.model_config.get("name", None) + namespace = model_class.model_config.get("namespace", None) + class_name = model_class.__name__ + + if name and namespace: + return f"{namespace}{name}" + if namespace and not name and not class_name.startswith(namespace): + return f"{namespace}{class_name}" + + namespace, name = validate_kind(model.__name__) + return f"{namespace}{name}" + + +def get_generics(model: type[BaseModel]) -> list[type[GenericModel]]: + return [ancestor for ancestor in model.__bases__ if issubclass(ancestor, GenericModel)] + + +def _add_fields( + node: NodeSchema | GenericSchema, model: type[BaseModel], inherited_fields: dict[str, dict[str, Any]] | None = None +) -> None: + for field_name, field in model.model_fields.items(): + if ( + inherited_fields + and field_name in inherited_fields + and field._attributes_set == inherited_fields[field_name] + ): + continue + + if field_name == "id": + continue + + field_info = analyze_field(field_name, field) + + if field_info.is_attribute: + node.attributes.append(field_to_attribute(field_name, field_info, field)) + elif field_info.is_relationship: + node.relationships.append(field_to_relationship(field_name, field_info, field)) + + +def model_to_node(model: type[BaseModel]) -> NodeSchema | GenericSchema: + # ------------------------------------------------------------ + # GenericSchema + # ------------------------------------------------------------ + + kind = get_kind(model) + namespace, name = validate_kind(kind) + + if GenericModel in model.__bases__: + generic = GenericSchema( + name=name, + namespace=namespace, + display_labels=model.model_config.get("display_labels", None), + description=model.model_config.get("description", None), + state=model.model_config.get("state", SchemaState.PRESENT), + label=model.model_config.get("label", None), + # include_in_menu=generic_schema.include_in_menu if generic_schema else None, + # menu_placement=generic_schema.menu_placement if generic_schema else None, + # documentation=generic_schema.documentation if generic_schema else None, + # order_by=generic_schema.order_by if generic_schema else None, + # parent=schema.parent if schema else None, + # children=schema.children if schema else None, + # icon=generic_schema.icon if generic_schema else None, + # generate_profile=schema.generate_profile if schema else None, + # branch=schema.branch if schema else None, + # default_filter=schema.default_filter if schema else None, + ) + _add_fields(node=generic, model=model) + return generic + + # ------------------------------------------------------------ + # NodeSchema + # ------------------------------------------------------------ + generics = get_generics(model) + + # list all inherited fields with a hash for each to track if they are identical on the node + inherited_fields = { + field_name: field._attributes_set for generic in generics for field_name, field in generic.model_fields.items() + } + + node = NodeSchema( + name=name, + namespace=namespace, + display_labels=model.model_config.get("display_labels", None), + description=model.model_config.get("description", None), + state=model.model_config.get("state", SchemaState.PRESENT), + label=model.model_config.get("label", None), + # include_in_menu=node_schema.include_in_menu if node_schema else None, + # menu_placement=node_schema.menu_placement if node_schema else None, + # documentation=node_schema.documentation if node_schema else None, + # order_by=node_schema.order_by if node_schema else None, + inherit_from=[get_kind(generic) for generic in generics], + # parent=node_schema.parent if node_schema else None, + # children=node_schema.children if node_schema else None, + # icon=node_schema.icon if node_schema else None, + # generate_profile=node_schema.generate_profile if node_schema else None, + # branch=node_schema.branch if node_schema else None, + # default_filter=schema.default_filter if schema else None, + ) + + _add_fields(node=node, model=model, inherited_fields=inherited_fields) + return node + + +def from_pydantic(models: list[type[BaseModel]]) -> SchemaRoot: + schema = SchemaRoot(version="1.0") + + for model in models: + node = model_to_node(model=model) + + if isinstance(node, NodeSchema): + schema.nodes.append(node) + elif isinstance(node, GenericSchema): + schema.generics.append(node) + + return schema + + +# class NodeSchema(BaseModel): +# name: str| None = None +# namespace: str| None = None +# display_labels: list[str] | None = None + +# class NodeMetaclass(ModelMetaclass): +# model_config: NodeConfig +# # model_schema: NodeSchema +# __config__: type[NodeConfig] +# # __schema__: NodeSchema diff --git a/tests/unit/sdk/test_pydantic.py b/tests/unit/sdk/test_pydantic.py new file mode 100644 index 00000000..fc074b8b --- /dev/null +++ b/tests/unit/sdk/test_pydantic.py @@ -0,0 +1,606 @@ +from __future__ import annotations + +from typing import ForwardRef, Optional + +import pytest +from pydantic import BaseModel + +from infrahub_sdk.schema.main import ( + AttributeKind, + AttributeSchema, + GenericSchema, + NodeSchema, + RelationshipSchema, + SchemaState, +) +from infrahub_sdk.schema.pydantic_utils import ( + Attribute, + GenericModel, + InfrahubConfig, + NodeModel, + Relationship, + SchemaModel, + analyze_field, + field_to_attribute, + field_to_relationship, + from_pydantic, + get_attribute_kind, + get_kind, + model_to_node, +) + + +class MyAllInOneModel(NodeModel): + name: str + age: int + is_active: bool + opt_age: int | None = None + default_name: str = "some_default" + old_opt_age: Optional[int] = None + + +class AcmeTag(NodeModel): + name: str = Attribute(default="test_tag", description="The name of the tag") + description: str | None = Attribute(None, kind=AttributeKind.TEXTAREA) + label: str = Attribute(unique=True, description="The label of the tag") + + +class AcmeCar(NodeModel): + name: str + tags: list[AcmeTag] + owner: AcmePerson + secondary_owner: AcmePerson | None = Relationship(peer="AcmePerson", optional=True) + + +class AcmePerson(NodeModel): + name: str + cars: list[AcmeCar] | None = None + + +# -------------------------------- + + +class Book(NodeModel): + model_config = InfrahubConfig(name="Book", namespace="Library", display_labels=["name__value"]) + + title: str + isbn: str = Attribute(..., unique=True) + created_at: str + author: LibraryAuthor + + +class AbstractPerson(GenericModel): + model_config = InfrahubConfig(namespace="Library") + firstname: str = Attribute(..., description="The first name of the person", pattern=r"^[a-zA-Z]+$") + lastname: str + + +class LibraryAuthor(AbstractPerson): + books: list[Book] + + +class LibraryReader(AbstractPerson): + favorite_books: list[Book] + favorite_authors: list[LibraryAuthor] + + +@pytest.mark.parametrize( + "field_name, expected_kind", + [ + pytest.param("name", "Text", id="name_field"), + pytest.param("age", "Number", id="age_field"), + pytest.param("is_active", "Boolean", id="is_active_field"), + pytest.param("opt_age", "Number", id="opt_age_field"), + pytest.param("default_name", "Text", id="default_name_field"), + pytest.param("old_opt_age", "Number", id="old_opt_age_field"), + ], +) +def test_get_field_kind(field_name, expected_kind): + assert get_attribute_kind(MyAllInOneModel.model_fields[field_name]) == expected_kind + + +@pytest.mark.parametrize( + "field_name, model, expected", + [ + pytest.param( + "name", + MyAllInOneModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "name", + "optional": False, + "primary_type": str, + }, + id="MyAllInOneModel_name", + ), + pytest.param( + "age", + MyAllInOneModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "age", + "optional": False, + "primary_type": int, + }, + id="MyAllInOneModel_age", + ), + pytest.param( + "is_active", + MyAllInOneModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "is_active", + "optional": False, + "primary_type": bool, + }, + id="MyAllInOneModel_is_active", + ), + pytest.param( + "opt_age", + MyAllInOneModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "opt_age", + "optional": True, + "primary_type": int, + }, + id="MyAllInOneModel_opt_age", + ), + pytest.param( + "default_name", + MyAllInOneModel, + { + "default": "some_default", + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "default_name", + "optional": True, + "primary_type": str, + }, + id="MyAllInOneModel_default_name", + ), + pytest.param( + "old_opt_age", + MyAllInOneModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "old_opt_age", + "optional": True, + "primary_type": int, + }, + id="MyAllInOneModel_old_opt_age", + ), + pytest.param( + "description", + AcmeTag, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "description", + "optional": True, + "primary_type": str, + }, + id="AcmeTag_description", + ), + pytest.param( + "name", + AcmeTag, + { + "default": "test_tag", + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "name", + "optional": True, + "primary_type": str, + }, + id="AcmeTag_name", + ), + pytest.param( + "label", + AcmeTag, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "label", + "optional": False, + "primary_type": str, + }, + id="AcmeTag_label", + ), + pytest.param( + "owner", + AcmeCar, + { + "default": None, + "is_attribute": False, + "is_list": False, + "is_relationship": True, + "name": "owner", + "optional": False, + "primary_type": ForwardRef("AcmePerson"), + }, + id="AcmeCar_owner", + ), + pytest.param( + "tags", + AcmeCar, + { + "default": None, + "is_attribute": False, + "is_list": True, + "is_relationship": True, + "name": "tags", + "optional": False, + "primary_type": AcmeTag, + }, + id="AcmeCar_tags", + ), + pytest.param( + "secondary_owner", + AcmeCar, + { + "default": None, + "is_attribute": False, + "is_list": False, + "is_relationship": True, + "name": "secondary_owner", + "optional": True, + "primary_type": "AcmePerson", + }, + id="AcmeCar_secondary_owner", + ), + ], +) +def test_analyze_field(field_name: str, model: type[BaseModel], expected: dict): + if field_name in model.model_fields: + field = model.model_fields[field_name] + elif issubclass(model, SchemaModel) and field_name in model.__infrahub_relationships__: + field = model.__infrahub_relationships__[field_name] + else: + raise ValueError(f"Field {field_name} not found in model {model}") + assert analyze_field(field_name=field_name, field=field).to_dict() == expected + + +@pytest.mark.parametrize( + "field_name, model, expected", + [ + pytest.param( + "name", + MyAllInOneModel, + AttributeSchema( + name="name", + kind=AttributeKind.TEXT, + optional=False, + ), + id="MyAllInOneModel_name", + ), + pytest.param( + "age", + MyAllInOneModel, + AttributeSchema( + name="age", + kind=AttributeKind.NUMBER, + optional=False, + ), + id="MyAllInOneModel_age", + ), + pytest.param( + "is_active", + MyAllInOneModel, + AttributeSchema( + name="is_active", + kind=AttributeKind.BOOLEAN, + optional=False, + ), + id="MyAllInOneModel_is_active", + ), + pytest.param( + "opt_age", + MyAllInOneModel, + AttributeSchema( + name="opt_age", + kind=AttributeKind.NUMBER, + optional=True, + ), + id="MyAllInOneModel_opt_age", + ), + pytest.param( + "default_name", + MyAllInOneModel, + AttributeSchema( + name="default_name", + kind=AttributeKind.TEXT, + optional=True, + default_value="some_default", + ), + id="MyAllInOneModel_default_name", + ), + pytest.param( + "old_opt_age", + MyAllInOneModel, + AttributeSchema( + name="old_opt_age", + kind=AttributeKind.NUMBER, + optional=True, + ), + id="MyAllInOneModel_old_opt_age", + ), + pytest.param( + "description", + AcmeTag, + AttributeSchema( + name="description", + kind=AttributeKind.TEXTAREA, + optional=True, + ), + id="AcmeTag_description", + ), + pytest.param( + "name", + AcmeTag, + AttributeSchema( + name="name", + description="The name of the tag", + kind=AttributeKind.TEXT, + optional=True, + default_value="test_tag", + ), + id="AcmeTag_name", + ), + pytest.param( + "label", + AcmeTag, + AttributeSchema( + name="label", + description="The label of the tag", + kind=AttributeKind.TEXT, + optional=False, + unique=True, + ), + id="AcmeTag_label", + ), + pytest.param( + "firstname", + AbstractPerson, + AttributeSchema( + name="firstname", + description="The first name of the person", + kind=AttributeKind.TEXT, + optional=False, + regex=r"^[a-zA-Z]+$", + ), + id="AbstractPerson_firstname", + ), + ], +) +def test_field_to_attribute(field_name: str, model: type[BaseModel], expected: AttributeSchema): + field = model.model_fields[field_name] + field_info = analyze_field(field_name, field) + assert field_to_attribute(field_name, field_info, field) == expected + + +@pytest.mark.parametrize( + "field_name, model, expected", + [ + pytest.param( + "owner", + AcmeCar, + RelationshipSchema( + name="owner", + peer="AcmePerson", + cardinality="one", + optional=False, + ), + id="AcmeCar_owner", + ), + pytest.param( + "tags", + AcmeCar, + RelationshipSchema( + name="tags", + peer="AcmeTag", + cardinality="many", + optional=False, + ), + id="AcmeCar_tags", + ), + pytest.param( + "secondary_owner", + AcmeCar, + RelationshipSchema( + name="secondary_owner", + peer="AcmePerson", + cardinality="one", + optional=True, + ), + id="AcmeCar_secondary_owner", + ), + ], +) +def test_field_to_relationship(field_name: str, model: type[BaseModel | SchemaModel], expected: RelationshipSchema): + if field_name in model.model_fields: + field = model.model_fields[field_name] + elif issubclass(model, SchemaModel) and field_name in model.__infrahub_relationships__: + field = model.__infrahub_relationships__[field_name] + else: + raise ValueError(f"Field {field_name} not found in model {model}") + field_info = analyze_field(field_name, field) + assert field_to_relationship(field_name, field_info, field) == expected + + +@pytest.mark.parametrize( + "model, expected", + [ + pytest.param(MyAllInOneModel, "MyAllInOneModel", id="MyAllInOneModel"), + pytest.param(Book, "LibraryBook", id="Book"), + pytest.param(LibraryAuthor, "LibraryAuthor", id="LibraryAuthor"), + pytest.param(LibraryReader, "LibraryReader", id="LibraryReader"), + pytest.param(AbstractPerson, "LibraryAbstractPerson", id="AbstractPerson"), + pytest.param(AcmeTag, "AcmeTag", id="AcmeTag"), + pytest.param(AcmeCar, "AcmeCar", id="AcmeCar"), + pytest.param(AcmePerson, "AcmePerson", id="AcmePerson"), + ], +) +def test_get_kind(model: type[BaseModel], expected: str): + assert get_kind(model) == expected + + +@pytest.mark.parametrize( + "model, expected", + [ + pytest.param( + MyAllInOneModel, + NodeSchema( + name="AllInOneModel", + namespace="My", + state=SchemaState.PRESENT, + attributes=[ + AttributeSchema(name="name", kind=AttributeKind.TEXT, optional=False), + AttributeSchema(name="age", kind=AttributeKind.NUMBER, optional=False), + AttributeSchema(name="is_active", kind=AttributeKind.BOOLEAN, optional=False), + AttributeSchema(name="opt_age", kind=AttributeKind.NUMBER, optional=True), + AttributeSchema( + name="default_name", kind=AttributeKind.TEXT, optional=True, default_value="some_default" + ), + AttributeSchema(name="old_opt_age", kind=AttributeKind.NUMBER, optional=True), + ], + ), + id="MyAllInOneModel", + ), + pytest.param( + Book, + NodeSchema( + name="Book", + namespace="Library", + display_labels=["name__value"], + state=SchemaState.PRESENT, + attributes=[ + AttributeSchema(name="title", kind=AttributeKind.TEXT, optional=False), + AttributeSchema(name="isbn", kind=AttributeKind.TEXT, optional=False, unique=True), + AttributeSchema(name="created_at", kind=AttributeKind.TEXT, optional=False), + ], + relationships=[ + RelationshipSchema( + name="author", + peer="LibraryAuthor", + cardinality="one", + optional=False, + relationships=[ + RelationshipSchema(name="books", peer="LibraryBook", cardinality="many", optional=False), + ], + ), + ], + ), + id="Book", + ), + pytest.param( + LibraryAuthor, + NodeSchema( + name="Author", + namespace="Library", + inherit_from=["LibraryAbstractPerson"], + state=SchemaState.PRESENT, + relationships=[ + RelationshipSchema(name="books", peer="LibraryBook", cardinality="many", optional=False), + ], + ), + id="LibraryAuthor", + ), + pytest.param( + LibraryReader, + NodeSchema( + name="Reader", + namespace="Library", + inherit_from=["LibraryAbstractPerson"], + state=SchemaState.PRESENT, + relationships=[ + RelationshipSchema(name="favorite_books", peer="LibraryBook", cardinality="many", optional=False), + RelationshipSchema( + name="favorite_authors", peer="LibraryAuthor", cardinality="many", optional=False + ), + ], + ), + id="LibraryReader", + ), + pytest.param( + AbstractPerson, + GenericSchema( + name="AbstractPerson", + namespace="Library", + state=SchemaState.PRESENT, + attributes=[ + AttributeSchema( + name="firstname", + kind=AttributeKind.TEXT, + optional=False, + description="The first name of the person", + regex=r"^[a-zA-Z]+$", + ), + AttributeSchema(name="lastname", kind=AttributeKind.TEXT, optional=False), + ], + ), + id="AbstractPerson", + ), + pytest.param( + AcmeTag, + NodeSchema( + name="Tag", + namespace="Acme", + state=SchemaState.PRESENT, + attributes=[ + AttributeSchema( + name="name", + kind=AttributeKind.TEXT, + default_value="test_tag", + optional=True, + description="The name of the tag", + ), + AttributeSchema(name="description", kind=AttributeKind.TEXTAREA, optional=True), + AttributeSchema( + name="label", + kind=AttributeKind.TEXT, + optional=False, + unique=True, + description="The label of the tag", + ), + ], + ), + id="AcmeTag", + ), + ], +) +def test_model_to_node(model: type[BaseModel], expected: NodeSchema): + node = model_to_node(model) + assert node == expected + + +def test_related_models(): + schemas = from_pydantic(models=[AcmePerson, AcmeCar, AcmeTag]) + assert len(schemas.nodes) == 3 + + +# def test_library_models(): +# schemas = from_pydantic(models=[Book, AbstractPerson, LibraryAuthor, LibraryReader]) +# assert len(schemas.nodes) == 3 +# assert len(schemas.generics) == 1