22
33import  asyncio 
44import  json 
5+ import  warnings 
56from  collections .abc  import  MutableMapping 
67from  enum  import  Enum 
78from  time  import  sleep 
@@ -90,6 +91,13 @@ class EnumMutation(str, Enum):
9091
9192
9293class  InfrahubSchemaBase :
94+     client : InfrahubClient  |  InfrahubClientSync 
95+     cache : dict [str , BranchSchema ]
96+ 
97+     def  __init__ (self , client : InfrahubClient  |  InfrahubClientSync ):
98+         self .client  =  client 
99+         self .cache  =  {}
100+ 
93101    def  validate (self , data : dict [str , Any ]) ->  None :
94102        SchemaRoot (** data )
95103
@@ -102,6 +110,23 @@ def validate_data_against_schema(self, schema: MainSchemaTypesAPI, data: dict) -
102110                    message = f"{ key } { identifier }  ,
103111                )
104112
113+     def  set_cache (self , schema : dict [str , Any ] |  SchemaRootAPI  |  BranchSchema , branch : str  |  None  =  None ) ->  None :
114+         """ 
115+         Set the cache manually (primarily for unit testing) 
116+ 
117+         Args: 
118+             schema: The schema to set the cache as provided by the /api/schema endpoint either in dict or SchemaRootAPI format 
119+             branch: The name of the branch to set the cache for. 
120+         """ 
121+         branch  =  branch  or  self .client .default_branch 
122+ 
123+         if  isinstance (schema , SchemaRootAPI ):
124+             schema  =  BranchSchema .from_schema_root_api (data = schema )
125+         elif  isinstance (schema , dict ):
126+             schema  =  BranchSchema .from_api_response (data = schema )
127+ 
128+         self .cache [branch ] =  schema 
129+ 
105130    def  generate_payload_create (
106131        self ,
107132        schema : MainSchemaTypesAPI ,
@@ -187,11 +212,18 @@ def _parse_schema_response(response: httpx.Response, branch: str) -> MutableMapp
187212
188213        return  data 
189214
215+     @staticmethod  
216+     def  _deprecated_schema_timeout () ->  None :
217+         warnings .warn (
218+             "The 'timeout' parameter is deprecated while fetching the schema and will be removed version 2.0.0 of the Infrahub Python SDK. " 
219+             "Use client.default_timeout instead." ,
220+             DeprecationWarning ,
221+             stacklevel = 2 ,
222+         )
223+ 
190224
191225class  InfrahubSchema (InfrahubSchemaBase ):
192-     def  __init__ (self , client : InfrahubClient ):
193-         self .client  =  client 
194-         self .cache : dict [str , BranchSchema ] =  {}
226+     client : InfrahubClient 
195227
196228    async  def  get (
197229        self ,
@@ -204,16 +236,19 @@ async def get(
204236
205237        kind_str  =  self ._get_schema_name (schema = kind )
206238
239+         if  timeout :
240+             self ._deprecated_schema_timeout ()
241+ 
207242        if  refresh :
208-             self .cache [branch ] =  await  self ._fetch (branch = branch ,  timeout = timeout )
243+             self .cache [branch ] =  await  self ._fetch (branch = branch )
209244
210245        if  branch  in  self .cache  and  kind_str  in  self .cache [branch ].nodes :
211246            return  self .cache [branch ].nodes [kind_str ]
212247
213248        # Fetching the latest schema from the server if we didn't fetch it earlier 
214249        #   because we coulnd't find the object on the local cache 
215250        if  not  refresh :
216-             self .cache [branch ] =  await  self ._fetch (branch = branch ,  timeout = timeout )
251+             self .cache [branch ] =  await  self ._fetch (branch = branch )
217252
218253        if  branch  in  self .cache  and  kind_str  in  self .cache [branch ].nodes :
219254            return  self .cache [branch ].nodes [kind_str ]
@@ -416,59 +451,45 @@ async def add_dropdown_option(
416451        )
417452
418453    async  def  fetch (
419-         self , branch : str , namespaces : list [str ] |  None  =  None , timeout : int  |  None  =  None 
454+         self , branch : str , namespaces : list [str ] |  None  =  None , timeout : int  |  None  =  None ,  populate_cache :  bool   =   True 
420455    ) ->  MutableMapping [str , MainSchemaTypesAPI ]:
421456        """Fetch the schema from the server for a given branch. 
422457
423458        Args: 
424-             branch (str): Name of the branch to fetch the schema for. 
425-             timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds. 
459+             branch: Name of the branch to fetch the schema for. 
460+             timeout: Overrides default timeout used when querying the schema. deprecated. 
461+             populate_cache: Whether to populate the cache with the fetched schema. Defaults to True. 
426462
427463        Returns: 
428464            dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind 
429465        """ 
430-         branch_schema  =  await  self ._fetch (branch = branch , namespaces = namespaces , timeout = timeout )
466+ 
467+         if  timeout :
468+             self ._deprecated_schema_timeout ()
469+ 
470+         branch_schema  =  await  self ._fetch (branch = branch , namespaces = namespaces )
471+ 
472+         if  populate_cache :
473+             self .cache [branch ] =  branch_schema 
474+ 
431475        return  branch_schema .nodes 
432476
433-     async  def  _fetch (
434-         self , branch : str , namespaces : list [str ] |  None  =  None , timeout : int  |  None  =  None 
435-     ) ->  BranchSchema :
477+     async  def  _fetch (self , branch : str , namespaces : list [str ] |  None  =  None ) ->  BranchSchema :
436478        url_parts  =  [("branch" , branch )]
437479        if  namespaces :
438480            url_parts .extend ([("namespaces" , ns ) for  ns  in  namespaces ])
439481        query_params  =  urlencode (url_parts )
440482        url  =  f"{ self .client .address } { query_params }  
441483
442-         response  =  await  self .client ._get (url = url ,  timeout = timeout )
484+         response  =  await  self .client ._get (url = url )
443485
444486        data  =  self ._parse_schema_response (response = response , branch = branch )
445487
446-         nodes : MutableMapping [str , MainSchemaTypesAPI ] =  {}
447-         for  node_schema  in  data .get ("nodes" , []):
448-             node  =  NodeSchemaAPI (** node_schema )
449-             nodes [node .kind ] =  node 
450- 
451-         for  generic_schema  in  data .get ("generics" , []):
452-             generic  =  GenericSchemaAPI (** generic_schema )
453-             nodes [generic .kind ] =  generic 
454- 
455-         for  profile_schema  in  data .get ("profiles" , []):
456-             profile  =  ProfileSchemaAPI (** profile_schema )
457-             nodes [profile .kind ] =  profile 
458- 
459-         for  template_schema  in  data .get ("templates" , []):
460-             template  =  TemplateSchemaAPI (** template_schema )
461-             nodes [template .kind ] =  template 
462- 
463-         schema_hash  =  data .get ("main" , "" )
464- 
465-         return  BranchSchema (hash = schema_hash , nodes = nodes )
488+         return  BranchSchema .from_api_response (data = data )
466489
467490
468491class  InfrahubSchemaSync (InfrahubSchemaBase ):
469-     def  __init__ (self , client : InfrahubClientSync ):
470-         self .client  =  client 
471-         self .cache : dict [str , BranchSchema ] =  {}
492+     client : InfrahubClientSync 
472493
473494    def  all (
474495        self ,
@@ -506,10 +527,25 @@ def get(
506527        refresh : bool  =  False ,
507528        timeout : int  |  None  =  None ,
508529    ) ->  MainSchemaTypesAPI :
530+         """ 
531+         Retrieve a specific schema object from the server. 
532+ 
533+         Args: 
534+             kind: The kind of schema object to retrieve. 
535+             branch: The branch to retrieve the schema from. 
536+             refresh: Whether to refresh the schema. 
537+             timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated). 
538+ 
539+         Returns: 
540+             MainSchemaTypes: The schema object. 
541+         """ 
509542        branch  =  branch  or  self .client .default_branch 
510543
511544        kind_str  =  self ._get_schema_name (schema = kind )
512545
546+         if  timeout :
547+             self ._deprecated_schema_timeout ()
548+ 
513549        if  refresh :
514550            self .cache [branch ] =  self ._fetch (branch = branch )
515551
@@ -519,7 +555,7 @@ def get(
519555        # Fetching the latest schema from the server if we didn't fetch it earlier 
520556        #   because we coulnd't find the object on the local cache 
521557        if  not  refresh :
522-             self .cache [branch ] =  self ._fetch (branch = branch ,  timeout = timeout )
558+             self .cache [branch ] =  self ._fetch (branch = branch )
523559
524560        if  branch  in  self .cache  and  kind_str  in  self .cache [branch ].nodes :
525561            return  self .cache [branch ].nodes [kind_str ]
@@ -639,49 +675,39 @@ def add_dropdown_option(
639675        )
640676
641677    def  fetch (
642-         self , branch : str , namespaces : list [str ] |  None  =  None , timeout : int  |  None  =  None 
678+         self , branch : str , namespaces : list [str ] |  None  =  None , timeout : int  |  None  =  None ,  populate_cache :  bool   =   True 
643679    ) ->  MutableMapping [str , MainSchemaTypesAPI ]:
644680        """Fetch the schema from the server for a given branch. 
645681
646682        Args: 
647-             branch (str): Name of the branch to fetch the schema for. 
648-             timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds. 
683+             branch: Name of the branch to fetch the schema for. 
684+             timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated). 
685+             populate_cache: Whether to populate the cache with the fetched schema. Defaults to True. 
649686
650687        Returns: 
651688            dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind 
652689        """ 
653-         branch_schema  =  self ._fetch (branch = branch , namespaces = namespaces , timeout = timeout )
690+         if  timeout :
691+             self ._deprecated_schema_timeout ()
692+ 
693+         branch_schema  =  self ._fetch (branch = branch , namespaces = namespaces )
694+ 
695+         if  populate_cache :
696+             self .cache [branch ] =  branch_schema 
697+ 
654698        return  branch_schema .nodes 
655699
656-     def  _fetch (self , branch : str , namespaces : list [str ] |  None  =  None ,  timeout :  int   |   None   =   None ) ->  BranchSchema :
700+     def  _fetch (self , branch : str , namespaces : list [str ] |  None  =  None ) ->  BranchSchema :
657701        url_parts  =  [("branch" , branch )]
658702        if  namespaces :
659703            url_parts .extend ([("namespaces" , ns ) for  ns  in  namespaces ])
660704        query_params  =  urlencode (url_parts )
661705        url  =  f"{ self .client .address } { query_params }  
662-         response  =  self .client ._get (url = url , timeout = timeout )
663-         data  =  self ._parse_schema_response (response = response , branch = branch )
706+         response  =  self .client ._get (url = url )
664707
665-         nodes : MutableMapping [str , MainSchemaTypesAPI ] =  {}
666-         for  node_schema  in  data .get ("nodes" , []):
667-             node  =  NodeSchemaAPI (** node_schema )
668-             nodes [node .kind ] =  node 
669- 
670-         for  generic_schema  in  data .get ("generics" , []):
671-             generic  =  GenericSchemaAPI (** generic_schema )
672-             nodes [generic .kind ] =  generic 
673- 
674-         for  profile_schema  in  data .get ("profiles" , []):
675-             profile  =  ProfileSchemaAPI (** profile_schema )
676-             nodes [profile .kind ] =  profile 
677- 
678-         for  template_schema  in  data .get ("templates" , []):
679-             template  =  TemplateSchemaAPI (** template_schema )
680-             nodes [template .kind ] =  template 
681- 
682-         schema_hash  =  data .get ("main" , "" )
708+         data  =  self ._parse_schema_response (response = response , branch = branch )
683709
684-         return  BranchSchema ( hash = schema_hash ,  nodes = nodes )
710+         return  BranchSchema . from_api_response ( data = data )
685711
686712    def  load (
687713        self , schemas : list [dict ], branch : str  |  None  =  None , wait_until_converged : bool  =  False 
0 commit comments