66from  pydantic  import  BaseModel , Field 
77
88from  ..exceptions  import  ObjectValidationError , ValidationError 
9- from  ..schema  import  RelationshipSchema 
9+ from  ..schema  import  GenericSchemaAPI ,  RelationshipSchema 
1010from  ..yaml  import  InfrahubFile , InfrahubFileKind 
1111
1212if  TYPE_CHECKING :
@@ -168,14 +168,21 @@ async def validate_format(self, client: InfrahubClient, branch: str | None = Non
168168        schema  =  await  client .schema .get (kind = self .kind , branch = branch )
169169        for  idx , item  in  enumerate (self .data ):
170170            errors .extend (
171-                 await  self .validate_object (client = client , position = [idx  +  1 ], schema = schema , data = item , branch = branch )
171+                 await  self .validate_object (
172+                     client = client ,
173+                     position = [idx  +  1 ],
174+                     schema = schema ,
175+                     data = item ,
176+                     branch = branch ,
177+                     default_schema_kind = self .kind ,
178+                 )
172179            )
173180        return  errors 
174181
175182    async  def  process (self , client : InfrahubClient , branch : str  |  None  =  None ) ->  None :
176183        schema  =  await  client .schema .get (kind = self .kind , branch = branch )
177184        for  idx , item  in  enumerate (self .data ):
178-             await  self .create_node (client = client , schema = schema , data = item , position = [idx  +  1 ], branch = branch )
185+             await  self .create_node (client = client , schema = schema , data = item , position = [idx  +  1 ], branch = branch ,  default_schema_kind = self . kind )
179186
180187    @classmethod  
181188    async  def  validate_object (
@@ -186,6 +193,7 @@ async def validate_object(
186193        position : list [int  |  str ],
187194        context : dict  |  None  =  None ,
188195        branch : str  |  None  =  None ,
196+         default_schema_kind : str  |  None  =  None ,
189197    ) ->  list [ObjectValidationError ]:
190198        errors : list [ObjectValidationError ] =  []
191199        context  =  context .copy () if  context  else  {}
@@ -234,6 +242,7 @@ async def validate_object(
234242                        data = value ,
235243                        context = context ,
236244                        branch = branch ,
245+                         default_schema_kind = default_schema_kind ,
237246                    )
238247                )
239248
@@ -248,6 +257,7 @@ async def validate_related_nodes(
248257        data : dict  |  list [dict ],
249258        context : dict  |  None  =  None ,
250259        branch : str  |  None  =  None ,
260+         default_schema_kind : str  |  None  =  None ,
251261    ) ->  list [ObjectValidationError ]:
252262        context  =  context .copy () if  context  else  {}
253263        errors : list [ObjectValidationError ] =  []
@@ -260,7 +270,9 @@ async def validate_related_nodes(
260270
261271        if  isinstance (data , dict ) and  rel_info .format  ==  RelationshipDataFormat .ONE_OBJ :
262272            peer_kind  =  data .get ("kind" ) or  rel_info .peer_kind 
263-             peer_schema  =  await  client .schema .get (kind = peer_kind , branch = branch )
273+             peer_schema  =  await  cls .get_peer_schema (
274+                 client = client , peer_kind = peer_kind , branch = branch , default_schema_kind = default_schema_kind 
275+             )
264276
265277            rel_info .find_matching_relationship (peer_schema = peer_schema )
266278            context .update (rel_info .get_context (value = "placeholder" ))
@@ -273,13 +285,16 @@ async def validate_related_nodes(
273285                    data = data ["data" ],
274286                    context = context ,
275287                    branch = branch ,
288+                     default_schema_kind = default_schema_kind ,
276289                )
277290            )
278291            return  errors 
279292
280293        if  isinstance (data , dict ) and  rel_info .format  ==  RelationshipDataFormat .MANY_OBJ_DICT_LIST :
281294            peer_kind  =  data .get ("kind" ) or  rel_info .peer_kind 
282-             peer_schema  =  await  client .schema .get (kind = peer_kind , branch = branch )
295+             peer_schema  =  await  cls .get_peer_schema (
296+                 client = client , peer_kind = peer_kind , branch = branch , default_schema_kind = default_schema_kind 
297+             )
283298
284299            rel_info .find_matching_relationship (peer_schema = peer_schema )
285300            context .update (rel_info .get_context (value = "placeholder" ))
@@ -294,6 +309,7 @@ async def validate_related_nodes(
294309                        data = peer_data ,
295310                        context = context ,
296311                        branch = branch ,
312+                         default_schema_kind = default_schema_kind ,
297313                    )
298314                )
299315            return  errors 
@@ -302,7 +318,9 @@ async def validate_related_nodes(
302318            for  idx , item  in  enumerate (data ):
303319                context ["list_index" ] =  idx 
304320                peer_kind  =  item .get ("kind" ) or  rel_info .peer_kind 
305-                 peer_schema  =  await  client .schema .get (kind = peer_kind , branch = branch )
321+                 peer_schema  =  await  cls .get_peer_schema (
322+                     client = client , peer_kind = peer_kind , branch = branch , default_schema_kind = default_schema_kind 
323+                 )
306324
307325                rel_info .find_matching_relationship (peer_schema = peer_schema )
308326                context .update (rel_info .get_context (value = "placeholder" ))
@@ -315,6 +333,7 @@ async def validate_related_nodes(
315333                        data = item ["data" ],
316334                        context = context ,
317335                        branch = branch ,
336+                         default_schema_kind = default_schema_kind ,
318337                    )
319338                )
320339            return  errors 
@@ -345,7 +364,13 @@ async def create_node(
345364        context  =  context .copy () if  context  else  {}
346365
347366        errors  =  await  cls .validate_object (
348-             client = client , position = position , schema = schema , data = data , context = context , branch = branch 
367+             client = client ,
368+             position = position ,
369+             schema = schema ,
370+             data = data ,
371+             context = context ,
372+             branch = branch ,
373+             default_schema_kind = default_schema_kind ,
349374        )
350375        if  errors :
351376            messages  =  [str (error ) for  error  in  errors ]
@@ -480,7 +505,9 @@ async def create_related_nodes(
480505
481506        if  isinstance (data , dict ) and  rel_info .format  ==  RelationshipDataFormat .MANY_OBJ_DICT_LIST :
482507            peer_kind  =  data .get ("kind" ) or  rel_info .peer_kind 
483-             peer_schema  =  await  client .schema .get (kind = peer_kind , branch = branch )
508+             peer_schema  =  await  cls .get_peer_schema (
509+                 client = client , peer_kind = peer_kind , branch = branch , default_schema_kind = default_schema_kind 
510+             )
484511
485512            if  parent_node :
486513                rel_info .find_matching_relationship (peer_schema = peer_schema )
@@ -506,7 +533,9 @@ async def create_related_nodes(
506533                context ["list_index" ] =  idx 
507534
508535                peer_kind  =  item .get ("kind" ) or  rel_info .peer_kind 
509-                 peer_schema  =  await  client .schema .get (kind = peer_kind , branch = branch )
536+                 peer_schema  =  await  cls .get_peer_schema (
537+                     client = client , peer_kind = peer_kind , branch = branch , default_schema_kind = default_schema_kind 
538+                 )
510539
511540                if  parent_node :
512541                    rel_info .find_matching_relationship (peer_schema = peer_schema )
@@ -529,6 +558,23 @@ async def create_related_nodes(
529558            f"Relationship { rel_info .rel_schema .name } { rel_info .rel_schema .cardinality } { type (data )}  
530559        )
531560
561+     @classmethod  
562+     async  def  get_peer_schema (
563+         cls , client : InfrahubClient , peer_kind : str , branch : str  |  None  =  None , default_schema_kind : str  |  None  =  None 
564+     ) ->  MainSchemaTypesAPI :
565+         peer_schema  =  await  client .schema .get (kind = peer_kind , branch = branch )
566+         if  not  isinstance (peer_schema , GenericSchemaAPI ):
567+             return  peer_schema 
568+ 
569+         if  not  default_schema_kind :
570+             raise  ValueError (f"Found a peer schema as a generic { peer_kind }  )
571+ 
572+         # if the initial peer_kind was a generic, we try the default_schema_kind 
573+         peer_schema  =  await  client .schema .get (kind = default_schema_kind , branch = branch )
574+         if  isinstance (peer_schema , GenericSchemaAPI ):
575+             raise  ValueError (f"Default schema kind { default_schema_kind }  )
576+         return  peer_schema 
577+ 
532578
533579class  ObjectFile (InfrahubFile ):
534580    _spec : InfrahubObjectFileData  |  None  =  None 
0 commit comments