22
33import asyncio
44import json
5+ import warnings
56from collections .abc import MutableMapping
67from enum import Enum
78from time import sleep
@@ -90,6 +91,13 @@ class EnumMutation(str, Enum):
9091
9192
9293class InfrahubSchemaBase :
94+ client : InfrahubClient | InfrahubClientSync
95+ cache : dict [str , BranchSchema ]
96+
97+ def __init__ (self , client : InfrahubClient | InfrahubClientSync ):
98+ self .client = client
99+ self .cache = {}
100+
93101 def validate (self , data : dict [str , Any ]) -> None :
94102 SchemaRoot (** data )
95103
@@ -102,6 +110,23 @@ def validate_data_against_schema(self, schema: MainSchemaTypesAPI, data: dict) -
102110 message = f"{ key } is not a valid value for { identifier } " ,
103111 )
104112
113+ def set_cache (self , schema : dict [str , Any ] | SchemaRootAPI | BranchSchema , branch : str | None = None ) -> None :
114+ """
115+ Set the cache manually (primarily for unit testing)
116+
117+ Args:
118+ schema: The schema to set the cache as provided by the /api/schema endpoint either in dict or SchemaRootAPI format
119+ branch: The name of the branch to set the cache for.
120+ """
121+ branch = branch or self .client .default_branch
122+
123+ if isinstance (schema , SchemaRootAPI ):
124+ schema = BranchSchema .from_schema_root_api (data = schema )
125+ elif isinstance (schema , dict ):
126+ schema = BranchSchema .from_api_response (data = schema )
127+
128+ self .cache [branch ] = schema
129+
105130 def generate_payload_create (
106131 self ,
107132 schema : MainSchemaTypesAPI ,
@@ -187,11 +212,18 @@ def _parse_schema_response(response: httpx.Response, branch: str) -> MutableMapp
187212
188213 return data
189214
215+ @staticmethod
216+ def _deprecated_schema_timeout () -> None :
217+ warnings .warn (
218+ "The 'timeout' parameter is deprecated while fetching the schema and will be removed version 2.0.0 of the Infrahub Python SDK. "
219+ "Use client.default_timeout instead." ,
220+ DeprecationWarning ,
221+ stacklevel = 2 ,
222+ )
223+
190224
191225class InfrahubSchema (InfrahubSchemaBase ):
192- def __init__ (self , client : InfrahubClient ):
193- self .client = client
194- self .cache : dict [str , BranchSchema ] = {}
226+ client : InfrahubClient
195227
196228 async def get (
197229 self ,
@@ -204,16 +236,19 @@ async def get(
204236
205237 kind_str = self ._get_schema_name (schema = kind )
206238
239+ if timeout :
240+ self ._deprecated_schema_timeout ()
241+
207242 if refresh :
208- self .cache [branch ] = await self ._fetch (branch = branch , timeout = timeout )
243+ self .cache [branch ] = await self ._fetch (branch = branch )
209244
210245 if branch in self .cache and kind_str in self .cache [branch ].nodes :
211246 return self .cache [branch ].nodes [kind_str ]
212247
213248 # Fetching the latest schema from the server if we didn't fetch it earlier
214249 # because we coulnd't find the object on the local cache
215250 if not refresh :
216- self .cache [branch ] = await self ._fetch (branch = branch , timeout = timeout )
251+ self .cache [branch ] = await self ._fetch (branch = branch )
217252
218253 if branch in self .cache and kind_str in self .cache [branch ].nodes :
219254 return self .cache [branch ].nodes [kind_str ]
@@ -416,59 +451,45 @@ async def add_dropdown_option(
416451 )
417452
418453 async def fetch (
419- self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None
454+ self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None , populate_cache : bool = True
420455 ) -> MutableMapping [str , MainSchemaTypesAPI ]:
421456 """Fetch the schema from the server for a given branch.
422457
423458 Args:
424- branch (str): Name of the branch to fetch the schema for.
425- timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
459+ branch: Name of the branch to fetch the schema for.
460+ timeout: Overrides default timeout used when querying the schema. deprecated.
461+ populate_cache: Whether to populate the cache with the fetched schema. Defaults to True.
426462
427463 Returns:
428464 dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
429465 """
430- branch_schema = await self ._fetch (branch = branch , namespaces = namespaces , timeout = timeout )
466+
467+ if timeout :
468+ self ._deprecated_schema_timeout ()
469+
470+ branch_schema = await self ._fetch (branch = branch , namespaces = namespaces )
471+
472+ if populate_cache :
473+ self .cache [branch ] = branch_schema
474+
431475 return branch_schema .nodes
432476
433- async def _fetch (
434- self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None
435- ) -> BranchSchema :
477+ async def _fetch (self , branch : str , namespaces : list [str ] | None = None ) -> BranchSchema :
436478 url_parts = [("branch" , branch )]
437479 if namespaces :
438480 url_parts .extend ([("namespaces" , ns ) for ns in namespaces ])
439481 query_params = urlencode (url_parts )
440482 url = f"{ self .client .address } /api/schema?{ query_params } "
441483
442- response = await self .client ._get (url = url , timeout = timeout )
484+ response = await self .client ._get (url = url )
443485
444486 data = self ._parse_schema_response (response = response , branch = branch )
445487
446- nodes : MutableMapping [str , MainSchemaTypesAPI ] = {}
447- for node_schema in data .get ("nodes" , []):
448- node = NodeSchemaAPI (** node_schema )
449- nodes [node .kind ] = node
450-
451- for generic_schema in data .get ("generics" , []):
452- generic = GenericSchemaAPI (** generic_schema )
453- nodes [generic .kind ] = generic
454-
455- for profile_schema in data .get ("profiles" , []):
456- profile = ProfileSchemaAPI (** profile_schema )
457- nodes [profile .kind ] = profile
458-
459- for template_schema in data .get ("templates" , []):
460- template = TemplateSchemaAPI (** template_schema )
461- nodes [template .kind ] = template
462-
463- schema_hash = data .get ("main" , "" )
464-
465- return BranchSchema (hash = schema_hash , nodes = nodes )
488+ return BranchSchema .from_api_response (data = data )
466489
467490
468491class InfrahubSchemaSync (InfrahubSchemaBase ):
469- def __init__ (self , client : InfrahubClientSync ):
470- self .client = client
471- self .cache : dict [str , BranchSchema ] = {}
492+ client : InfrahubClientSync
472493
473494 def all (
474495 self ,
@@ -506,10 +527,25 @@ def get(
506527 refresh : bool = False ,
507528 timeout : int | None = None ,
508529 ) -> MainSchemaTypesAPI :
530+ """
531+ Retrieve a specific schema object from the server.
532+
533+ Args:
534+ kind: The kind of schema object to retrieve.
535+ branch: The branch to retrieve the schema from.
536+ refresh: Whether to refresh the schema.
537+ timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated).
538+
539+ Returns:
540+ MainSchemaTypes: The schema object.
541+ """
509542 branch = branch or self .client .default_branch
510543
511544 kind_str = self ._get_schema_name (schema = kind )
512545
546+ if timeout :
547+ self ._deprecated_schema_timeout ()
548+
513549 if refresh :
514550 self .cache [branch ] = self ._fetch (branch = branch )
515551
@@ -519,7 +555,7 @@ def get(
519555 # Fetching the latest schema from the server if we didn't fetch it earlier
520556 # because we coulnd't find the object on the local cache
521557 if not refresh :
522- self .cache [branch ] = self ._fetch (branch = branch , timeout = timeout )
558+ self .cache [branch ] = self ._fetch (branch = branch )
523559
524560 if branch in self .cache and kind_str in self .cache [branch ].nodes :
525561 return self .cache [branch ].nodes [kind_str ]
@@ -639,49 +675,39 @@ def add_dropdown_option(
639675 )
640676
641677 def fetch (
642- self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None
678+ self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None , populate_cache : bool = True
643679 ) -> MutableMapping [str , MainSchemaTypesAPI ]:
644680 """Fetch the schema from the server for a given branch.
645681
646682 Args:
647- branch (str): Name of the branch to fetch the schema for.
648- timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
683+ branch: Name of the branch to fetch the schema for.
684+ timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated).
685+ populate_cache: Whether to populate the cache with the fetched schema. Defaults to True.
649686
650687 Returns:
651688 dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
652689 """
653- branch_schema = self ._fetch (branch = branch , namespaces = namespaces , timeout = timeout )
690+ if timeout :
691+ self ._deprecated_schema_timeout ()
692+
693+ branch_schema = self ._fetch (branch = branch , namespaces = namespaces )
694+
695+ if populate_cache :
696+ self .cache [branch ] = branch_schema
697+
654698 return branch_schema .nodes
655699
656- def _fetch (self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None ) -> BranchSchema :
700+ def _fetch (self , branch : str , namespaces : list [str ] | None = None ) -> BranchSchema :
657701 url_parts = [("branch" , branch )]
658702 if namespaces :
659703 url_parts .extend ([("namespaces" , ns ) for ns in namespaces ])
660704 query_params = urlencode (url_parts )
661705 url = f"{ self .client .address } /api/schema?{ query_params } "
662- response = self .client ._get (url = url , timeout = timeout )
663- data = self ._parse_schema_response (response = response , branch = branch )
706+ response = self .client ._get (url = url )
664707
665- nodes : MutableMapping [str , MainSchemaTypesAPI ] = {}
666- for node_schema in data .get ("nodes" , []):
667- node = NodeSchemaAPI (** node_schema )
668- nodes [node .kind ] = node
669-
670- for generic_schema in data .get ("generics" , []):
671- generic = GenericSchemaAPI (** generic_schema )
672- nodes [generic .kind ] = generic
673-
674- for profile_schema in data .get ("profiles" , []):
675- profile = ProfileSchemaAPI (** profile_schema )
676- nodes [profile .kind ] = profile
677-
678- for template_schema in data .get ("templates" , []):
679- template = TemplateSchemaAPI (** template_schema )
680- nodes [template .kind ] = template
681-
682- schema_hash = data .get ("main" , "" )
708+ data = self ._parse_schema_response (response = response , branch = branch )
683709
684- return BranchSchema ( hash = schema_hash , nodes = nodes )
710+ return BranchSchema . from_api_response ( data = data )
685711
686712 def load (
687713 self , schemas : list [dict ], branch : str | None = None , wait_until_converged : bool = False
0 commit comments