Skip to content

Commit 94d6199

Browse files
authored
Merge pull request #310 from opsmill/pog-sdk-selective-schema-hash-update-IHS-69
Add parameter to only optionally refresh the schema hash on schema.all
2 parents 79d9749 + e775467 commit 94d6199

File tree

5 files changed

+92
-26
lines changed

5 files changed

+92
-26
lines changed

changelog/152.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add 'schema_hash' parameter to client.schema.all to only optionally refresh the schema if the provided hash differs from what the client has already cached.

infrahub_sdk/schema/__init__.py

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import asyncio
4-
from collections import defaultdict
54
from collections.abc import MutableMapping
65
from enum import Enum
76
from time import sleep
@@ -22,6 +21,7 @@
2221
from .main import (
2322
AttributeSchema,
2423
AttributeSchemaAPI,
24+
BranchSchema,
2525
BranchSupportType,
2626
GenericSchema,
2727
GenericSchemaAPI,
@@ -169,7 +169,7 @@ def _get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str) -> str:
169169
class InfrahubSchema(InfrahubSchemaBase):
170170
def __init__(self, client: InfrahubClient):
171171
self.client = client
172-
self.cache: dict = defaultdict(lambda: dict)
172+
self.cache: dict[str, BranchSchema] = {}
173173

174174
async def get(
175175
self,
@@ -183,23 +183,27 @@ async def get(
183183
kind_str = self._get_schema_name(schema=kind)
184184

185185
if refresh:
186-
self.cache[branch] = await self.fetch(branch=branch, timeout=timeout)
186+
self.cache[branch] = await self._fetch(branch=branch, timeout=timeout)
187187

188-
if branch in self.cache and kind_str in self.cache[branch]:
189-
return self.cache[branch][kind_str]
188+
if branch in self.cache and kind_str in self.cache[branch].nodes:
189+
return self.cache[branch].nodes[kind_str]
190190

191191
# Fetching the latest schema from the server if we didn't fetch it earlier
192192
# because we coulnd't find the object on the local cache
193193
if not refresh:
194-
self.cache[branch] = await self.fetch(branch=branch, timeout=timeout)
194+
self.cache[branch] = await self._fetch(branch=branch, timeout=timeout)
195195

196-
if branch in self.cache and kind_str in self.cache[branch]:
197-
return self.cache[branch][kind_str]
196+
if branch in self.cache and kind_str in self.cache[branch].nodes:
197+
return self.cache[branch].nodes[kind_str]
198198

199199
raise SchemaNotFoundError(identifier=kind_str)
200200

201201
async def all(
202-
self, branch: str | None = None, refresh: bool = False, namespaces: list[str] | None = None
202+
self,
203+
branch: str | None = None,
204+
refresh: bool = False,
205+
namespaces: list[str] | None = None,
206+
schema_hash: str | None = None,
203207
) -> MutableMapping[str, MainSchemaTypesAPI]:
204208
"""Retrieve the entire schema for a given branch.
205209
@@ -209,15 +213,19 @@ async def all(
209213
Args:
210214
branch (str, optional): Name of the branch to query. Defaults to default_branch.
211215
refresh (bool, optional): Force a refresh of the schema. Defaults to False.
216+
schema_hash (str, optional): Only refresh if the current schema doesn't match this hash.
212217
213218
Returns:
214219
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
215220
"""
216221
branch = branch or self.client.default_branch
222+
if refresh and branch in self.cache and schema_hash and self.cache[branch].hash == schema_hash:
223+
refresh = False
224+
217225
if refresh or branch not in self.cache:
218-
self.cache[branch] = await self.fetch(branch=branch, namespaces=namespaces)
226+
self.cache[branch] = await self._fetch(branch=branch, namespaces=namespaces)
219227

220-
return self.cache[branch]
228+
return self.cache[branch].nodes
221229

222230
async def load(
223231
self, schemas: list[dict], branch: str | None = None, wait_until_converged: bool = False
@@ -392,11 +400,17 @@ async def fetch(
392400
393401
Args:
394402
branch (str): Name of the branch to fetch the schema for.
395-
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
403+
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
396404
397405
Returns:
398406
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
399407
"""
408+
branch_schema = await self._fetch(branch=branch, namespaces=namespaces, timeout=timeout)
409+
return branch_schema.nodes
410+
411+
async def _fetch(
412+
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None
413+
) -> BranchSchema:
400414
url_parts = [("branch", branch)]
401415
if namespaces:
402416
url_parts.extend([("namespaces", ns) for ns in namespaces])
@@ -425,16 +439,22 @@ async def fetch(
425439
template = TemplateSchemaAPI(**template_schema)
426440
nodes[template.kind] = template
427441

428-
return nodes
442+
schema_hash = data.get("main", "")
443+
444+
return BranchSchema(hash=schema_hash, nodes=nodes)
429445

430446

431447
class InfrahubSchemaSync(InfrahubSchemaBase):
432448
def __init__(self, client: InfrahubClientSync):
433449
self.client = client
434-
self.cache: dict = defaultdict(lambda: dict)
450+
self.cache: dict[str, BranchSchema] = {}
435451

436452
def all(
437-
self, branch: str | None = None, refresh: bool = False, namespaces: list[str] | None = None
453+
self,
454+
branch: str | None = None,
455+
refresh: bool = False,
456+
namespaces: list[str] | None = None,
457+
schema_hash: str | None = None,
438458
) -> MutableMapping[str, MainSchemaTypesAPI]:
439459
"""Retrieve the entire schema for a given branch.
440460
@@ -444,15 +464,19 @@ def all(
444464
Args:
445465
branch (str, optional): Name of the branch to query. Defaults to default_branch.
446466
refresh (bool, optional): Force a refresh of the schema. Defaults to False.
467+
schema_hash (str, optional): Only refresh if the current schema doesn't match this hash.
447468
448469
Returns:
449470
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
450471
"""
451472
branch = branch or self.client.default_branch
473+
if refresh and branch in self.cache and schema_hash and self.cache[branch].hash == schema_hash:
474+
refresh = False
475+
452476
if refresh or branch not in self.cache:
453-
self.cache[branch] = self.fetch(branch=branch, namespaces=namespaces)
477+
self.cache[branch] = self._fetch(branch=branch, namespaces=namespaces)
454478

455-
return self.cache[branch]
479+
return self.cache[branch].nodes
456480

457481
def get(
458482
self,
@@ -466,18 +490,18 @@ def get(
466490
kind_str = self._get_schema_name(schema=kind)
467491

468492
if refresh:
469-
self.cache[branch] = self.fetch(branch=branch)
493+
self.cache[branch] = self._fetch(branch=branch)
470494

471-
if branch in self.cache and kind_str in self.cache[branch]:
472-
return self.cache[branch][kind_str]
495+
if branch in self.cache and kind_str in self.cache[branch].nodes:
496+
return self.cache[branch].nodes[kind_str]
473497

474498
# Fetching the latest schema from the server if we didn't fetch it earlier
475499
# because we coulnd't find the object on the local cache
476500
if not refresh:
477-
self.cache[branch] = self.fetch(branch=branch, timeout=timeout)
501+
self.cache[branch] = self._fetch(branch=branch, timeout=timeout)
478502

479-
if branch in self.cache and kind_str in self.cache[branch]:
480-
return self.cache[branch][kind_str]
503+
if branch in self.cache and kind_str in self.cache[branch].nodes:
504+
return self.cache[branch].nodes[kind_str]
481505

482506
raise SchemaNotFoundError(identifier=kind_str)
483507

@@ -600,17 +624,20 @@ def fetch(
600624
601625
Args:
602626
branch (str): Name of the branch to fetch the schema for.
603-
timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
627+
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
604628
605629
Returns:
606630
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
607631
"""
632+
branch_schema = self._fetch(branch=branch, namespaces=namespaces, timeout=timeout)
633+
return branch_schema.nodes
634+
635+
def _fetch(self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None) -> BranchSchema:
608636
url_parts = [("branch", branch)]
609637
if namespaces:
610638
url_parts.extend([("namespaces", ns) for ns in namespaces])
611639
query_params = urlencode(url_parts)
612640
url = f"{self.client.address}/api/schema?{query_params}"
613-
614641
response = self.client._get(url=url, timeout=timeout)
615642
response.raise_for_status()
616643

@@ -633,7 +660,9 @@ def fetch(
633660
template = TemplateSchemaAPI(**template_schema)
634661
nodes[template.kind] = template
635662

636-
return nodes
663+
schema_hash = data.get("main", "")
664+
665+
return BranchSchema(hash=schema_hash, nodes=nodes)
637666

638667
def load(
639668
self, schemas: list[dict], branch: str | None = None, wait_until_converged: bool = False

infrahub_sdk/schema/main.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4+
from collections.abc import MutableMapping
45
from enum import Enum
56
from typing import TYPE_CHECKING, Any, Union
67

@@ -348,3 +349,10 @@ class SchemaRootAPI(BaseModel):
348349
nodes: list[NodeSchemaAPI] = Field(default_factory=list)
349350
profiles: list[ProfileSchemaAPI] = Field(default_factory=list)
350351
templates: list[TemplateSchemaAPI] = Field(default_factory=list)
352+
353+
354+
class BranchSchema(BaseModel):
355+
hash: str = Field(...)
356+
nodes: MutableMapping[str, GenericSchemaAPI | NodeSchemaAPI | ProfileSchemaAPI | TemplateSchemaAPI] = Field(
357+
default_factory=dict
358+
)

tests/fixtures/schema_01.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{
2+
"main": "c0272bc24cd943f21cf30affda06b12d",
23
"nodes": [
34
{
45
"name": "GraphQLQuery",

tests/unit/sdk/test_schema.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,33 @@ async def test_fetch_schema(mock_schema_query_01, client_type):
6464
assert isinstance(nodes["BuiltinTag"], NodeSchemaAPI)
6565

6666

67+
@pytest.mark.parametrize("client_type", client_types)
68+
async def test_fetch_schema_conditional_refresh(mock_schema_query_01: HTTPXMock, client_type: str) -> None:
69+
"""Verify that only one schema request is sent if we request to update the schema but already have the correct hash"""
70+
if client_type == "standard":
71+
client = InfrahubClient(config=Config(address="http://mock", insert_tracker=True))
72+
nodes = await client.schema.all(branch="main")
73+
schema_hash = client.schema.cache["main"].hash
74+
assert schema_hash
75+
nodes = await client.schema.all(branch="main", refresh=True, schema_hash=schema_hash)
76+
else:
77+
client = InfrahubClientSync(config=Config(address="http://mock", insert_tracker=True))
78+
nodes = client.schema.all(branch="main")
79+
schema_hash = client.schema.cache["main"].hash
80+
assert schema_hash
81+
nodes = client.schema.all(branch="main", refresh=True, schema_hash=schema_hash)
82+
83+
assert len(nodes) == 4
84+
assert sorted(nodes.keys()) == [
85+
"BuiltinLocation",
86+
"BuiltinTag",
87+
"CoreGraphQLQuery",
88+
"CoreRepository",
89+
]
90+
assert isinstance(nodes["BuiltinTag"], NodeSchemaAPI)
91+
assert len(mock_schema_query_01.get_requests()) == 1
92+
93+
6794
@pytest.mark.parametrize("client_type", client_types)
6895
async def test_schema_data_validation(rfile_schema, client_type):
6996
if client_type == "standard":

0 commit comments

Comments
 (0)