diff --git a/changelog/15.fixed.md b/changelog/15.fixed.md new file mode 100644 index 00000000..4105142d --- /dev/null +++ b/changelog/15.fixed.md @@ -0,0 +1 @@ +The SDK client query methods (get, filters, all) default behaviour has changed. The query methods will store the retrieved nodes in the internal store by default, where previously this behaviour had to be enabled explicitly using the `populate_store` argument. diff --git a/docs/docs/python-sdk/guides/store.mdx b/docs/docs/python-sdk/guides/store.mdx index f2d8546f..07bcfc58 100644 --- a/docs/docs/python-sdk/guides/store.mdx +++ b/docs/docs/python-sdk/guides/store.mdx @@ -14,21 +14,39 @@ The store can also be used to store Nodes that we retrieve using the client quer ## Storing nodes in the store -Nodes retrieved from Infrahub using a query method, will be stored in the internal store when you set the `populate_store` argument to `True`. -Nodes stored in the store using this method can be retrieved using their `id` as the key in the store. +Nodes retrieved from Infrahub using a the SDK client's query methods, such as the `get` `filters` or `all` method, will be automatically stored in the internal store. Nodes stored in the store using this method can be retrieved using their `id` as the key in the store. ```python - tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=True) + tag = await client.get(kind="BuiltinTag", name__value="RED") + ``` + + + + + ```python + tag = client.get(kind="BuiltinTag", name__value="RED") + ``` + + + + +This behaviour may not be desirable in all scenarios, therefor you can explicitly disable this behaviour by setting the `populate_store` argument to `False` when calling the query methods. + + + + + ```python + tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=False) ``` ```python - tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=True) + tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=False) ``` @@ -42,7 +60,7 @@ You can store nodes in the object store manually using the `set` method. This ha ```python - tag = await client.get(kind="BuiltinTag", name__value="RED") + tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=False) client.store.set(key=tag.name.value, node=tag) ``` @@ -50,7 +68,7 @@ You can store nodes in the object store manually using the `set` method. This ha ```python - tag = client.get(kind="BuiltinTag", name__value="RED") + tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=False) client.store.set(key=tag.name.value, node=tag) ``` @@ -60,13 +78,13 @@ You can store nodes in the object store manually using the `set` method. This ha ## Retrieving object from the store Nodes can be retrieved from the internal store using the key that was used to store them. -For nodes that are stored using the `populate_store` argument on a query method, this will be their `id`. +For nodes that are stored by the client's query methods, this will be their `id`. ```python - tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=True) + tag = await client.get(kind="BuiltinTag", name__value="RED") tag_in_store = client.store.get(key=tag.id) assert tag == tag_in_store ``` @@ -75,7 +93,7 @@ For nodes that are stored using the `populate_store` argument on a query method, ```python - tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=True) + tag = client.get(kind="BuiltinTag", name__value="RED") tag_in_store = client.store.get(key=tag.id) assert tag == tag_in_store ``` @@ -89,7 +107,7 @@ For nodes that have been added manually to the store, this will be the key that ```python - tag = await client.get(kind="BuiltinTag", name__value="RED") + tag = await client.get(kind="BuiltinTag", name__value="RED", populate_store=False) client.store.set(key=tag.name.value, node=tag) tag_in_store = client.store.get(key=tag.name.value) assert tag == tag_in_store @@ -99,7 +117,7 @@ For nodes that have been added manually to the store, this will be the key that ```python - tag = client.get(kind="BuiltinTag", name__value="RED") + tag = client.get(kind="BuiltinTag", name__value="RED", populate_store=False) client.store.set(key=tag.name.value, node=tag) tag_in_store = client.store.get(key=tag.name.value) assert tag == tag_in_store @@ -142,7 +160,7 @@ You can use a custom store, outside of the Infrahub SDK client. Storing or retri from infrahub_sdk.store import NodeStore store = NodeStore() - device = await client.get(kind="TestDevice", name__value="atl1-edge1") + device = await client.get(kind="TestDevice", name__value="atl1-edge1", populate_store=False) store.set(key=device.name.value, node=device) store.get(key=device.name.value) ``` @@ -154,7 +172,7 @@ You can use a custom store, outside of the Infrahub SDK client. Storing or retri from infrahub_sdk.store import NodeStoreSync store = NodeStoreSync() - device = await client.get(kind="TestDevice", name__value="atl1-edge1") + device = await client.get(kind="TestDevice", name__value="atl1-edge1", populate_store=False) store.set(key=device.name.value, node=device) store.get(key=device.name.value) ``` diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index 0d689afc..b0b9530a 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -457,7 +457,7 @@ async def get( hfid: list[str] | None = None, include: list[str] | None = None, exclude: list[str] | None = None, - populate_store: bool = False, + populate_store: bool = True, fragment: bool = False, prefetch_relationships: bool = False, property: bool = False, @@ -611,7 +611,7 @@ async def all( at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, - populate_store: bool = False, + populate_store: bool = True, offset: int | None = None, limit: int | None = None, include: list[str] | None = None, @@ -707,7 +707,7 @@ async def filters( at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, - populate_store: bool = False, + populate_store: bool = True, offset: int | None = None, limit: int | None = None, include: list[str] | None = None, @@ -1726,7 +1726,7 @@ def all( at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, - populate_store: bool = False, + populate_store: bool = True, offset: int | None = None, limit: int | None = None, include: list[str] | None = None, @@ -1857,7 +1857,7 @@ def filters( at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, - populate_store: bool = False, + populate_store: bool = True, offset: int | None = None, limit: int | None = None, include: list[str] | None = None, @@ -2110,7 +2110,7 @@ def get( hfid: list[str] | None = None, include: list[str] | None = None, exclude: list[str] | None = None, - populate_store: bool = False, + populate_store: bool = True, fragment: bool = False, prefetch_relationships: bool = False, property: bool = False, diff --git a/tests/unit/sdk/test_client.py b/tests/unit/sdk/test_client.py index 4d849221..77607e74 100644 --- a/tests/unit/sdk/test_client.py +++ b/tests/unit/sdk/test_client.py @@ -128,18 +128,17 @@ async def test_method_get_user_permissions(clients, mock_query_infrahub_user, cl @pytest.mark.parametrize("client_type", client_types) async def test_method_all_with_limit(clients, mock_query_repository_page1_2, client_type): if client_type == "standard": - repos = await clients.standard.all(kind="CoreRepository", limit=3) + repos = await clients.standard.all(kind="CoreRepository", populate_store=False, limit=3) assert not clients.standard.store._store["CoreRepository"] - repos = await clients.standard.all(kind="CoreRepository", populate_store=True, limit=3) + repos = await clients.standard.all(kind="CoreRepository", limit=3) assert len(clients.standard.store._store["CoreRepository"]) == 3 else: - repos = clients.sync.all(kind="CoreRepository", limit=3) + repos = clients.sync.all(kind="CoreRepository", populate_store=False, limit=3) assert not clients.sync.store._store["CoreRepository"] - repos = clients.sync.all(kind="CoreRepository", populate_store=True, limit=3) + repos = clients.sync.all(kind="CoreRepository", limit=3) assert len(clients.sync.store._store["CoreRepository"]) == 3 - assert len(repos) == 3 @@ -148,16 +147,16 @@ async def test_method_all_multiple_pages( clients, mock_query_repository_page1_2, mock_query_repository_page2_2, client_type ): if client_type == "standard": - repos = await clients.standard.all(kind="CoreRepository") + repos = await clients.standard.all(kind="CoreRepository", populate_store=False) assert not clients.standard.store._store["CoreRepository"] - repos = await clients.standard.all(kind="CoreRepository", populate_store=True) + repos = await clients.standard.all(kind="CoreRepository") assert len(clients.standard.store._store["CoreRepository"]) == 5 else: - repos = clients.sync.all(kind="CoreRepository") + repos = clients.sync.all(kind="CoreRepository", populate_store=False) assert not clients.sync.store._store["CoreRepository"] - repos = clients.sync.all(kind="CoreRepository", populate_store=True) + repos = clients.sync.all(kind="CoreRepository") assert len(clients.sync.store._store["CoreRepository"]) == 5 assert len(repos) == 5 @@ -168,16 +167,16 @@ async def test_method_all_batching( clients, mock_query_location_batch_count, mock_query_location_batch, client_type, use_parallel ): if client_type == "standard": - locations = await clients.standard.all(kind="BuiltinLocation", parallel=use_parallel) + locations = await clients.standard.all(kind="BuiltinLocation", populate_store=False, parallel=use_parallel) assert not clients.standard.store._store["BuiltinLocation"] - locations = await clients.standard.all(kind="BuiltinLocation", populate_store=True, parallel=use_parallel) + locations = await clients.standard.all(kind="BuiltinLocation", parallel=use_parallel) assert len(clients.standard.store._store["BuiltinLocation"]) == 30 else: - locations = clients.sync.all(kind="BuiltinLocation", parallel=use_parallel) + locations = clients.sync.all(kind="BuiltinLocation", populate_store=False, parallel=use_parallel) assert not clients.sync.store._store["BuiltinLocation"] - locations = clients.sync.all(kind="BuiltinLocation", populate_store=True, parallel=use_parallel) + locations = clients.sync.all(kind="BuiltinLocation", parallel=use_parallel) assert len(clients.sync.store._store["BuiltinLocation"]) == 30 assert len(locations) == 30 @@ -186,16 +185,16 @@ async def test_method_all_batching( @pytest.mark.parametrize("client_type", client_types) async def test_method_all_single_page(clients, mock_query_repository_page1_1, client_type): if client_type == "standard": - repos = await clients.standard.all(kind="CoreRepository") + repos = await clients.standard.all(kind="CoreRepository", populate_store=False) assert not clients.standard.store._store["CoreRepository"] - repos = await clients.standard.all(kind="CoreRepository", populate_store=True) + repos = await clients.standard.all(kind="CoreRepository") assert len(clients.standard.store._store["CoreRepository"]) == 2 else: - repos = clients.sync.all(kind="CoreRepository") + repos = clients.sync.all(kind="CoreRepository", populate_store=False) assert not clients.sync.store._store["CoreRepository"] - repos = clients.sync.all(kind="CoreRepository", populate_store=True) + repos = clients.sync.all(kind="CoreRepository") assert len(clients.sync.store._store["CoreRepository"]) == 2 assert len(repos) == 2 @@ -241,20 +240,22 @@ async def test_method_get_by_id(httpx_mock: HTTPXMock, clients, mock_schema_quer ) if client_type == "standard": - repo = await clients.standard.get(kind="CoreRepository", id=response_id) + repo = await clients.standard.get(kind="CoreRepository", id=response_id, populate_store=False) assert isinstance(repo, InfrahubNode) with pytest.raises(NodeNotFoundError): assert clients.standard.store.get(key=response_id) - repo = await clients.standard.get(kind="CoreRepository", id=response_id, populate_store=True) + repo = await clients.standard.get(kind="CoreRepository", id=response_id) + assert isinstance(repo, InfrahubNode) assert clients.standard.store.get(key=response_id) else: - repo = clients.sync.get(kind="CoreRepository", id=response_id) + repo = clients.sync.get(kind="CoreRepository", id=response_id, populate_store=False) assert isinstance(repo, InfrahubNodeSync) with pytest.raises(NodeNotFoundError): assert clients.sync.store.get(key=response_id) - repo = clients.sync.get(kind="CoreRepository", id=response_id, populate_store=True) + repo = clients.sync.get(kind="CoreRepository", id=response_id) + assert isinstance(repo, InfrahubNodeSync) assert clients.sync.store.get(key=response_id) @@ -287,20 +288,22 @@ async def test_method_get_by_hfid(httpx_mock: HTTPXMock, clients, mock_schema_qu ) if client_type == "standard": - repo = await clients.standard.get(kind="CoreRepository", hfid=["infrahub-demo-core"]) + repo = await clients.standard.get(kind="CoreRepository", hfid=["infrahub-demo-core"], populate_store=False) assert isinstance(repo, InfrahubNode) with pytest.raises(NodeNotFoundError): assert clients.standard.store.get(key=response_id) - repo = await clients.standard.get(kind="CoreRepository", hfid=["infrahub-demo-core"], populate_store=True) + repo = await clients.standard.get(kind="CoreRepository", hfid=["infrahub-demo-core"]) + assert isinstance(repo, InfrahubNode) assert clients.standard.store.get(key=response_id) else: - repo = clients.sync.get(kind="CoreRepository", hfid=["infrahub-demo-core"]) + repo = clients.sync.get(kind="CoreRepository", hfid=["infrahub-demo-core"], populate_store=False) assert isinstance(repo, InfrahubNodeSync) with pytest.raises(NodeNotFoundError): assert clients.sync.store.get(key="infrahub-demo-core") - repo = clients.sync.get(kind="CoreRepository", hfid=["infrahub-demo-core"], populate_store=True) + repo = clients.sync.get(kind="CoreRepository", hfid=["infrahub-demo-core"]) + assert isinstance(repo, InfrahubNodeSync) assert clients.sync.store.get(key=response_id) @@ -332,20 +335,20 @@ async def test_method_get_by_default_filter(httpx_mock: HTTPXMock, clients, mock ) if client_type == "standard": - repo = await clients.standard.get(kind="CoreRepository", id="infrahub-demo-core") + repo = await clients.standard.get(kind="CoreRepository", id="infrahub-demo-core", populate_store=False) assert isinstance(repo, InfrahubNode) with pytest.raises(NodeNotFoundError): assert clients.standard.store.get(key=response_id) - repo = await clients.standard.get(kind="CoreRepository", id="infrahub-demo-core", populate_store=True) + repo = await clients.standard.get(kind="CoreRepository", id="infrahub-demo-core") assert clients.standard.store.get(key=response_id) else: - repo = clients.sync.get(kind="CoreRepository", id="infrahub-demo-core") + repo = clients.sync.get(kind="CoreRepository", id="infrahub-demo-core", populate_store=False) assert isinstance(repo, InfrahubNodeSync) with pytest.raises(NodeNotFoundError): assert clients.sync.store.get(key="infrahub-demo-core") - repo = clients.sync.get(kind="CoreRepository", id="infrahub-demo-core", populate_store=True) + repo = clients.sync.get(kind="CoreRepository", id="infrahub-demo-core") assert clients.sync.store.get(key=response_id) @@ -431,6 +434,7 @@ async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_re "bfae43e8-5ebb-456c-a946-bf64e930710a", "9486cfce-87db-479d-ad73-07d80ba96a0f", ], + populate_store=False, ) assert len(repos) == 2 assert not clients.standard.store._store["CoreRepository"] @@ -441,7 +445,6 @@ async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_re "bfae43e8-5ebb-456c-a946-bf64e930710a", "9486cfce-87db-479d-ad73-07d80ba96a0f", ], - populate_store=True, ) assert len(clients.standard.store._store["CoreRepository"]) == 2 assert len(repos) == 2 @@ -452,6 +455,7 @@ async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_re "bfae43e8-5ebb-456c-a946-bf64e930710a", "9486cfce-87db-479d-ad73-07d80ba96a0f", ], + populate_store=False, ) assert len(repos) == 2 assert not clients.sync.store._store["CoreRepository"] @@ -462,7 +466,6 @@ async def test_method_filters_many(httpx_mock: HTTPXMock, clients, mock_query_re "bfae43e8-5ebb-456c-a946-bf64e930710a", "9486cfce-87db-479d-ad73-07d80ba96a0f", ], - populate_store=True, ) assert len(clients.sync.store._store["CoreRepository"]) == 2 assert len(repos) == 2