Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/15.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The SDK client query methods (get, filters, all) default behaviour has changed. The query methods will store the retrieved nodes in the internal store by default, where previously this behaviour had to be enabled explicitly using the `populate_store` argument.
44 changes: 31 additions & 13 deletions docs/docs/python-sdk/guides/store.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,39 @@ The store can also be used to store Nodes that we retrieve using the client quer

## Storing nodes in the store

Nodes retrieved from Infrahub using a query method, will be stored in the internal store when you set the `populate_store` argument to `True`.
Nodes stored in the store using this method can be retrieved using their `id` as the key in the store.
Nodes retrieved from Infrahub using a the SDK client's query methods, such as the `get` `filters` or `all` method, will be automatically stored in the internal store. Nodes stored in the store using this method can be retrieved using their `id` as the key in the store.

<Tabs groupId="async-sync">
<TabItem value="Async" default>

```python
tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=True)
tag = await client.get(kind="BuiltinTag", name__value="RED")
```

</TabItem>
<TabItem value="Sync" default>

```python
tag = client.get(kind="BuiltinTag", name__value="RED")
```

</TabItem>
</Tabs>

This behaviour may not be desirable in all scenarios, therefor you can explicitly disable this behaviour by setting the `populate_store` argument to `False` when calling the query methods.

<Tabs groupId="async-sync">
<TabItem value="Async" default>

```python
tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=False)
```

</TabItem>
<TabItem value="Sync" default>

```python
tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=True)
tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=False)
```

</TabItem>
Expand All @@ -42,15 +60,15 @@ You can store nodes in the object store manually using the `set` method. This ha
<TabItem value="Async" default>

```python
tag = await client.get(kind="BuiltinTag", name__value="RED")
tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=False)
client.store.set(key=tag.name.value, node=tag)
```

</TabItem>
<TabItem value="Sync" default>

```python
tag = client.get(kind="BuiltinTag", name__value="RED")
tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=False)
client.store.set(key=tag.name.value, node=tag)
```

Expand All @@ -60,13 +78,13 @@ You can store nodes in the object store manually using the `set` method. This ha
## Retrieving object from the store

Nodes can be retrieved from the internal store using the key that was used to store them.
For nodes that are stored using the `populate_store` argument on a query method, this will be their `id`.
For nodes that are stored by the client's query methods, this will be their `id`.

<Tabs groupId="async-sync">
<TabItem value="Async" default>

```python
tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=True)
tag = await client.get(kind="BuiltinTag", name__value="RED")
tag_in_store = client.store.get(key=tag.id)
assert tag == tag_in_store
```
Expand All @@ -75,7 +93,7 @@ For nodes that are stored using the `populate_store` argument on a query method,
<TabItem value="Sync" default>

```python
tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=True)
tag = client.get(kind="BuiltinTag", name__value="RED")
tag_in_store = client.store.get(key=tag.id)
assert tag == tag_in_store
```
Expand All @@ -89,7 +107,7 @@ For nodes that have been added manually to the store, this will be the key that
<TabItem value="Async" default>

```python
tag = await client.get(kind="BuiltinTag", name__value="RED")
tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=False)
client.store.set(key=tag.name.value, node=tag)
tag_in_store = client.store.get(key=tag.name.value)
assert tag == tag_in_store
Expand All @@ -99,7 +117,7 @@ For nodes that have been added manually to the store, this will be the key that
<TabItem value="Sync" default>

```python
tag = client.get(kind="BuiltinTag", name__value="RED")
tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=False)
client.store.set(key=tag.name.value, node=tag)
tag_in_store = client.store.get(key=tag.name.value)
assert tag == tag_in_store
Expand Down Expand Up @@ -142,7 +160,7 @@ You can use a custom store, outside of the Infrahub SDK client. Storing or retri
from infrahub_sdk.store import NodeStore
store = NodeStore()

device = await client.get(kind="TestDevice", name__value="atl1-edge1")
device = await client.get(kind="TestDevice", name__value="atl1-edge1", populate_store=False)
store.set(key=device.name.value, node=device)
store.get(key=device.name.value)
```
Expand All @@ -154,7 +172,7 @@ You can use a custom store, outside of the Infrahub SDK client. Storing or retri
from infrahub_sdk.store import NodeStoreSync
store = NodeStoreSync()

device = await client.get(kind="TestDevice", name__value="atl1-edge1")
device = await client.get(kind="TestDevice", name__value="atl1-edge1", populate_store=False)
store.set(key=device.name.value, node=device)
store.get(key=device.name.value)
```
Expand Down
12 changes: 6 additions & 6 deletions infrahub_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ async def get(
hfid: list[str] | None = None,
include: list[str] | None = None,
exclude: list[str] | None = None,
populate_store: bool = False,
populate_store: bool = True,
fragment: bool = False,
prefetch_relationships: bool = False,
property: bool = False,
Expand Down Expand Up @@ -611,7 +611,7 @@ async def all(
at: Timestamp | None = None,
branch: str | None = None,
timeout: int | None = None,
populate_store: bool = False,
populate_store: bool = True,
offset: int | None = None,
limit: int | None = None,
include: list[str] | None = None,
Expand Down Expand Up @@ -707,7 +707,7 @@ async def filters(
at: Timestamp | None = None,
branch: str | None = None,
timeout: int | None = None,
populate_store: bool = False,
populate_store: bool = True,
offset: int | None = None,
limit: int | None = None,
include: list[str] | None = None,
Expand Down Expand Up @@ -1726,7 +1726,7 @@ def all(
at: Timestamp | None = None,
branch: str | None = None,
timeout: int | None = None,
populate_store: bool = False,
populate_store: bool = True,
offset: int | None = None,
limit: int | None = None,
include: list[str] | None = None,
Expand Down Expand Up @@ -1857,7 +1857,7 @@ def filters(
at: Timestamp | None = None,
branch: str | None = None,
timeout: int | None = None,
populate_store: bool = False,
populate_store: bool = True,
offset: int | None = None,
limit: int | None = None,
include: list[str] | None = None,
Expand Down Expand Up @@ -2110,7 +2110,7 @@ def get(
hfid: list[str] | None = None,
include: list[str] | None = None,
exclude: list[str] | None = None,
populate_store: bool = False,
populate_store: bool = True,
fragment: bool = False,
prefetch_relationships: bool = False,
property: bool = False,
Expand Down
65 changes: 34 additions & 31 deletions tests/unit/sdk/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,17 @@ async def test_method_get_user_permissions(clients, mock_query_infrahub_user, cl
@pytest.mark.parametrize("client_type", client_types)
async def test_method_all_with_limit(clients, mock_query_repository_page1_2, client_type):
if client_type == "standard":
repos = await clients.standard.all(kind="CoreRepository", limit=3)
repos = await clients.standard.all(kind="CoreRepository", populate_store=False, limit=3)
assert not clients.standard.store._store["CoreRepository"]

repos = await clients.standard.all(kind="CoreRepository", populate_store=True, limit=3)
repos = await clients.standard.all(kind="CoreRepository", limit=3)
assert len(clients.standard.store._store["CoreRepository"]) == 3
else:
repos = clients.sync.all(kind="CoreRepository", limit=3)
repos = clients.sync.all(kind="CoreRepository", populate_store=False, limit=3)
assert not clients.sync.store._store["CoreRepository"]

repos = clients.sync.all(kind="CoreRepository", populate_store=True, limit=3)
repos = clients.sync.all(kind="CoreRepository", limit=3)
assert len(clients.sync.store._store["CoreRepository"]) == 3

assert len(repos) == 3


Expand All @@ -148,16 +147,16 @@ async def test_method_all_multiple_pages(
clients, mock_query_repository_page1_2, mock_query_repository_page2_2, client_type
):
if client_type == "standard":
repos = await clients.standard.all(kind="CoreRepository")
repos = await clients.standard.all(kind="CoreRepository", populate_store=False)
assert not clients.standard.store._store["CoreRepository"]

repos = await clients.standard.all(kind="CoreRepository", populate_store=True)
repos = await clients.standard.all(kind="CoreRepository")
assert len(clients.standard.store._store["CoreRepository"]) == 5
else:
repos = clients.sync.all(kind="CoreRepository")
repos = clients.sync.all(kind="CoreRepository", populate_store=False)
assert not clients.sync.store._store["CoreRepository"]

repos = clients.sync.all(kind="CoreRepository", populate_store=True)
repos = clients.sync.all(kind="CoreRepository")
assert len(clients.sync.store._store["CoreRepository"]) == 5

assert len(repos) == 5
Expand All @@ -168,16 +167,16 @@ async def test_method_all_batching(
clients, mock_query_location_batch_count, mock_query_location_batch, client_type, use_parallel
):
if client_type == "standard":
locations = await clients.standard.all(kind="BuiltinLocation", parallel=use_parallel)
locations = await clients.standard.all(kind="BuiltinLocation", populate_store=False, parallel=use_parallel)
assert not clients.standard.store._store["BuiltinLocation"]

locations = await clients.standard.all(kind="BuiltinLocation", populate_store=True, parallel=use_parallel)
locations = await clients.standard.all(kind="BuiltinLocation", parallel=use_parallel)
assert len(clients.standard.store._store["BuiltinLocation"]) == 30
else:
locations = clients.sync.all(kind="BuiltinLocation", parallel=use_parallel)
locations = clients.sync.all(kind="BuiltinLocation", populate_store=False, parallel=use_parallel)
assert not clients.sync.store._store["BuiltinLocation"]

locations = clients.sync.all(kind="BuiltinLocation", populate_store=True, parallel=use_parallel)
locations = clients.sync.all(kind="BuiltinLocation", parallel=use_parallel)
assert len(clients.sync.store._store["BuiltinLocation"]) == 30

assert len(locations) == 30
Expand All @@ -186,16 +185,16 @@ async def test_method_all_batching(
@pytest.mark.parametrize("client_type", client_types)
async def test_method_all_single_page(clients, mock_query_repository_page1_1, client_type):
if client_type == "standard":
repos = await clients.standard.all(kind="CoreRepository")
repos = await clients.standard.all(kind="CoreRepository", populate_store=False)
assert not clients.standard.store._store["CoreRepository"]

repos = await clients.standard.all(kind="CoreRepository", populate_store=True)
repos = await clients.standard.all(kind="CoreRepository")
assert len(clients.standard.store._store["CoreRepository"]) == 2
else:
repos = clients.sync.all(kind="CoreRepository")
repos = clients.sync.all(kind="CoreRepository", populate_store=False)
assert not clients.sync.store._store["CoreRepository"]

repos = clients.sync.all(kind="CoreRepository", populate_store=True)
repos = clients.sync.all(kind="CoreRepository")
assert len(clients.sync.store._store["CoreRepository"]) == 2

assert len(repos) == 2
Expand Down Expand Up @@ -241,20 +240,22 @@ async def test_method_get_by_id(httpx_mock: HTTPXMock, clients, mock_schema_quer
)

if client_type == "standard":
repo = await clients.standard.get(kind="CoreRepository", id=response_id)
repo = await clients.standard.get(kind="CoreRepository", id=response_id, populate_store=False)
assert isinstance(repo, InfrahubNode)
with pytest.raises(NodeNotFoundError):
assert clients.standard.store.get(key=response_id)

repo = await clients.standard.get(kind="CoreRepository", id=response_id, populate_store=True)
repo = await clients.standard.get(kind="CoreRepository", id=response_id)
assert isinstance(repo, InfrahubNode)
assert clients.standard.store.get(key=response_id)
else:
repo = clients.sync.get(kind="CoreRepository", id=response_id)
repo = clients.sync.get(kind="CoreRepository", id=response_id, populate_store=False)
assert isinstance(repo, InfrahubNodeSync)
with pytest.raises(NodeNotFoundError):
assert clients.sync.store.get(key=response_id)

repo = clients.sync.get(kind="CoreRepository", id=response_id, populate_store=True)
repo = clients.sync.get(kind="CoreRepository", id=response_id)
assert isinstance(repo, InfrahubNodeSync)
assert clients.sync.store.get(key=response_id)


Expand Down Expand Up @@ -287,20 +288,22 @@ async def test_method_get_by_hfid(httpx_mock: HTTPXMock, clients, mock_schema_qu
)

if client_type == "standard":
repo = await clients.standard.get(kind="CoreRepository", hfid=["infrahub-demo-core"])
repo = await clients.standard.get(kind="CoreRepository", hfid=["infrahub-demo-core"], populate_store=False)
assert isinstance(repo, InfrahubNode)
with pytest.raises(NodeNotFoundError):
assert clients.standard.store.get(key=response_id)

repo = await clients.standard.get(kind="CoreRepository", hfid=["infrahub-demo-core"], populate_store=True)
repo = await clients.standard.get(kind="CoreRepository", hfid=["infrahub-demo-core"])
assert isinstance(repo, InfrahubNode)
assert clients.standard.store.get(key=response_id)
else:
repo = clients.sync.get(kind="CoreRepository", hfid=["infrahub-demo-core"])
repo = clients.sync.get(kind="CoreRepository", hfid=["infrahub-demo-core"], populate_store=False)
assert isinstance(repo, InfrahubNodeSync)
with pytest.raises(NodeNotFoundError):
assert clients.sync.store.get(key="infrahub-demo-core")

repo = clients.sync.get(kind="CoreRepository", hfid=["infrahub-demo-core"], populate_store=True)
repo = clients.sync.get(kind="CoreRepository", hfid=["infrahub-demo-core"])
assert isinstance(repo, InfrahubNodeSync)
assert clients.sync.store.get(key=response_id)


Expand Down Expand Up @@ -332,20 +335,20 @@ async def test_method_get_by_default_filter(httpx_mock: HTTPXMock, clients, mock
)

if client_type == "standard":
repo = await clients.standard.get(kind="CoreRepository", id="infrahub-demo-core")
repo = await clients.standard.get(kind="CoreRepository", id="infrahub-demo-core", populate_store=False)
assert isinstance(repo, InfrahubNode)
with pytest.raises(NodeNotFoundError):
assert clients.standard.store.get(key=response_id)

repo = await clients.standard.get(kind="CoreRepository", id="infrahub-demo-core", populate_store=True)
repo = await clients.standard.get(kind="CoreRepository", id="infrahub-demo-core")
assert clients.standard.store.get(key=response_id)
else:
repo = clients.sync.get(kind="CoreRepository", id="infrahub-demo-core")
repo = clients.sync.get(kind="CoreRepository", id="infrahub-demo-core", populate_store=False)
assert isinstance(repo, InfrahubNodeSync)
with pytest.raises(NodeNotFoundError):
assert clients.sync.store.get(key="infrahub-demo-core")

repo = clients.sync.get(kind="CoreRepository", id="infrahub-demo-core", populate_store=True)
repo = clients.sync.get(kind="CoreRepository", id="infrahub-demo-core")
assert clients.sync.store.get(key=response_id)


Expand Down Expand Up @@ -431,6 +434,7 @@ async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_re
"bfae43e8-5ebb-456c-a946-bf64e930710a",
"9486cfce-87db-479d-ad73-07d80ba96a0f",
],
populate_store=False,
)
assert len(repos) == 2
assert not clients.standard.store._store["CoreRepository"]
Expand All @@ -441,7 +445,6 @@ async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_re
"bfae43e8-5ebb-456c-a946-bf64e930710a",
"9486cfce-87db-479d-ad73-07d80ba96a0f",
],
populate_store=True,
)
assert len(clients.standard.store._store["CoreRepository"]) == 2
assert len(repos) == 2
Expand All @@ -452,6 +455,7 @@ async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_re
"bfae43e8-5ebb-456c-a946-bf64e930710a",
"9486cfce-87db-479d-ad73-07d80ba96a0f",
],
populate_store=False,
)
assert len(repos) == 2
assert not clients.sync.store._store["CoreRepository"]
Expand All @@ -462,7 +466,6 @@ async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_re
"bfae43e8-5ebb-456c-a946-bf64e930710a",
"9486cfce-87db-479d-ad73-07d80ba96a0f",
],
populate_store=True,
)
assert len(clients.sync.store._store["CoreRepository"]) == 2
assert len(repos) == 2
Expand Down