22
33import asyncio
44import json
5+ import warnings
56from collections .abc import MutableMapping
67from enum import Enum
78from time import sleep
@@ -90,6 +91,13 @@ class EnumMutation(str, Enum):
9091
9192
9293class InfrahubSchemaBase :
94+ client : InfrahubClient | InfrahubClientSync
95+ cache : dict [str , BranchSchema ]
96+
97+ def __init__ (self , client : InfrahubClient | InfrahubClientSync ):
98+ self .client = client
99+ self .cache = {}
100+
93101 def validate (self , data : dict [str , Any ]) -> None :
94102 SchemaRoot (** data )
95103
@@ -102,6 +110,14 @@ def validate_data_against_schema(self, schema: MainSchemaTypesAPI, data: dict) -
102110 message = f"{ key } is not a valid value for { identifier } " ,
103111 )
104112
113+ def set_cache (self , schema : dict [str , Any ] | BranchSchema , branch : str | None = None ) -> None :
114+ branch = branch or self .client .default_branch
115+
116+ if isinstance (schema , dict ):
117+ schema = BranchSchema .from_api_response (data = schema )
118+
119+ self .cache [branch ] = schema
120+
105121 def generate_payload_create (
106122 self ,
107123 schema : MainSchemaTypesAPI ,
@@ -187,11 +203,18 @@ def _parse_schema_response(response: httpx.Response, branch: str) -> MutableMapp
187203
188204 return data
189205
206+ @staticmethod
207+ def _deprecated_schema_timeout () -> None :
208+ warnings .warn (
209+ "The 'timeout' parameter is deprecated while fetching the schema and will be removed version 2.0.0 of the Infrahub Python SDK. "
210+ "Use client.default_timeout instead." ,
211+ DeprecationWarning ,
212+ stacklevel = 2 ,
213+ )
214+
190215
191216class InfrahubSchema (InfrahubSchemaBase ):
192- def __init__ (self , client : InfrahubClient ):
193- self .client = client
194- self .cache : dict [str , BranchSchema ] = {}
217+ client : InfrahubClient
195218
196219 async def get (
197220 self ,
@@ -204,16 +227,19 @@ async def get(
204227
205228 kind_str = self ._get_schema_name (schema = kind )
206229
230+ if timeout :
231+ self ._deprecated_schema_timeout ()
232+
207233 if refresh :
208- self .cache [branch ] = await self ._fetch (branch = branch , timeout = timeout )
234+ self .cache [branch ] = await self ._fetch (branch = branch )
209235
210236 if branch in self .cache and kind_str in self .cache [branch ].nodes :
211237 return self .cache [branch ].nodes [kind_str ]
212238
213239 # Fetching the latest schema from the server if we didn't fetch it earlier
214240 # because we coulnd't find the object on the local cache
215241 if not refresh :
216- self .cache [branch ] = await self ._fetch (branch = branch , timeout = timeout )
242+ self .cache [branch ] = await self ._fetch (branch = branch )
217243
218244 if branch in self .cache and kind_str in self .cache [branch ].nodes :
219245 return self .cache [branch ].nodes [kind_str ]
@@ -416,59 +442,45 @@ async def add_dropdown_option(
416442 )
417443
418444 async def fetch (
419- self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None
445+ self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None , populate_cache : bool = True
420446 ) -> MutableMapping [str , MainSchemaTypesAPI ]:
421447 """Fetch the schema from the server for a given branch.
422448
423449 Args:
424- branch (str): Name of the branch to fetch the schema for.
425- timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
450+ branch: Name of the branch to fetch the schema for.
451+ timeout: Overrides default timeout used when querying the schema. deprecated.
452+ populate_cache: Whether to populate the cache with the fetched schema. Defaults to True.
426453
427454 Returns:
428455 dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
429456 """
430- branch_schema = await self ._fetch (branch = branch , namespaces = namespaces , timeout = timeout )
457+
458+ if timeout :
459+ self ._deprecated_schema_timeout ()
460+
461+ branch_schema = await self ._fetch (branch = branch , namespaces = namespaces )
462+
463+ if populate_cache :
464+ self .cache [branch ] = branch_schema
465+
431466 return branch_schema .nodes
432467
433- async def _fetch (
434- self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None
435- ) -> BranchSchema :
468+ async def _fetch (self , branch : str , namespaces : list [str ] | None = None ) -> BranchSchema :
436469 url_parts = [("branch" , branch )]
437470 if namespaces :
438471 url_parts .extend ([("namespaces" , ns ) for ns in namespaces ])
439472 query_params = urlencode (url_parts )
440473 url = f"{ self .client .address } /api/schema?{ query_params } "
441474
442- response = await self .client ._get (url = url , timeout = timeout )
475+ response = await self .client ._get (url = url )
443476
444477 data = self ._parse_schema_response (response = response , branch = branch )
445478
446- nodes : MutableMapping [str , MainSchemaTypesAPI ] = {}
447- for node_schema in data .get ("nodes" , []):
448- node = NodeSchemaAPI (** node_schema )
449- nodes [node .kind ] = node
450-
451- for generic_schema in data .get ("generics" , []):
452- generic = GenericSchemaAPI (** generic_schema )
453- nodes [generic .kind ] = generic
454-
455- for profile_schema in data .get ("profiles" , []):
456- profile = ProfileSchemaAPI (** profile_schema )
457- nodes [profile .kind ] = profile
458-
459- for template_schema in data .get ("templates" , []):
460- template = TemplateSchemaAPI (** template_schema )
461- nodes [template .kind ] = template
462-
463- schema_hash = data .get ("main" , "" )
464-
465- return BranchSchema (hash = schema_hash , nodes = nodes )
479+ return BranchSchema .from_api_response (data = data )
466480
467481
468482class InfrahubSchemaSync (InfrahubSchemaBase ):
469- def __init__ (self , client : InfrahubClientSync ):
470- self .client = client
471- self .cache : dict [str , BranchSchema ] = {}
483+ client : InfrahubClientSync
472484
473485 def all (
474486 self ,
@@ -506,10 +518,25 @@ def get(
506518 refresh : bool = False ,
507519 timeout : int | None = None ,
508520 ) -> MainSchemaTypesAPI :
521+ """
522+ Retrieve a specific schema object from the server.
523+
524+ Args:
525+ kind: The kind of schema object to retrieve.
526+ branch: The branch to retrieve the schema from.
527+ refresh: Whether to refresh the schema.
528+ timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated).
529+
530+ Returns:
531+ MainSchemaTypes: The schema object.
532+ """
509533 branch = branch or self .client .default_branch
510534
511535 kind_str = self ._get_schema_name (schema = kind )
512536
537+ if timeout :
538+ self ._deprecated_schema_timeout ()
539+
513540 if refresh :
514541 self .cache [branch ] = self ._fetch (branch = branch )
515542
@@ -519,7 +546,7 @@ def get(
519546 # Fetching the latest schema from the server if we didn't fetch it earlier
520547 # because we coulnd't find the object on the local cache
521548 if not refresh :
522- self .cache [branch ] = self ._fetch (branch = branch , timeout = timeout )
549+ self .cache [branch ] = self ._fetch (branch = branch )
523550
524551 if branch in self .cache and kind_str in self .cache [branch ].nodes :
525552 return self .cache [branch ].nodes [kind_str ]
@@ -639,49 +666,39 @@ def add_dropdown_option(
639666 )
640667
641668 def fetch (
642- self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None
669+ self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None , populate_cache : bool = True
643670 ) -> MutableMapping [str , MainSchemaTypesAPI ]:
644671 """Fetch the schema from the server for a given branch.
645672
646673 Args:
647- branch (str): Name of the branch to fetch the schema for.
648- timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
674+ branch: Name of the branch to fetch the schema for.
675+ timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated).
676+ populate_cache: Whether to populate the cache with the fetched schema. Defaults to True.
649677
650678 Returns:
651679 dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
652680 """
653- branch_schema = self ._fetch (branch = branch , namespaces = namespaces , timeout = timeout )
681+ if timeout :
682+ self ._deprecated_schema_timeout ()
683+
684+ branch_schema = self ._fetch (branch = branch , namespaces = namespaces )
685+
686+ if populate_cache :
687+ self .cache [branch ] = branch_schema
688+
654689 return branch_schema .nodes
655690
656- def _fetch (self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None ) -> BranchSchema :
691+ def _fetch (self , branch : str , namespaces : list [str ] | None = None ) -> BranchSchema :
657692 url_parts = [("branch" , branch )]
658693 if namespaces :
659694 url_parts .extend ([("namespaces" , ns ) for ns in namespaces ])
660695 query_params = urlencode (url_parts )
661696 url = f"{ self .client .address } /api/schema?{ query_params } "
662- response = self .client ._get (url = url , timeout = timeout )
663- data = self ._parse_schema_response (response = response , branch = branch )
664-
665- nodes : MutableMapping [str , MainSchemaTypesAPI ] = {}
666- for node_schema in data .get ("nodes" , []):
667- node = NodeSchemaAPI (** node_schema )
668- nodes [node .kind ] = node
697+ response = self .client ._get (url = url )
669698
670- for generic_schema in data .get ("generics" , []):
671- generic = GenericSchemaAPI (** generic_schema )
672- nodes [generic .kind ] = generic
673-
674- for profile_schema in data .get ("profiles" , []):
675- profile = ProfileSchemaAPI (** profile_schema )
676- nodes [profile .kind ] = profile
677-
678- for template_schema in data .get ("templates" , []):
679- template = TemplateSchemaAPI (** template_schema )
680- nodes [template .kind ] = template
681-
682- schema_hash = data .get ("main" , "" )
699+ data = self ._parse_schema_response (response = response , branch = branch )
683700
684- return BranchSchema ( hash = schema_hash , nodes = nodes )
701+ return BranchSchema . from_api_response ( data = data )
685702
686703 def load (
687704 self , schemas : list [dict ], branch : str | None = None , wait_until_converged : bool = False
0 commit comments