Skip to content

Commit 711e2cc

Browse files
committed
Add method schema.set_cache() to populate the cache manually (primarily for unit testing)
1 parent c72da0c commit 711e2cc

File tree

8 files changed

+168
-82
lines changed

8 files changed

+168
-82
lines changed

changelog/+schema-fetch.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
By default, schema.fetch will now populate the cache (this behavior can be changed with `populate_cache`)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add method schema.set_cache() to populate the cache manually (primarily for unit testing)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
The 'timeout' parameter while fetching the schema has been deprecated. the default_timeout will be used instead.

infrahub_sdk/checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def client(self, value: InfrahubClient) -> None:
8383
async def init(cls, client: InfrahubClient | None = None, *args: Any, **kwargs: Any) -> InfrahubCheck:
8484
"""Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically."""
8585
warnings.warn(
86-
"InfrahubCheck.init has been deprecated and will be removed in the version in Infrahub SDK 2.0.0",
86+
"InfrahubCheck.init has been deprecated and will be removed in version 2.0.0 of the Infrahub Python SDK",
8787
DeprecationWarning,
8888
stacklevel=1,
8989
)

infrahub_sdk/schema/__init__.py

Lines changed: 80 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import json
5+
import warnings
56
from collections.abc import MutableMapping
67
from enum import Enum
78
from time import sleep
@@ -90,6 +91,13 @@ class EnumMutation(str, Enum):
9091

9192

9293
class InfrahubSchemaBase:
94+
client: InfrahubClient | InfrahubClientSync
95+
cache: dict[str, BranchSchema]
96+
97+
def __init__(self, client: InfrahubClient | InfrahubClientSync):
98+
self.client = client
99+
self.cache = {}
100+
93101
def validate(self, data: dict[str, Any]) -> None:
94102
SchemaRoot(**data)
95103

@@ -102,6 +110,14 @@ def validate_data_against_schema(self, schema: MainSchemaTypesAPI, data: dict) -
102110
message=f"{key} is not a valid value for {identifier}",
103111
)
104112

113+
def set_cache(self, schema: dict[str, Any] | BranchSchema, branch: str | None = None) -> None:
114+
branch = branch or self.client.default_branch
115+
116+
if isinstance(schema, dict):
117+
schema = BranchSchema.from_api_response(data=schema)
118+
119+
self.cache[branch] = schema
120+
105121
def generate_payload_create(
106122
self,
107123
schema: MainSchemaTypesAPI,
@@ -187,11 +203,18 @@ def _parse_schema_response(response: httpx.Response, branch: str) -> MutableMapp
187203

188204
return data
189205

206+
@staticmethod
207+
def _deprecated_schema_timeout() -> None:
208+
warnings.warn(
209+
"The 'timeout' parameter is deprecated while fetching the schema and will be removed version 2.0.0 of the Infrahub Python SDK. "
210+
"Use client.default_timeout instead.",
211+
DeprecationWarning,
212+
stacklevel=2,
213+
)
214+
190215

191216
class InfrahubSchema(InfrahubSchemaBase):
192-
def __init__(self, client: InfrahubClient):
193-
self.client = client
194-
self.cache: dict[str, BranchSchema] = {}
217+
client: InfrahubClient
195218

196219
async def get(
197220
self,
@@ -204,16 +227,19 @@ async def get(
204227

205228
kind_str = self._get_schema_name(schema=kind)
206229

230+
if timeout:
231+
self._deprecated_schema_timeout()
232+
207233
if refresh:
208-
self.cache[branch] = await self._fetch(branch=branch, timeout=timeout)
234+
self.cache[branch] = await self._fetch(branch=branch)
209235

210236
if branch in self.cache and kind_str in self.cache[branch].nodes:
211237
return self.cache[branch].nodes[kind_str]
212238

213239
# Fetching the latest schema from the server if we didn't fetch it earlier
214240
# because we coulnd't find the object on the local cache
215241
if not refresh:
216-
self.cache[branch] = await self._fetch(branch=branch, timeout=timeout)
242+
self.cache[branch] = await self._fetch(branch=branch)
217243

218244
if branch in self.cache and kind_str in self.cache[branch].nodes:
219245
return self.cache[branch].nodes[kind_str]
@@ -416,59 +442,45 @@ async def add_dropdown_option(
416442
)
417443

418444
async def fetch(
419-
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None
445+
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None, populate_cache: bool = True
420446
) -> MutableMapping[str, MainSchemaTypesAPI]:
421447
"""Fetch the schema from the server for a given branch.
422448
423449
Args:
424-
branch (str): Name of the branch to fetch the schema for.
425-
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
450+
branch: Name of the branch to fetch the schema for.
451+
timeout: Overrides default timeout used when querying the schema. deprecated.
452+
populate_cache: Whether to populate the cache with the fetched schema. Defaults to True.
426453
427454
Returns:
428455
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
429456
"""
430-
branch_schema = await self._fetch(branch=branch, namespaces=namespaces, timeout=timeout)
457+
458+
if timeout:
459+
self._deprecated_schema_timeout()
460+
461+
branch_schema = await self._fetch(branch=branch, namespaces=namespaces)
462+
463+
if populate_cache:
464+
self.cache[branch] = branch_schema
465+
431466
return branch_schema.nodes
432467

433-
async def _fetch(
434-
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None
435-
) -> BranchSchema:
468+
async def _fetch(self, branch: str, namespaces: list[str] | None = None) -> BranchSchema:
436469
url_parts = [("branch", branch)]
437470
if namespaces:
438471
url_parts.extend([("namespaces", ns) for ns in namespaces])
439472
query_params = urlencode(url_parts)
440473
url = f"{self.client.address}/api/schema?{query_params}"
441474

442-
response = await self.client._get(url=url, timeout=timeout)
475+
response = await self.client._get(url=url)
443476

444477
data = self._parse_schema_response(response=response, branch=branch)
445478

446-
nodes: MutableMapping[str, MainSchemaTypesAPI] = {}
447-
for node_schema in data.get("nodes", []):
448-
node = NodeSchemaAPI(**node_schema)
449-
nodes[node.kind] = node
450-
451-
for generic_schema in data.get("generics", []):
452-
generic = GenericSchemaAPI(**generic_schema)
453-
nodes[generic.kind] = generic
454-
455-
for profile_schema in data.get("profiles", []):
456-
profile = ProfileSchemaAPI(**profile_schema)
457-
nodes[profile.kind] = profile
458-
459-
for template_schema in data.get("templates", []):
460-
template = TemplateSchemaAPI(**template_schema)
461-
nodes[template.kind] = template
462-
463-
schema_hash = data.get("main", "")
464-
465-
return BranchSchema(hash=schema_hash, nodes=nodes)
479+
return BranchSchema.from_api_response(data=data)
466480

467481

468482
class InfrahubSchemaSync(InfrahubSchemaBase):
469-
def __init__(self, client: InfrahubClientSync):
470-
self.client = client
471-
self.cache: dict[str, BranchSchema] = {}
483+
client: InfrahubClientSync
472484

473485
def all(
474486
self,
@@ -506,10 +518,25 @@ def get(
506518
refresh: bool = False,
507519
timeout: int | None = None,
508520
) -> MainSchemaTypesAPI:
521+
"""
522+
Retrieve a specific schema object from the server.
523+
524+
Args:
525+
kind: The kind of schema object to retrieve.
526+
branch: The branch to retrieve the schema from.
527+
refresh: Whether to refresh the schema.
528+
timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated).
529+
530+
Returns:
531+
MainSchemaTypes: The schema object.
532+
"""
509533
branch = branch or self.client.default_branch
510534

511535
kind_str = self._get_schema_name(schema=kind)
512536

537+
if timeout:
538+
self._deprecated_schema_timeout()
539+
513540
if refresh:
514541
self.cache[branch] = self._fetch(branch=branch)
515542

@@ -519,7 +546,7 @@ def get(
519546
# Fetching the latest schema from the server if we didn't fetch it earlier
520547
# because we coulnd't find the object on the local cache
521548
if not refresh:
522-
self.cache[branch] = self._fetch(branch=branch, timeout=timeout)
549+
self.cache[branch] = self._fetch(branch=branch)
523550

524551
if branch in self.cache and kind_str in self.cache[branch].nodes:
525552
return self.cache[branch].nodes[kind_str]
@@ -639,49 +666,39 @@ def add_dropdown_option(
639666
)
640667

641668
def fetch(
642-
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None
669+
self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None, populate_cache: bool = True
643670
) -> MutableMapping[str, MainSchemaTypesAPI]:
644671
"""Fetch the schema from the server for a given branch.
645672
646673
Args:
647-
branch (str): Name of the branch to fetch the schema for.
648-
timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
674+
branch: Name of the branch to fetch the schema for.
675+
timeout: Overrides default timeout used when querying the GraphQL API. Specified in seconds (deprecated).
676+
populate_cache: Whether to populate the cache with the fetched schema. Defaults to True.
649677
650678
Returns:
651679
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
652680
"""
653-
branch_schema = self._fetch(branch=branch, namespaces=namespaces, timeout=timeout)
681+
if timeout:
682+
self._deprecated_schema_timeout()
683+
684+
branch_schema = self._fetch(branch=branch, namespaces=namespaces)
685+
686+
if populate_cache:
687+
self.cache[branch] = branch_schema
688+
654689
return branch_schema.nodes
655690

656-
def _fetch(self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None) -> BranchSchema:
691+
def _fetch(self, branch: str, namespaces: list[str] | None = None) -> BranchSchema:
657692
url_parts = [("branch", branch)]
658693
if namespaces:
659694
url_parts.extend([("namespaces", ns) for ns in namespaces])
660695
query_params = urlencode(url_parts)
661696
url = f"{self.client.address}/api/schema?{query_params}"
662-
response = self.client._get(url=url, timeout=timeout)
663-
data = self._parse_schema_response(response=response, branch=branch)
664-
665-
nodes: MutableMapping[str, MainSchemaTypesAPI] = {}
666-
for node_schema in data.get("nodes", []):
667-
node = NodeSchemaAPI(**node_schema)
668-
nodes[node.kind] = node
697+
response = self.client._get(url=url)
669698

670-
for generic_schema in data.get("generics", []):
671-
generic = GenericSchemaAPI(**generic_schema)
672-
nodes[generic.kind] = generic
673-
674-
for profile_schema in data.get("profiles", []):
675-
profile = ProfileSchemaAPI(**profile_schema)
676-
nodes[profile.kind] = profile
677-
678-
for template_schema in data.get("templates", []):
679-
template = TemplateSchemaAPI(**template_schema)
680-
nodes[template.kind] = template
681-
682-
schema_hash = data.get("main", "")
699+
data = self._parse_schema_response(response=response, branch=branch)
683700

684-
return BranchSchema(hash=schema_hash, nodes=nodes)
701+
return BranchSchema.from_api_response(data=data)
685702

686703
def load(
687704
self, schemas: list[dict], branch: str | None = None, wait_until_converged: bool = False

infrahub_sdk/schema/main.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import TYPE_CHECKING, Any, Union
77

88
from pydantic import BaseModel, ConfigDict, Field
9+
from typing_extensions import Self
910

1011
if TYPE_CHECKING:
1112
from ..node import InfrahubNode, InfrahubNodeSync
@@ -356,3 +357,26 @@ class BranchSchema(BaseModel):
356357
nodes: MutableMapping[str, GenericSchemaAPI | NodeSchemaAPI | ProfileSchemaAPI | TemplateSchemaAPI] = Field(
357358
default_factory=dict
358359
)
360+
361+
@classmethod
362+
def from_api_response(cls, data: MutableMapping[str, Any]) -> Self:
363+
nodes: MutableMapping[str, GenericSchemaAPI | NodeSchemaAPI | ProfileSchemaAPI | TemplateSchemaAPI] = {}
364+
for node_schema in data.get("nodes", []):
365+
node = NodeSchemaAPI(**node_schema)
366+
nodes[node.kind] = node
367+
368+
for generic_schema in data.get("generics", []):
369+
generic = GenericSchemaAPI(**generic_schema)
370+
nodes[generic.kind] = generic
371+
372+
for profile_schema in data.get("profiles", []):
373+
profile = ProfileSchemaAPI(**profile_schema)
374+
nodes[profile.kind] = profile
375+
376+
for template_schema in data.get("templates", []):
377+
template = TemplateSchemaAPI(**template_schema)
378+
nodes[template.kind] = template
379+
380+
schema_hash = data.get("main", "")
381+
382+
return cls(hash=schema_hash, nodes=nodes)

tests/unit/sdk/conftest.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,38 +1795,58 @@ async def mock_query_repository_page2_2(
17951795

17961796

17971797
@pytest.fixture
1798-
async def mock_schema_query_01(httpx_mock: HTTPXMock) -> HTTPXMock:
1798+
async def schema_query_01_data() -> dict:
17991799
response_text = (get_fixtures_dir() / "schema_01.json").read_text(encoding="UTF-8")
1800+
return ujson.loads(response_text)
18001801

1802+
1803+
@pytest.fixture
1804+
async def schema_query_02_data() -> dict:
1805+
response_text = (get_fixtures_dir() / "schema_02.json").read_text(encoding="UTF-8")
1806+
return ujson.loads(response_text)
1807+
1808+
1809+
@pytest.fixture
1810+
async def schema_query_04_data() -> dict:
1811+
response_text = (get_fixtures_dir() / "schema_04.json").read_text(encoding="UTF-8")
1812+
return ujson.loads(response_text)
1813+
1814+
1815+
@pytest.fixture
1816+
async def schema_query_05_data() -> dict:
1817+
response_text = (get_fixtures_dir() / "schema_05.json").read_text(encoding="UTF-8")
1818+
return ujson.loads(response_text)
1819+
1820+
1821+
@pytest.fixture
1822+
async def mock_schema_query_01(httpx_mock: HTTPXMock, schema_query_01_data: dict) -> HTTPXMock:
18011823
httpx_mock.add_response(
18021824
method="GET",
18031825
url="http://mock/api/schema?branch=main",
1804-
json=ujson.loads(response_text),
1826+
json=schema_query_01_data,
18051827
is_reusable=True,
18061828
)
18071829
return httpx_mock
18081830

18091831

18101832
@pytest.fixture
1811-
async def mock_schema_query_02(httpx_mock: HTTPXMock) -> HTTPXMock:
1812-
response_text = (get_fixtures_dir() / "schema_02.json").read_text(encoding="UTF-8")
1833+
async def mock_schema_query_02(httpx_mock: HTTPXMock, schema_query_02_data: dict) -> HTTPXMock:
18131834
httpx_mock.add_response(
18141835
method="GET",
18151836
url=re.compile(r"^http://mock/api/schema\?branch=(main|cr1234)"),
1816-
json=ujson.loads(response_text),
1837+
json=schema_query_02_data,
18171838
is_reusable=True,
18181839
)
18191840
return httpx_mock
18201841

18211842

18221843
@pytest.fixture
1823-
async def mock_schema_query_05(httpx_mock: HTTPXMock) -> HTTPXMock:
1824-
response_text = (get_fixtures_dir() / "schema_05.json").read_text(encoding="UTF-8")
1825-
1844+
async def mock_schema_query_05(httpx_mock: HTTPXMock, schema_query_05_data: dict) -> HTTPXMock:
18261845
httpx_mock.add_response(
18271846
method="GET",
18281847
url="http://mock/api/schema?branch=main",
1829-
json=ujson.loads(response_text),
1848+
json=schema_query_05_data,
1849+
is_reusable=True,
18301850
)
18311851
return httpx_mock
18321852

@@ -1933,13 +1953,11 @@ async def mock_rest_api_artifact_fetch(httpx_mock: HTTPXMock) -> HTTPXMock:
19331953

19341954

19351955
@pytest.fixture
1936-
async def mock_rest_api_artifact_generate(httpx_mock: HTTPXMock) -> HTTPXMock:
1937-
schema_response = (get_fixtures_dir() / "schema_04.json").read_text(encoding="UTF-8")
1938-
1956+
async def mock_rest_api_artifact_generate(httpx_mock: HTTPXMock, schema_query_04_data: dict) -> HTTPXMock:
19391957
httpx_mock.add_response(
19401958
method="GET",
19411959
url="http://mock/api/schema?branch=main",
1942-
json=ujson.loads(schema_response),
1960+
json=schema_query_04_data,
19431961
is_reusable=True,
19441962
)
19451963

0 commit comments

Comments
 (0)