diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8b242f64..a9496730 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -11,6 +11,20 @@ This project uses [*towncrier*](https://towncrier.readthedocs.io/) and the chang
+## [1.9.0](https://github.com/opsmill/infrahub-sdk-python/tree/v1.9.0) - 2025-03-21
+
+### Added
+
+- Add 'schema_hash' parameter to client.schema.all to only optionally refresh the schema if the provided hash differs from what the client has already cached. ([#152](https://github.com/opsmill/infrahub-sdk-python/issues/152))
+
+### Changed
+
+- CoreStandardGroups created or updated by a generator in Infrahub are now stored as a member of the CoreGeneratorGroup. Previously they were being stored as children of the CoreGeneratorGroup.
+
+### Fixed
+
+- The SDK client query methods (get, filters, all) default behaviour has changed. The query methods will store the retrieved nodes in the internal store by default, where previously this behaviour had to be enabled explicitly using the `populate_store` argument. ([#15](https://github.com/opsmill/infrahub-sdk-python/issues/15))
+
## [1.8.0](https://github.com/opsmill/infrahub-sdk-python/tree/v1.8.0) - 2025-03-19
### Deprecated
diff --git a/docs/docs/python-sdk/guides/store.mdx b/docs/docs/python-sdk/guides/store.mdx
index f2d8546f..07bcfc58 100644
--- a/docs/docs/python-sdk/guides/store.mdx
+++ b/docs/docs/python-sdk/guides/store.mdx
@@ -14,21 +14,39 @@ The store can also be used to store Nodes that we retrieve using the client quer
## Storing nodes in the store
-Nodes retrieved from Infrahub using a query method, will be stored in the internal store when you set the `populate_store` argument to `True`.
-Nodes stored in the store using this method can be retrieved using their `id` as the key in the store.
+Nodes retrieved from Infrahub using a the SDK client's query methods, such as the `get` `filters` or `all` method, will be automatically stored in the internal store. Nodes stored in the store using this method can be retrieved using their `id` as the key in the store.
```python
- tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=True)
+ tag = await client.get(kind="BuiltinTag", name__value="RED")
+ ```
+
+
+
+
+ ```python
+ tag = client.get(kind="BuiltinTag", name__value="RED")
+ ```
+
+
+
+
+This behaviour may not be desirable in all scenarios, therefor you can explicitly disable this behaviour by setting the `populate_store` argument to `False` when calling the query methods.
+
+
+
+
+ ```python
+ tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=False)
```
```python
- tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=True)
+ tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=False)
```
@@ -42,7 +60,7 @@ You can store nodes in the object store manually using the `set` method. This ha
```python
- tag = await client.get(kind="BuiltinTag", name__value="RED")
+ tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=False)
client.store.set(key=tag.name.value, node=tag)
```
@@ -50,7 +68,7 @@ You can store nodes in the object store manually using the `set` method. This ha
```python
- tag = client.get(kind="BuiltinTag", name__value="RED")
+ tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=False)
client.store.set(key=tag.name.value, node=tag)
```
@@ -60,13 +78,13 @@ You can store nodes in the object store manually using the `set` method. This ha
## Retrieving object from the store
Nodes can be retrieved from the internal store using the key that was used to store them.
-For nodes that are stored using the `populate_store` argument on a query method, this will be their `id`.
+For nodes that are stored by the client's query methods, this will be their `id`.
```python
- tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=True)
+ tag = await client.get(kind="BuiltinTag", name__value="RED")
tag_in_store = client.store.get(key=tag.id)
assert tag == tag_in_store
```
@@ -75,7 +93,7 @@ For nodes that are stored using the `populate_store` argument on a query method,
```python
- tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=True)
+ tag = client.get(kind="BuiltinTag", name__value="RED")
tag_in_store = client.store.get(key=tag.id)
assert tag == tag_in_store
```
@@ -89,7 +107,7 @@ For nodes that have been added manually to the store, this will be the key that
```python
- tag = await client.get(kind="BuiltinTag", name__value="RED")
+ tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=False)
client.store.set(key=tag.name.value, node=tag)
tag_in_store = client.store.get(key=tag.name.value)
assert tag == tag_in_store
@@ -99,7 +117,7 @@ For nodes that have been added manually to the store, this will be the key that
```python
- tag = client.get(kind="BuiltinTag", name__value="RED")
+ tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=False)
client.store.set(key=tag.name.value, node=tag)
tag_in_store = client.store.get(key=tag.name.value)
assert tag == tag_in_store
@@ -142,7 +160,7 @@ You can use a custom store, outside of the Infrahub SDK client. Storing or retri
from infrahub_sdk.store import NodeStore
store = NodeStore()
- device = await client.get(kind="TestDevice", name__value="atl1-edge1")
+ device = await client.get(kind="TestDevice", name__value="atl1-edge1", populate_store=False)
store.set(key=device.name.value, node=device)
store.get(key=device.name.value)
```
@@ -154,7 +172,7 @@ You can use a custom store, outside of the Infrahub SDK client. Storing or retri
from infrahub_sdk.store import NodeStoreSync
store = NodeStoreSync()
- device = await client.get(kind="TestDevice", name__value="atl1-edge1")
+ device = await client.get(kind="TestDevice", name__value="atl1-edge1", populate_store=False)
store.set(key=device.name.value, node=device)
store.get(key=device.name.value)
```
diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py
index 00664009..1835ff00 100644
--- a/infrahub_sdk/client.py
+++ b/infrahub_sdk/client.py
@@ -470,7 +470,7 @@ async def get(
hfid: list[str] | None = None,
include: list[str] | None = None,
exclude: list[str] | None = None,
- populate_store: bool = False,
+ populate_store: bool = True,
fragment: bool = False,
prefetch_relationships: bool = False,
property: bool = False,
@@ -624,7 +624,7 @@ async def all(
at: Timestamp | None = None,
branch: str | None = None,
timeout: int | None = None,
- populate_store: bool = False,
+ populate_store: bool = True,
offset: int | None = None,
limit: int | None = None,
include: list[str] | None = None,
@@ -720,7 +720,7 @@ async def filters(
at: Timestamp | None = None,
branch: str | None = None,
timeout: int | None = None,
- populate_store: bool = False,
+ populate_store: bool = True,
offset: int | None = None,
limit: int | None = None,
include: list[str] | None = None,
@@ -1745,7 +1745,7 @@ def all(
at: Timestamp | None = None,
branch: str | None = None,
timeout: int | None = None,
- populate_store: bool = False,
+ populate_store: bool = True,
offset: int | None = None,
limit: int | None = None,
include: list[str] | None = None,
@@ -1876,7 +1876,7 @@ def filters(
at: Timestamp | None = None,
branch: str | None = None,
timeout: int | None = None,
- populate_store: bool = False,
+ populate_store: bool = True,
offset: int | None = None,
limit: int | None = None,
include: list[str] | None = None,
@@ -2129,7 +2129,7 @@ def get(
hfid: list[str] | None = None,
include: list[str] | None = None,
exclude: list[str] | None = None,
- populate_store: bool = False,
+ populate_store: bool = True,
fragment: bool = False,
prefetch_relationships: bool = False,
property: bool = False,
diff --git a/infrahub_sdk/query_groups.py b/infrahub_sdk/query_groups.py
index b1b4687a..4bb732b2 100644
--- a/infrahub_sdk/query_groups.py
+++ b/infrahub_sdk/query_groups.py
@@ -19,7 +19,6 @@ def __init__(self) -> None:
self.related_node_ids: list[str] = []
self.related_group_ids: list[str] = []
self.unused_member_ids: list[str] | None = None
- self.unused_child_ids: list[str] | None = None
self.previous_members: list[RelatedNodeBase] | None = None
self.previous_children: list[RelatedNodeBase] | None = None
self.identifier: str | None = None
@@ -88,7 +87,7 @@ def __init__(self, client: InfrahubClient) -> None:
async def get_group(self, store_peers: bool = False) -> InfrahubNode | None:
group_name = self._generate_group_name()
try:
- group = await self.client.get(kind=self.group_type, name__value=group_name, include=["members", "children"])
+ group = await self.client.get(kind=self.group_type, name__value=group_name, include=["members"])
except NodeNotFoundError:
return None
@@ -96,7 +95,6 @@ async def get_group(self, store_peers: bool = False) -> InfrahubNode | None:
return group
self.previous_members = group.members.peers # type: ignore[attr-defined]
- self.previous_children = group.children.peers # type: ignore[attr-defined]
return group
async def delete_unused(self) -> None:
@@ -105,11 +103,6 @@ async def delete_unused(self) -> None:
if member.id in self.unused_member_ids and member.typename:
await self.client.delete(kind=member.typename, id=member.id)
- if self.previous_children and self.unused_child_ids:
- for child in self.previous_children:
- if child.id in self.unused_child_ids and child.typename:
- await self.client.delete(kind=child.typename, id=child.id)
-
async def add_related_nodes(self, ids: list[str], update_group_context: bool | None = None) -> None:
"""
Add related Nodes IDs to the context.
@@ -140,15 +133,9 @@ async def update_group(self) -> None:
"""
Create or update (using upsert) a CoreStandardGroup to store all the Nodes and Groups used during an execution.
"""
- children: list[str] = []
- members: list[str] = []
-
- if self.related_group_ids:
- children = self.related_group_ids
- if self.related_node_ids:
- members = self.related_node_ids
+ members: list[str] = self.related_group_ids + self.related_node_ids
- if not children and not members:
+ if not members:
return
group_name = self._generate_group_name()
@@ -164,7 +151,6 @@ async def update_group(self) -> None:
name=group_name,
description=description,
members=members,
- children=children,
)
await group.save(allow_upsert=True, update_group_context=False)
@@ -173,7 +159,6 @@ async def update_group(self) -> None:
# Calculate how many nodes should be deleted
self.unused_member_ids = set(existing_group.members.peer_ids) - set(members) # type: ignore
- self.unused_child_ids = set(existing_group.children.peer_ids) - set(children) # type: ignore
if not self.delete_unused_nodes:
return
@@ -194,7 +179,7 @@ def __init__(self, client: InfrahubClientSync) -> None:
def get_group(self, store_peers: bool = False) -> InfrahubNodeSync | None:
group_name = self._generate_group_name()
try:
- group = self.client.get(kind=self.group_type, name__value=group_name, include=["members", "children"])
+ group = self.client.get(kind=self.group_type, name__value=group_name, include=["members"])
except NodeNotFoundError:
return None
@@ -202,7 +187,6 @@ def get_group(self, store_peers: bool = False) -> InfrahubNodeSync | None:
return group
self.previous_members = group.members.peers # type: ignore[attr-defined]
- self.previous_children = group.children.peers # type: ignore[attr-defined]
return group
def delete_unused(self) -> None:
@@ -211,11 +195,6 @@ def delete_unused(self) -> None:
if member.id in self.unused_member_ids and member.typename:
self.client.delete(kind=member.typename, id=member.id)
- if self.previous_children and self.unused_child_ids:
- for child in self.previous_children:
- if child.id in self.unused_child_ids and child.typename:
- self.client.delete(kind=child.typename, id=child.id)
-
def add_related_nodes(self, ids: list[str], update_group_context: bool | None = None) -> None:
"""
Add related Nodes IDs to the context.
@@ -246,15 +225,9 @@ def update_group(self) -> None:
"""
Create or update (using upsert) a CoreStandardGroup to store all the Nodes and Groups used during an execution.
"""
- children: list[str] = []
- members: list[str] = []
-
- if self.related_group_ids:
- children = self.related_group_ids
- if self.related_node_ids:
- members = self.related_node_ids
+ members: list[str] = self.related_node_ids + self.related_group_ids
- if not children and not members:
+ if not members:
return
group_name = self._generate_group_name()
@@ -270,7 +243,6 @@ def update_group(self) -> None:
name=group_name,
description=description,
members=members,
- children=children,
)
group.save(allow_upsert=True, update_group_context=False)
@@ -279,7 +251,6 @@ def update_group(self) -> None:
# Calculate how many nodes should be deleted
self.unused_member_ids = set(existing_group.members.peer_ids) - set(members) # type: ignore
- self.unused_child_ids = set(existing_group.children.peer_ids) - set(children) # type: ignore
if not self.delete_unused_nodes:
return
diff --git a/infrahub_sdk/schema/__init__.py b/infrahub_sdk/schema/__init__.py
index 9b23fe49..080d7237 100644
--- a/infrahub_sdk/schema/__init__.py
+++ b/infrahub_sdk/schema/__init__.py
@@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
-from collections import defaultdict
from collections.abc import MutableMapping
from enum import Enum
from time import sleep
@@ -22,6 +21,7 @@
from .main import (
AttributeSchema,
AttributeSchemaAPI,
+ BranchSchema,
BranchSupportType,
GenericSchema,
GenericSchemaAPI,
@@ -169,7 +169,7 @@ def _get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str) -> str:
class InfrahubSchema(InfrahubSchemaBase):
def __init__(self, client: InfrahubClient):
self.client = client
- self.cache: dict = defaultdict(lambda: dict)
+ self.cache: dict[str, BranchSchema] = {}
async def get(
self,
@@ -183,23 +183,27 @@ async def get(
kind_str = self._get_schema_name(schema=kind)
if refresh:
- self.cache[branch] = await self.fetch(branch=branch, timeout=timeout)
+ self.cache[branch] = await self._fetch(branch=branch, timeout=timeout)
- if branch in self.cache and kind_str in self.cache[branch]:
- return self.cache[branch][kind_str]
+ if branch in self.cache and kind_str in self.cache[branch].nodes:
+ return self.cache[branch].nodes[kind_str]
# Fetching the latest schema from the server if we didn't fetch it earlier
# because we coulnd't find the object on the local cache
if not refresh:
- self.cache[branch] = await self.fetch(branch=branch, timeout=timeout)
+ self.cache[branch] = await self._fetch(branch=branch, timeout=timeout)
- if branch in self.cache and kind_str in self.cache[branch]:
- return self.cache[branch][kind_str]
+ if branch in self.cache and kind_str in self.cache[branch].nodes:
+ return self.cache[branch].nodes[kind_str]
raise SchemaNotFoundError(identifier=kind_str)
async def all(
- self, branch: str | None = None, refresh: bool = False, namespaces: list[str] | None = None
+ self,
+ branch: str | None = None,
+ refresh: bool = False,
+ namespaces: list[str] | None = None,
+ schema_hash: str | None = None,
) -> MutableMapping[str, MainSchemaTypesAPI]:
"""Retrieve the entire schema for a given branch.
@@ -209,15 +213,19 @@ async def all(
Args:
branch (str, optional): Name of the branch to query. Defaults to default_branch.
refresh (bool, optional): Force a refresh of the schema. Defaults to False.
+ schema_hash (str, optional): Only refresh if the current schema doesn't match this hash.
Returns:
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
"""
branch = branch or self.client.default_branch
+ if refresh and branch in self.cache and schema_hash and self.cache[branch].hash == schema_hash:
+ refresh = False
+
if refresh or branch not in self.cache:
- self.cache[branch] = await self.fetch(branch=branch, namespaces=namespaces)
+ self.cache[branch] = await self._fetch(branch=branch, namespaces=namespaces)
- return self.cache[branch]
+ return self.cache[branch].nodes
async def load(
self, schemas: list[dict], branch: str | None = None, wait_until_converged: bool = False
@@ -392,11 +400,17 @@ async def fetch(
Args:
branch (str): Name of the branch to fetch the schema for.
- timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
+ timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
Returns:
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
"""
+ branch_schema = await self._fetch(branch=branch, namespaces=namespaces, timeout=timeout)
+ return branch_schema.nodes
+
+ async def _fetch(
+ self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None
+ ) -> BranchSchema:
url_parts = [("branch", branch)]
if namespaces:
url_parts.extend([("namespaces", ns) for ns in namespaces])
@@ -425,16 +439,22 @@ async def fetch(
template = TemplateSchemaAPI(**template_schema)
nodes[template.kind] = template
- return nodes
+ schema_hash = data.get("main", "")
+
+ return BranchSchema(hash=schema_hash, nodes=nodes)
class InfrahubSchemaSync(InfrahubSchemaBase):
def __init__(self, client: InfrahubClientSync):
self.client = client
- self.cache: dict = defaultdict(lambda: dict)
+ self.cache: dict[str, BranchSchema] = {}
def all(
- self, branch: str | None = None, refresh: bool = False, namespaces: list[str] | None = None
+ self,
+ branch: str | None = None,
+ refresh: bool = False,
+ namespaces: list[str] | None = None,
+ schema_hash: str | None = None,
) -> MutableMapping[str, MainSchemaTypesAPI]:
"""Retrieve the entire schema for a given branch.
@@ -444,15 +464,19 @@ def all(
Args:
branch (str, optional): Name of the branch to query. Defaults to default_branch.
refresh (bool, optional): Force a refresh of the schema. Defaults to False.
+ schema_hash (str, optional): Only refresh if the current schema doesn't match this hash.
Returns:
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
"""
branch = branch or self.client.default_branch
+ if refresh and branch in self.cache and schema_hash and self.cache[branch].hash == schema_hash:
+ refresh = False
+
if refresh or branch not in self.cache:
- self.cache[branch] = self.fetch(branch=branch, namespaces=namespaces)
+ self.cache[branch] = self._fetch(branch=branch, namespaces=namespaces)
- return self.cache[branch]
+ return self.cache[branch].nodes
def get(
self,
@@ -466,18 +490,18 @@ def get(
kind_str = self._get_schema_name(schema=kind)
if refresh:
- self.cache[branch] = self.fetch(branch=branch)
+ self.cache[branch] = self._fetch(branch=branch)
- if branch in self.cache and kind_str in self.cache[branch]:
- return self.cache[branch][kind_str]
+ if branch in self.cache and kind_str in self.cache[branch].nodes:
+ return self.cache[branch].nodes[kind_str]
# Fetching the latest schema from the server if we didn't fetch it earlier
# because we coulnd't find the object on the local cache
if not refresh:
- self.cache[branch] = self.fetch(branch=branch, timeout=timeout)
+ self.cache[branch] = self._fetch(branch=branch, timeout=timeout)
- if branch in self.cache and kind_str in self.cache[branch]:
- return self.cache[branch][kind_str]
+ if branch in self.cache and kind_str in self.cache[branch].nodes:
+ return self.cache[branch].nodes[kind_str]
raise SchemaNotFoundError(identifier=kind_str)
@@ -600,17 +624,20 @@ def fetch(
Args:
branch (str): Name of the branch to fetch the schema for.
- timeout (int, optional): Overrides default timeout used when querying the graphql API. Specified in seconds.
+ timeout (int, optional): Overrides default timeout used when querying the GraphQL API. Specified in seconds.
Returns:
dict[str, MainSchemaTypes]: Dictionary of all schema organized by kind
"""
+ branch_schema = self._fetch(branch=branch, namespaces=namespaces, timeout=timeout)
+ return branch_schema.nodes
+
+ def _fetch(self, branch: str, namespaces: list[str] | None = None, timeout: int | None = None) -> BranchSchema:
url_parts = [("branch", branch)]
if namespaces:
url_parts.extend([("namespaces", ns) for ns in namespaces])
query_params = urlencode(url_parts)
url = f"{self.client.address}/api/schema?{query_params}"
-
response = self.client._get(url=url, timeout=timeout)
response.raise_for_status()
@@ -633,7 +660,9 @@ def fetch(
template = TemplateSchemaAPI(**template_schema)
nodes[template.kind] = template
- return nodes
+ schema_hash = data.get("main", "")
+
+ return BranchSchema(hash=schema_hash, nodes=nodes)
def load(
self, schemas: list[dict], branch: str | None = None, wait_until_converged: bool = False
diff --git a/infrahub_sdk/schema/main.py b/infrahub_sdk/schema/main.py
index 57aaa890..af5556b3 100644
--- a/infrahub_sdk/schema/main.py
+++ b/infrahub_sdk/schema/main.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import warnings
+from collections.abc import MutableMapping
from enum import Enum
from typing import TYPE_CHECKING, Any, Union
@@ -348,3 +349,10 @@ class SchemaRootAPI(BaseModel):
nodes: list[NodeSchemaAPI] = Field(default_factory=list)
profiles: list[ProfileSchemaAPI] = Field(default_factory=list)
templates: list[TemplateSchemaAPI] = Field(default_factory=list)
+
+
+class BranchSchema(BaseModel):
+ hash: str = Field(...)
+ nodes: MutableMapping[str, GenericSchemaAPI | NodeSchemaAPI | ProfileSchemaAPI | TemplateSchemaAPI] = Field(
+ default_factory=dict
+ )
diff --git a/pyproject.toml b/pyproject.toml
index abfadc14..8c528248 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "infrahub-sdk"
-version = "1.8.0"
+version = "1.9.0"
description = "Python Client to interact with Infrahub"
authors = ["OpsMill "]
readme = "README.md"
diff --git a/tests/fixtures/schema_01.json b/tests/fixtures/schema_01.json
index bf6f016e..344ebeab 100644
--- a/tests/fixtures/schema_01.json
+++ b/tests/fixtures/schema_01.json
@@ -1,4 +1,5 @@
{
+ "main": "c0272bc24cd943f21cf30affda06b12d",
"nodes": [
{
"name": "GraphQLQuery",
diff --git a/tests/unit/sdk/test_client.py b/tests/unit/sdk/test_client.py
index 9f0f4e33..a585ee87 100644
--- a/tests/unit/sdk/test_client.py
+++ b/tests/unit/sdk/test_client.py
@@ -134,18 +134,17 @@ async def test_method_get_user_permissions(clients, mock_query_infrahub_user, cl
@pytest.mark.parametrize("client_type", client_types)
async def test_method_all_with_limit(clients, mock_query_repository_page1_2, client_type):
if client_type == "standard":
- repos = await clients.standard.all(kind="CoreRepository", limit=3)
+ repos = await clients.standard.all(kind="CoreRepository", populate_store=False, limit=3)
assert not clients.standard.store._store["CoreRepository"]
- repos = await clients.standard.all(kind="CoreRepository", populate_store=True, limit=3)
+ repos = await clients.standard.all(kind="CoreRepository", limit=3)
assert len(clients.standard.store._store["CoreRepository"]) == 3
else:
- repos = clients.sync.all(kind="CoreRepository", limit=3)
+ repos = clients.sync.all(kind="CoreRepository", populate_store=False, limit=3)
assert not clients.sync.store._store["CoreRepository"]
- repos = clients.sync.all(kind="CoreRepository", populate_store=True, limit=3)
+ repos = clients.sync.all(kind="CoreRepository", limit=3)
assert len(clients.sync.store._store["CoreRepository"]) == 3
-
assert len(repos) == 3
@@ -154,16 +153,16 @@ async def test_method_all_multiple_pages(
clients, mock_query_repository_page1_2, mock_query_repository_page2_2, client_type
):
if client_type == "standard":
- repos = await clients.standard.all(kind="CoreRepository")
+ repos = await clients.standard.all(kind="CoreRepository", populate_store=False)
assert not clients.standard.store._store["CoreRepository"]
- repos = await clients.standard.all(kind="CoreRepository", populate_store=True)
+ repos = await clients.standard.all(kind="CoreRepository")
assert len(clients.standard.store._store["CoreRepository"]) == 5
else:
- repos = clients.sync.all(kind="CoreRepository")
+ repos = clients.sync.all(kind="CoreRepository", populate_store=False)
assert not clients.sync.store._store["CoreRepository"]
- repos = clients.sync.all(kind="CoreRepository", populate_store=True)
+ repos = clients.sync.all(kind="CoreRepository")
assert len(clients.sync.store._store["CoreRepository"]) == 5
assert len(repos) == 5
@@ -174,16 +173,16 @@ async def test_method_all_batching(
clients, mock_query_location_batch_count, mock_query_location_batch, client_type, use_parallel
):
if client_type == "standard":
- locations = await clients.standard.all(kind="BuiltinLocation", parallel=use_parallel)
+ locations = await clients.standard.all(kind="BuiltinLocation", populate_store=False, parallel=use_parallel)
assert not clients.standard.store._store["BuiltinLocation"]
- locations = await clients.standard.all(kind="BuiltinLocation", populate_store=True, parallel=use_parallel)
+ locations = await clients.standard.all(kind="BuiltinLocation", parallel=use_parallel)
assert len(clients.standard.store._store["BuiltinLocation"]) == 30
else:
- locations = clients.sync.all(kind="BuiltinLocation", parallel=use_parallel)
+ locations = clients.sync.all(kind="BuiltinLocation", populate_store=False, parallel=use_parallel)
assert not clients.sync.store._store["BuiltinLocation"]
- locations = clients.sync.all(kind="BuiltinLocation", populate_store=True, parallel=use_parallel)
+ locations = clients.sync.all(kind="BuiltinLocation", parallel=use_parallel)
assert len(clients.sync.store._store["BuiltinLocation"]) == 30
assert len(locations) == 30
@@ -192,16 +191,16 @@ async def test_method_all_batching(
@pytest.mark.parametrize("client_type", client_types)
async def test_method_all_single_page(clients, mock_query_repository_page1_1, client_type):
if client_type == "standard":
- repos = await clients.standard.all(kind="CoreRepository")
+ repos = await clients.standard.all(kind="CoreRepository", populate_store=False)
assert not clients.standard.store._store["CoreRepository"]
- repos = await clients.standard.all(kind="CoreRepository", populate_store=True)
+ repos = await clients.standard.all(kind="CoreRepository")
assert len(clients.standard.store._store["CoreRepository"]) == 2
else:
- repos = clients.sync.all(kind="CoreRepository")
+ repos = clients.sync.all(kind="CoreRepository", populate_store=False)
assert not clients.sync.store._store["CoreRepository"]
- repos = clients.sync.all(kind="CoreRepository", populate_store=True)
+ repos = clients.sync.all(kind="CoreRepository")
assert len(clients.sync.store._store["CoreRepository"]) == 2
assert len(repos) == 2
@@ -247,20 +246,22 @@ async def test_method_get_by_id(httpx_mock: HTTPXMock, clients, mock_schema_quer
)
if client_type == "standard":
- repo = await clients.standard.get(kind="CoreRepository", id=response_id)
+ repo = await clients.standard.get(kind="CoreRepository", id=response_id, populate_store=False)
assert isinstance(repo, InfrahubNode)
with pytest.raises(NodeNotFoundError):
assert clients.standard.store.get(key=response_id)
- repo = await clients.standard.get(kind="CoreRepository", id=response_id, populate_store=True)
+ repo = await clients.standard.get(kind="CoreRepository", id=response_id)
+ assert isinstance(repo, InfrahubNode)
assert clients.standard.store.get(key=response_id)
else:
- repo = clients.sync.get(kind="CoreRepository", id=response_id)
+ repo = clients.sync.get(kind="CoreRepository", id=response_id, populate_store=False)
assert isinstance(repo, InfrahubNodeSync)
with pytest.raises(NodeNotFoundError):
assert clients.sync.store.get(key=response_id)
- repo = clients.sync.get(kind="CoreRepository", id=response_id, populate_store=True)
+ repo = clients.sync.get(kind="CoreRepository", id=response_id)
+ assert isinstance(repo, InfrahubNodeSync)
assert clients.sync.store.get(key=response_id)
@@ -293,20 +294,22 @@ async def test_method_get_by_hfid(httpx_mock: HTTPXMock, clients, mock_schema_qu
)
if client_type == "standard":
- repo = await clients.standard.get(kind="CoreRepository", hfid=["infrahub-demo-core"])
+ repo = await clients.standard.get(kind="CoreRepository", hfid=["infrahub-demo-core"], populate_store=False)
assert isinstance(repo, InfrahubNode)
with pytest.raises(NodeNotFoundError):
assert clients.standard.store.get(key=response_id)
- repo = await clients.standard.get(kind="CoreRepository", hfid=["infrahub-demo-core"], populate_store=True)
+ repo = await clients.standard.get(kind="CoreRepository", hfid=["infrahub-demo-core"])
+ assert isinstance(repo, InfrahubNode)
assert clients.standard.store.get(key=response_id)
else:
- repo = clients.sync.get(kind="CoreRepository", hfid=["infrahub-demo-core"])
+ repo = clients.sync.get(kind="CoreRepository", hfid=["infrahub-demo-core"], populate_store=False)
assert isinstance(repo, InfrahubNodeSync)
with pytest.raises(NodeNotFoundError):
assert clients.sync.store.get(key="infrahub-demo-core")
- repo = clients.sync.get(kind="CoreRepository", hfid=["infrahub-demo-core"], populate_store=True)
+ repo = clients.sync.get(kind="CoreRepository", hfid=["infrahub-demo-core"])
+ assert isinstance(repo, InfrahubNodeSync)
assert clients.sync.store.get(key=response_id)
@@ -338,20 +341,20 @@ async def test_method_get_by_default_filter(httpx_mock: HTTPXMock, clients, mock
)
if client_type == "standard":
- repo = await clients.standard.get(kind="CoreRepository", id="infrahub-demo-core")
+ repo = await clients.standard.get(kind="CoreRepository", id="infrahub-demo-core", populate_store=False)
assert isinstance(repo, InfrahubNode)
with pytest.raises(NodeNotFoundError):
assert clients.standard.store.get(key=response_id)
- repo = await clients.standard.get(kind="CoreRepository", id="infrahub-demo-core", populate_store=True)
+ repo = await clients.standard.get(kind="CoreRepository", id="infrahub-demo-core")
assert clients.standard.store.get(key=response_id)
else:
- repo = clients.sync.get(kind="CoreRepository", id="infrahub-demo-core")
+ repo = clients.sync.get(kind="CoreRepository", id="infrahub-demo-core", populate_store=False)
assert isinstance(repo, InfrahubNodeSync)
with pytest.raises(NodeNotFoundError):
assert clients.sync.store.get(key="infrahub-demo-core")
- repo = clients.sync.get(kind="CoreRepository", id="infrahub-demo-core", populate_store=True)
+ repo = clients.sync.get(kind="CoreRepository", id="infrahub-demo-core")
assert clients.sync.store.get(key=response_id)
@@ -437,6 +440,7 @@ async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_re
"bfae43e8-5ebb-456c-a946-bf64e930710a",
"9486cfce-87db-479d-ad73-07d80ba96a0f",
],
+ populate_store=False,
)
assert len(repos) == 2
assert not clients.standard.store._store["CoreRepository"]
@@ -447,7 +451,6 @@ async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_re
"bfae43e8-5ebb-456c-a946-bf64e930710a",
"9486cfce-87db-479d-ad73-07d80ba96a0f",
],
- populate_store=True,
)
assert len(clients.standard.store._store["CoreRepository"]) == 2
assert len(repos) == 2
@@ -458,6 +461,7 @@ async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_re
"bfae43e8-5ebb-456c-a946-bf64e930710a",
"9486cfce-87db-479d-ad73-07d80ba96a0f",
],
+ populate_store=False,
)
assert len(repos) == 2
assert not clients.sync.store._store["CoreRepository"]
@@ -468,7 +472,6 @@ async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_re
"bfae43e8-5ebb-456c-a946-bf64e930710a",
"9486cfce-87db-479d-ad73-07d80ba96a0f",
],
- populate_store=True,
)
assert len(clients.sync.store._store["CoreRepository"]) == 2
assert len(repos) == 2
diff --git a/tests/unit/sdk/test_schema.py b/tests/unit/sdk/test_schema.py
index fcffb1c0..ee07a137 100644
--- a/tests/unit/sdk/test_schema.py
+++ b/tests/unit/sdk/test_schema.py
@@ -64,6 +64,33 @@ async def test_fetch_schema(mock_schema_query_01, client_type):
assert isinstance(nodes["BuiltinTag"], NodeSchemaAPI)
+@pytest.mark.parametrize("client_type", client_types)
+async def test_fetch_schema_conditional_refresh(mock_schema_query_01: HTTPXMock, client_type: str) -> None:
+ """Verify that only one schema request is sent if we request to update the schema but already have the correct hash"""
+ if client_type == "standard":
+ client = InfrahubClient(config=Config(address="http://mock", insert_tracker=True))
+ nodes = await client.schema.all(branch="main")
+ schema_hash = client.schema.cache["main"].hash
+ assert schema_hash
+ nodes = await client.schema.all(branch="main", refresh=True, schema_hash=schema_hash)
+ else:
+ client = InfrahubClientSync(config=Config(address="http://mock", insert_tracker=True))
+ nodes = client.schema.all(branch="main")
+ schema_hash = client.schema.cache["main"].hash
+ assert schema_hash
+ nodes = client.schema.all(branch="main", refresh=True, schema_hash=schema_hash)
+
+ assert len(nodes) == 4
+ assert sorted(nodes.keys()) == [
+ "BuiltinLocation",
+ "BuiltinTag",
+ "CoreGraphQLQuery",
+ "CoreRepository",
+ ]
+ assert isinstance(nodes["BuiltinTag"], NodeSchemaAPI)
+ assert len(mock_schema_query_01.get_requests()) == 1
+
+
@pytest.mark.parametrize("client_type", client_types)
async def test_schema_data_validation(rfile_schema, client_type):
if client_type == "standard":