11from collections .abc import Mapping
2- from typing import Any , Optional
2+ from typing import Any , Optional , Union
33
44import jinja2
55
66from . import protocols as sdk_protocols
77from .ctl .constants import PROTOCOLS_TEMPLATE
88from .schema import (
9- AttributeSchema ,
9+ AttributeSchemaAPI ,
1010 GenericSchema ,
11- MainSchemaTypes ,
11+ GenericSchemaAPI ,
12+ MainSchemaTypesAll ,
1213 NodeSchema ,
13- ProfileSchema ,
14- RelationshipSchema ,
14+ NodeSchemaAPI ,
15+ ProfileSchemaAPI ,
16+ RelationshipSchemaAPI ,
1517)
1618
1719ATTRIBUTE_KIND_MAP = {
4042
4143
4244class CodeGenerator :
43- def __init__ (self , schema : dict [str , MainSchemaTypes ]):
44- self .generics : dict [str , GenericSchema ] = {}
45- self .nodes : dict [str , NodeSchema ] = {}
46- self .profiles : dict [str , ProfileSchema ] = {}
45+ def __init__ (self , schema : dict [str , MainSchemaTypesAll ]):
46+ self .generics : dict [str , Union [ GenericSchemaAPI , GenericSchema ] ] = {}
47+ self .nodes : dict [str , Union [ NodeSchemaAPI , NodeSchema ] ] = {}
48+ self .profiles : dict [str , ProfileSchemaAPI ] = {}
4749
4850 for name , schema_type in schema .items ():
49- if isinstance (schema_type , GenericSchema ):
51+ if isinstance (schema_type , ( GenericSchemaAPI , GenericSchema ) ):
5052 self .generics [name ] = schema_type
51- if isinstance (schema_type , NodeSchema ):
53+ if isinstance (schema_type , ( NodeSchemaAPI , NodeSchema ) ):
5254 self .nodes [name ] = schema_type
53- if isinstance (schema_type , ProfileSchema ):
55+ if isinstance (schema_type , ProfileSchemaAPI ):
5456 self .profiles [name ] = schema_type
5557
5658 self .base_protocols = [
@@ -92,7 +94,7 @@ def _jinja2_filter_inheritance(value: dict[str, Any]) -> str:
9294 return ", " .join (inherit_from )
9395
9496 @staticmethod
95- def _jinja2_filter_render_attribute (value : AttributeSchema ) -> str :
97+ def _jinja2_filter_render_attribute (value : AttributeSchemaAPI ) -> str :
9698 attribute_kind : str = ATTRIBUTE_KIND_MAP [value .kind ]
9799
98100 if value .optional :
@@ -101,7 +103,7 @@ def _jinja2_filter_render_attribute(value: AttributeSchema) -> str:
101103 return f"{ value .name } : { attribute_kind } "
102104
103105 @staticmethod
104- def _jinja2_filter_render_relationship (value : RelationshipSchema , sync : bool = False ) -> str :
106+ def _jinja2_filter_render_relationship (value : RelationshipSchemaAPI , sync : bool = False ) -> str :
105107 name = value .name
106108 cardinality = value .cardinality
107109
@@ -116,12 +118,12 @@ def _jinja2_filter_render_relationship(value: RelationshipSchema, sync: bool = F
116118
117119 @staticmethod
118120 def _sort_and_filter_models (
119- models : Mapping [str , MainSchemaTypes ], filters : Optional [list [str ]] = None
120- ) -> list [MainSchemaTypes ]:
121+ models : Mapping [str , MainSchemaTypesAll ], filters : Optional [list [str ]] = None
122+ ) -> list [MainSchemaTypesAll ]:
121123 if filters is None :
122124 filters = ["CoreNode" ]
123125
124- filtered : list [MainSchemaTypes ] = []
126+ filtered : list [MainSchemaTypesAll ] = []
125127 for name , model in models .items ():
126128 if name in filters :
127129 continue
0 commit comments