Skip to content

Commit a96e547

Browse files
authored
Merge pull request #360 from opsmill/dga-20250416-set-schema-cache
Add method to populate the cache of the schema manually
2 parents 96013cf + 11fa9f8 commit a96e547

File tree

8 files changed

+184
-83
lines changed

8 files changed

+184
-83
lines changed

changelog/+schema-fetch.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
By default, schema.fetch will now populate the cache (this behavior can be changed with `populate_cache`)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add method `client.schema.set_cache()` to populate the cache manually (primarily for unit testing)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
The 'timeout' parameter while creating a node or fetching the schema has been deprecated. the default_timeout will be used instead.

infrahub_sdk/checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def client(self, value: InfrahubClient) -> None:
8383
async def init(cls, client: InfrahubClient | None = None, *args: Any, **kwargs: Any) -> InfrahubCheck:
8484
"""Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically."""
8585
warnings.warn(
86-
"InfrahubCheck.init has been deprecated and will be removed in the version in Infrahub SDK 2.0.0",
86+
"InfrahubCheck.init has been deprecated and will be removed in version 2.0.0 of the Infrahub Python SDK",
8787
DeprecationWarning,
8888
stacklevel=1,
8989
)

infrahub_sdk/schema/__init__.py

Lines changed: 89 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import json
5+
import warnings
56
from collections.abc import MutableMapping
67
from enum import Enum
78
from time import sleep
@@ -90,6 +91,13 @@ class EnumMutation(str, Enum):
9091

9192

9293
class InfrahubSchemaBase:
94+
client: InfrahubClient | InfrahubClientSync
95+
cache: dict[str, BranchSchema]
96+
97+
def __init__(self, client: InfrahubClient | InfrahubClientSync):
98+
self.client = client
99+
self.cache = {}
100+
93101
def validate(self, data: dict[str, Any]) -> None:
94102
SchemaRoot(**data)
95103

@@ -102,6 +110,23 @@ def validate_data_against_schema(self, schema: MainSchemaTypesAPI, data: dict) -
102110
message=f"{key} is not a valid value for {identifier}",
103111
)
104112

113+
def set_cache(self, schema: dict[str, Any] | SchemaRootAPI | BranchSchema, branch: str | None = None) -> None:
114+
"""
115+
Set the cache manually (primarily for unit testing)
116+
117+
Args:
118+
schema: The schema to set the cache as provided by the /api/schema endpoint either in dict or SchemaRootAPI format
119+
branch: The name of the branch to set the cache for.
120+
"""
121+
branch = branch or self.client.default_branch
122+
123+
if isinstance(schema, SchemaRootAPI):
124+
schema = BranchSchema.from_schema_root_api(data=schema)
125+
elif isinstance(schema, dict):
126+
schema = BranchSchema.from_api_response(data=schema)
127+
128+
self.cache[branch] = schema
129+
105130
def generate_payload_create(
106131
self,
107132
schema: MainSchemaTypesAPI,
@@ -187,11 +212,18 @@ def _parse_schema_response(response: httpx.Response, branch: str) -> MutableMapp
187212

188213
return data
189214

215+
@staticmethod
216+
def _deprecated_schema_timeout() -> None:
217+
warnings.warn(
218+
"The 'timeout' parameter is deprecated while fetching the schema and will be removed version 2.0.0 of the Infrahub Python SDK. "
219+
"Use client.default_timeout instead.",
220+
DeprecationWarning,
221+
stacklevel=2,
222+
)
223+
190224

191225
class InfrahubSchema(InfrahubSchemaBase):
192-
def __init__(self, client: InfrahubClient):
193-
self.client = client
194-
self.cache: dict[str, BranchSchema] = {}
226+
client: InfrahubClient
195227

196228
async def get(
197229
self,
@@ -204,16 +236,19 @@ async def get(
204236

205237
kind_str = self._get_schema_name(schema=kind)
206238

239+
if timeout:
240+
self._deprecated_schema_timeout()
241+
207242
if refresh:
208-
self.cache[branch] = await self._fetch(branch=branch, timeout=timeout)
243+
self.cache[branch] = await self._fetch(branch=branch)
209244

210245
if branch in self.cache and kind_str in self.cache[branch].nodes:
211246
return self.cache[branch].nodes[kind_str]
212247

213248
# Fetching the latest schema from the server if we didn't fetch it earlier
214249
# because we coulnd't find the object on the local cache
215250
if not refresh:
216-
self.cache[branch] = await self._fetch(branch=branch, timeout=timeout)
251+
self.cache[branch] = await self._fetch(branch=branch)
217252

218253
if branch in self.cache and kind_str in self.cache[branch].nodes:
219254
return self.cache[branch].nodes[kind_str]
@@ -416,59 +451,45 @@ async def add_dropdown_option(
416451
)
417452

418453
async def fetch(
419-
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None
454+
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None, populate_cache: bool = True
420455
) -> MutableMapping[str, MainSchemaTypesAPI]:
421456
"""Fetch the schema from the server for a given branch.
422457
423458
Args:
424-
branch (str): Name of the branch to fetch the schema for.
425-
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
459+
branch: Name of the branch to fetch the schema for.
460+
timeout: Overrides default timeout used when querying the schema. deprecated.
461+
populate_cache: Whether to populate the cache with the fetched schema. Defaults to True.
426462
427463
Returns:
428464
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
429465
"""
430-
branch_schema = await self._fetch(branch=branch, namespaces=namespaces, timeout=timeout)
466+
467+
if timeout:
468+
self._deprecated_schema_timeout()
469+
470+
branch_schema = await self._fetch(branch=branch, namespaces=namespaces)
471+
472+
if populate_cache:
473+
self.cache[branch] = branch_schema
474+
431475
return branch_schema.nodes
432476

433-
async def _fetch(
434-
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None
435-
) -> BranchSchema:
477+
async def _fetch(self, branch: str, namespaces: list[str] | None = None) -> BranchSchema:
436478
url_parts = [("branch", branch)]
437479
if namespaces:
438480
url_parts.extend([("namespaces", ns) for ns in namespaces])
439481
query_params = urlencode(url_parts)
440482
url = f"{self.client.address}/api/schema?{query_params}"
441483

442-
response = await self.client._get(url=url, timeout=timeout)
484+
response = await self.client._get(url=url)
443485

444486
data = self._parse_schema_response(response=response, branch=branch)
445487

446-
nodes: MutableMapping[str, MainSchemaTypesAPI] = {}
447-
for node_schema in data.get("nodes", []):
448-
node = NodeSchemaAPI(**node_schema)
449-
nodes[node.kind] = node
450-
451-
for generic_schema in data.get("generics", []):
452-
generic = GenericSchemaAPI(**generic_schema)
453-
nodes[generic.kind] = generic
454-
455-
for profile_schema in data.get("profiles", []):
456-
profile = ProfileSchemaAPI(**profile_schema)
457-
nodes[profile.kind] = profile
458-
459-
for template_schema in data.get("templates", []):
460-
template = TemplateSchemaAPI(**template_schema)
461-
nodes[template.kind] = template
462-
463-
schema_hash = data.get("main", "")
464-
465-
return BranchSchema(hash=schema_hash, nodes=nodes)
488+
return BranchSchema.from_api_response(data=data)
466489

467490

468491
class InfrahubSchemaSync(InfrahubSchemaBase):
469-
def __init__(self, client: InfrahubClientSync):
470-
self.client = client
471-
self.cache: dict[str, BranchSchema] = {}
492+
client: InfrahubClientSync
472493

473494
def all(
474495
self,
@@ -506,10 +527,25 @@ def get(
506527
refresh: bool = False,
507528
timeout: int | None = None,
508529
) -> MainSchemaTypesAPI:
530+
"""
531+
Retrieve a specific schema object from the server.
532+
533+
Args:
534+
kind: The kind of schema object to retrieve.
535+
branch: The branch to retrieve the schema from.
536+
refresh: Whether to refresh the schema.
537+
timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated).
538+
539+
Returns:
540+
MainSchemaTypes: The schema object.
541+
"""
509542
branch = branch or self.client.default_branch
510543

511544
kind_str = self._get_schema_name(schema=kind)
512545

546+
if timeout:
547+
self._deprecated_schema_timeout()
548+
513549
if refresh:
514550
self.cache[branch] = self._fetch(branch=branch)
515551

@@ -519,7 +555,7 @@ def get(
519555
# Fetching the latest schema from the server if we didn't fetch it earlier
520556
# because we coulnd't find the object on the local cache
521557
if not refresh:
522-
self.cache[branch] = self._fetch(branch=branch, timeout=timeout)
558+
self.cache[branch] = self._fetch(branch=branch)
523559

524560
if branch in self.cache and kind_str in self.cache[branch].nodes:
525561
return self.cache[branch].nodes[kind_str]
@@ -639,49 +675,39 @@ def add_dropdown_option(
639675
)
640676

641677
def fetch(
642-
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None
678+
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None, populate_cache: bool = True
643679
) -> MutableMapping[str, MainSchemaTypesAPI]:
644680
"""Fetch the schema from the server for a given branch.
645681
646682
Args:
647-
branch (str): Name of the branch to fetch the schema for.
648-
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
683+
branch: Name of the branch to fetch the schema for.
684+
timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated).
685+
populate_cache: Whether to populate the cache with the fetched schema. Defaults to True.
649686
650687
Returns:
651688
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
652689
"""
653-
branch_schema = self._fetch(branch=branch, namespaces=namespaces, timeout=timeout)
690+
if timeout:
691+
self._deprecated_schema_timeout()
692+
693+
branch_schema = self._fetch(branch=branch, namespaces=namespaces)
694+
695+
if populate_cache:
696+
self.cache[branch] = branch_schema
697+
654698
return branch_schema.nodes
655699

656-
def _fetch(self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None) -> BranchSchema:
700+
def _fetch(self, branch: str, namespaces: list[str] | None = None) -> BranchSchema:
657701
url_parts = [("branch", branch)]
658702
if namespaces:
659703
url_parts.extend([("namespaces", ns) for ns in namespaces])
660704
query_params = urlencode(url_parts)
661705
url = f"{self.client.address}/api/schema?{query_params}"
662-
response = self.client._get(url=url, timeout=timeout)
663-
data = self._parse_schema_response(response=response, branch=branch)
706+
response = self.client._get(url=url)
664707

665-
nodes: MutableMapping[str, MainSchemaTypesAPI] = {}
666-
for node_schema in data.get("nodes", []):
667-
node = NodeSchemaAPI(**node_schema)
668-
nodes[node.kind] = node
669-
670-
for generic_schema in data.get("generics", []):
671-
generic = GenericSchemaAPI(**generic_schema)
672-
nodes[generic.kind] = generic
673-
674-
for profile_schema in data.get("profiles", []):
675-
profile = ProfileSchemaAPI(**profile_schema)
676-
nodes[profile.kind] = profile
677-
678-
for template_schema in data.get("templates", []):
679-
template = TemplateSchemaAPI(**template_schema)
680-
nodes[template.kind] = template
681-
682-
schema_hash = data.get("main", "")
708+
data = self._parse_schema_response(response=response, branch=branch)
683709

684-
return BranchSchema(hash=schema_hash, nodes=nodes)
710+
return BranchSchema.from_api_response(data=data)
685711

686712
def load(
687713
self, schemas: list[dict], branch: str | None = None, wait_until_converged: bool = False

infrahub_sdk/schema/main.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import TYPE_CHECKING, Any, Union
77

88
from pydantic import BaseModel, ConfigDict, Field
9+
from typing_extensions import Self
910

1011
if TYPE_CHECKING:
1112
from ..node import InfrahubNode, InfrahubNodeSync
@@ -348,7 +349,7 @@ def to_schema_dict(self) -> dict[str, Any]:
348349
class SchemaRootAPI(BaseModel):
349350
model_config = ConfigDict(use_enum_values=True)
350351

351-
version: str
352+
main: str | None = None
352353
generics: list[GenericSchemaAPI] = Field(default_factory=list)
353354
nodes: list[NodeSchemaAPI] = Field(default_factory=list)
354355
profiles: list[ProfileSchemaAPI] = Field(default_factory=list)
@@ -360,3 +361,32 @@ class BranchSchema(BaseModel):
360361
nodes: MutableMapping[str, GenericSchemaAPI | NodeSchemaAPI | ProfileSchemaAPI | TemplateSchemaAPI] = Field(
361362
default_factory=dict
362363
)
364+
365+
@classmethod
366+
def from_api_response(cls, data: MutableMapping[str, Any]) -> Self:
367+
"""
368+
Convert an API response from /api/schema into a BranchSchema object.
369+
"""
370+
return cls.from_schema_root_api(data=SchemaRootAPI(**data))
371+
372+
@classmethod
373+
def from_schema_root_api(cls, data: SchemaRootAPI) -> Self:
374+
"""
375+
Convert a SchemaRootAPI object to a BranchSchema object.
376+
"""
377+
nodes: MutableMapping[str, GenericSchemaAPI | NodeSchemaAPI | ProfileSchemaAPI | TemplateSchemaAPI] = {}
378+
for node in data.nodes:
379+
nodes[node.kind] = node
380+
381+
for generic in data.generics:
382+
nodes[generic.kind] = generic
383+
384+
for profile in data.profiles:
385+
nodes[profile.kind] = profile
386+
387+
for template in data.templates:
388+
nodes[template.kind] = template
389+
390+
schema_hash = data.main or ""
391+
392+
return cls(hash=schema_hash, nodes=nodes)

0 commit comments

Comments
 (0)