11from __future__ import annotations
22
33import asyncio
4- from collections import defaultdict
54from collections .abc import MutableMapping
65from enum import Enum
76from time import sleep
2221from .main import (
2322 AttributeSchema ,
2423 AttributeSchemaAPI ,
24+ BranchSchema ,
2525 BranchSupportType ,
2626 GenericSchema ,
2727 GenericSchemaAPI ,
@@ -169,7 +169,7 @@ def _get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str) -> str:
169169class InfrahubSchema (InfrahubSchemaBase ):
170170 def __init__ (self , client : InfrahubClient ):
171171 self .client = client
172- self .cache : dict = defaultdict ( lambda : dict )
172+ self .cache : dict [ str , BranchSchema ] = {}
173173
174174 async def get (
175175 self ,
@@ -183,23 +183,27 @@ async def get(
183183 kind_str = self ._get_schema_name (schema = kind )
184184
185185 if refresh :
186- self .cache [branch ] = await self .fetch (branch = branch , timeout = timeout )
186+ self .cache [branch ] = await self ._fetch (branch = branch , timeout = timeout )
187187
188- if branch in self .cache and kind_str in self .cache [branch ]:
189- return self .cache [branch ][kind_str ]
188+ if branch in self .cache and kind_str in self .cache [branch ]. nodes :
189+ return self .cache [branch ]. nodes [kind_str ]
190190
191191 # Fetching the latest schema from the server if we didn't fetch it earlier
192192 # because we coulnd't find the object on the local cache
193193 if not refresh :
194- self .cache [branch ] = await self .fetch (branch = branch , timeout = timeout )
194+ self .cache [branch ] = await self ._fetch (branch = branch , timeout = timeout )
195195
196- if branch in self .cache and kind_str in self .cache [branch ]:
197- return self .cache [branch ][kind_str ]
196+ if branch in self .cache and kind_str in self .cache [branch ]. nodes :
197+ return self .cache [branch ]. nodes [kind_str ]
198198
199199 raise SchemaNotFoundError (identifier = kind_str )
200200
201201 async def all (
202- self , branch : str | None = None , refresh : bool = False , namespaces : list [str ] | None = None
202+ self ,
203+ branch : str | None = None ,
204+ refresh : bool = False ,
205+ namespaces : list [str ] | None = None ,
206+ schema_hash : str | None = None ,
203207 ) -> MutableMapping [str , MainSchemaTypesAPI ]:
204208 """Retrieve the entire schema for a given branch.
205209
@@ -209,15 +213,19 @@ async def all(
209213 Args:
210214 branch (str, optional): Name of the branch to query. Defaults to default_branch.
211215 refresh (bool, optional): Force a refresh of the schema. Defaults to False.
216+ schema_hash (str, optional): Only refresh if the current schema doesn't match this hash.
212217
213218 Returns:
214219 dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
215220 """
216221 branch = branch or self .client .default_branch
222+ if refresh and branch in self .cache and schema_hash and self .cache [branch ].hash == schema_hash :
223+ refresh = False
224+
217225 if refresh or branch not in self .cache :
218- self .cache [branch ] = await self .fetch (branch = branch , namespaces = namespaces )
226+ self .cache [branch ] = await self ._fetch (branch = branch , namespaces = namespaces )
219227
220- return self .cache [branch ]
228+ return self .cache [branch ]. nodes
221229
222230 async def load (
223231 self , schemas : list [dict ], branch : str | None = None , wait_until_converged : bool = False
@@ -392,11 +400,17 @@ async def fetch(
392400
393401 Args:
394402 branch (str): Name of the branch to fetch the schema for.
395- timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
403+ timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
396404
397405 Returns:
398406 dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
399407 """
408+ branch_schema = await self ._fetch (branch = branch , namespaces = namespaces , timeout = timeout )
409+ return branch_schema .nodes
410+
411+ async def _fetch (
412+ self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None
413+ ) -> BranchSchema :
400414 url_parts = [("branch" , branch )]
401415 if namespaces :
402416 url_parts .extend ([("namespaces" , ns ) for ns in namespaces ])
@@ -425,16 +439,22 @@ async def fetch(
425439 template = TemplateSchemaAPI (** template_schema )
426440 nodes [template .kind ] = template
427441
428- return nodes
442+ schema_hash = data .get ("main" , "" )
443+
444+ return BranchSchema (hash = schema_hash , nodes = nodes )
429445
430446
431447class InfrahubSchemaSync (InfrahubSchemaBase ):
432448 def __init__ (self , client : InfrahubClientSync ):
433449 self .client = client
434- self .cache : dict = defaultdict ( lambda : dict )
450+ self .cache : dict [ str , BranchSchema ] = {}
435451
436452 def all (
437- self , branch : str | None = None , refresh : bool = False , namespaces : list [str ] | None = None
453+ self ,
454+ branch : str | None = None ,
455+ refresh : bool = False ,
456+ namespaces : list [str ] | None = None ,
457+ schema_hash : str | None = None ,
438458 ) -> MutableMapping [str , MainSchemaTypesAPI ]:
439459 """Retrieve the entire schema for a given branch.
440460
@@ -444,15 +464,19 @@ def all(
444464 Args:
445465 branch (str, optional): Name of the branch to query. Defaults to default_branch.
446466 refresh (bool, optional): Force a refresh of the schema. Defaults to False.
467+ schema_hash (str, optional): Only refresh if the current schema doesn't match this hash.
447468
448469 Returns:
449470 dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
450471 """
451472 branch = branch or self .client .default_branch
473+ if refresh and branch in self .cache and schema_hash and self .cache [branch ].hash == schema_hash :
474+ refresh = False
475+
452476 if refresh or branch not in self .cache :
453- self .cache [branch ] = self .fetch (branch = branch , namespaces = namespaces )
477+ self .cache [branch ] = self ._fetch (branch = branch , namespaces = namespaces )
454478
455- return self .cache [branch ]
479+ return self .cache [branch ]. nodes
456480
457481 def get (
458482 self ,
@@ -466,18 +490,18 @@ def get(
466490 kind_str = self ._get_schema_name (schema = kind )
467491
468492 if refresh :
469- self .cache [branch ] = self .fetch (branch = branch )
493+ self .cache [branch ] = self ._fetch (branch = branch )
470494
471- if branch in self .cache and kind_str in self .cache [branch ]:
472- return self .cache [branch ][kind_str ]
495+ if branch in self .cache and kind_str in self .cache [branch ]. nodes :
496+ return self .cache [branch ]. nodes [kind_str ]
473497
474498 # Fetching the latest schema from the server if we didn't fetch it earlier
475499 # because we coulnd't find the object on the local cache
476500 if not refresh :
477- self .cache [branch ] = self .fetch (branch = branch , timeout = timeout )
501+ self .cache [branch ] = self ._fetch (branch = branch , timeout = timeout )
478502
479- if branch in self .cache and kind_str in self .cache [branch ]:
480- return self .cache [branch ][kind_str ]
503+ if branch in self .cache and kind_str in self .cache [branch ]. nodes :
504+ return self .cache [branch ]. nodes [kind_str ]
481505
482506 raise SchemaNotFoundError (identifier = kind_str )
483507
@@ -600,17 +624,20 @@ def fetch(
600624
601625 Args:
602626 branch (str): Name of the branch to fetch the schema for.
603- timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
627+ timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
604628
605629 Returns:
606630 dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
607631 """
632+ branch_schema = self ._fetch (branch = branch , namespaces = namespaces , timeout = timeout )
633+ return branch_schema .nodes
634+
635+ def _fetch (self , branch : str , namespaces : list [str ] | None = None , timeout : int | None = None ) -> BranchSchema :
608636 url_parts = [("branch" , branch )]
609637 if namespaces :
610638 url_parts .extend ([("namespaces" , ns ) for ns in namespaces ])
611639 query_params = urlencode (url_parts )
612640 url = f"{ self .client .address } /api/schema?{ query_params } "
613-
614641 response = self .client ._get (url = url , timeout = timeout )
615642 response .raise_for_status ()
616643
@@ -633,7 +660,9 @@ def fetch(
633660 template = TemplateSchemaAPI (** template_schema )
634661 nodes [template .kind ] = template
635662
636- return nodes
663+ schema_hash = data .get ("main" , "" )
664+
665+ return BranchSchema (hash = schema_hash , nodes = nodes )
637666
638667 def load (
639668 self , schemas : list [dict ], branch : str | None = None , wait_until_converged : bool = False
0 commit comments