diff --git a/changelog/11.changed.md b/changelog/11.changed.md new file mode 100644 index 00000000..a72fa814 --- /dev/null +++ b/changelog/11.changed.md @@ -0,0 +1,7 @@ +Method client.get() can now return `None` instead of raising an exception when `raise_when_missing` is set to False + +```python +response = await clients.get( + kind="CoreRepository", name__value="infrahub-demo", raise_when_missing=False +) +``` \ No newline at end of file diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index 972c0d73..20e3457f 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -363,6 +363,41 @@ async def delete(self, kind: Union[str, type[SchemaType]], id: str, branch: Opti async def get( self, kind: type[SchemaType], + raise_when_missing: Literal[False], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> Optional[SchemaType]: ... + + @overload + async def get( + self, + kind: type[SchemaType], + raise_when_missing: Literal[True], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> SchemaType: ... + + @overload + async def get( + self, + kind: type[SchemaType], + raise_when_missing: bool = ..., at: Optional[Timestamp] = ..., branch: Optional[str] = ..., id: Optional[str] = ..., @@ -379,6 +414,41 @@ async def get( async def get( self, kind: str, + raise_when_missing: Literal[False], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> Optional[InfrahubNode]: ... + + @overload + async def get( + self, + kind: str, + raise_when_missing: Literal[True], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> InfrahubNode: ... + + @overload + async def get( + self, + kind: str, + raise_when_missing: bool = ..., at: Optional[Timestamp] = ..., branch: Optional[str] = ..., id: Optional[str] = ..., @@ -394,6 +464,7 @@ async def get( async def get( self, kind: Union[str, type[SchemaType]], + raise_when_missing: bool = True, at: Optional[Timestamp] = None, branch: Optional[str] = None, id: Optional[str] = None, @@ -404,7 +475,7 @@ async def get( fragment: bool = False, prefetch_relationships: bool = False, **kwargs: Any, - ) -> Union[InfrahubNode, SchemaType]: + ) -> Union[InfrahubNode, SchemaType, None]: branch = branch or self.default_branch schema = await self.schema.get(kind=kind, branch=branch) @@ -437,8 +508,10 @@ async def get( **filters, ) - if len(results) == 0: + if len(results) == 0 and raise_when_missing: raise NodeNotFoundError(branch_name=branch, node_type=schema.kind, identifier=filters) + if len(results) == 0 and not raise_when_missing: + return None if len(results) > 1: raise IndexError("More than 1 node returned") @@ -1737,6 +1810,41 @@ def filters( def get( self, kind: type[SchemaTypeSync], + raise_when_missing: Literal[False], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> Optional[SchemaTypeSync]: ... + + @overload + def get( + self, + kind: type[SchemaTypeSync], + raise_when_missing: Literal[True], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> SchemaTypeSync: ... + + @overload + def get( + self, + kind: type[SchemaTypeSync], + raise_when_missing: bool = ..., at: Optional[Timestamp] = ..., branch: Optional[str] = ..., id: Optional[str] = ..., @@ -1753,6 +1861,41 @@ def get( def get( self, kind: str, + raise_when_missing: Literal[False], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> Optional[InfrahubNodeSync]: ... + + @overload + def get( + self, + kind: str, + raise_when_missing: Literal[True], + at: Optional[Timestamp] = ..., + branch: Optional[str] = ..., + id: Optional[str] = ..., + hfid: Optional[list[str]] = ..., + include: Optional[list[str]] = ..., + exclude: Optional[list[str]] = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + **kwargs: Any, + ) -> InfrahubNodeSync: ... + + @overload + def get( + self, + kind: str, + raise_when_missing: bool = ..., at: Optional[Timestamp] = ..., branch: Optional[str] = ..., id: Optional[str] = ..., @@ -1768,6 +1911,7 @@ def get( def get( self, kind: Union[str, type[SchemaTypeSync]], + raise_when_missing: bool = True, at: Optional[Timestamp] = None, branch: Optional[str] = None, id: Optional[str] = None, @@ -1778,7 +1922,7 @@ def get( fragment: bool = False, prefetch_relationships: bool = False, **kwargs: Any, - ) -> Union[InfrahubNodeSync, SchemaTypeSync]: + ) -> Union[InfrahubNodeSync, SchemaTypeSync, None]: branch = branch or self.default_branch schema = self.schema.get(kind=kind, branch=branch) @@ -1811,8 +1955,10 @@ def get( **filters, ) - if len(results) == 0: + if len(results) == 0 and raise_when_missing: raise NodeNotFoundError(branch_name=branch, node_type=schema.kind, identifier=filters) + if len(results) == 0 and not raise_when_missing: + return None if len(results) > 1: raise IndexError("More than 1 node returned") diff --git a/tests/unit/sdk/conftest.py b/tests/unit/sdk/conftest.py index 63cf6356..5afe7ef2 100644 --- a/tests/unit/sdk/conftest.py +++ b/tests/unit/sdk/conftest.py @@ -65,6 +65,7 @@ def replace_annotation(annotation: str) -> str: "Optional[CoreNode]": "Optional[CoreNodeSync]", "Union[str, type[SchemaType]]": "Union[str, type[SchemaTypeSync]]", "Union[InfrahubNode, SchemaType]": "Union[InfrahubNodeSync, SchemaTypeSync]", + "Union[InfrahubNode, SchemaType, None]": "Union[InfrahubNodeSync, SchemaTypeSync, None]", "Union[list[InfrahubNode], list[SchemaType]]": "Union[list[InfrahubNodeSync], list[SchemaTypeSync]]", "InfrahubClient": "InfrahubClientSync", "InfrahubNode": "InfrahubNodeSync", @@ -104,6 +105,7 @@ def replace_annotation(annotation: str) -> str: "Optional[CoreNodeSync]": "Optional[CoreNode]", "Union[str, type[SchemaTypeSync]]": "Union[str, type[SchemaType]]", "Union[InfrahubNodeSync, SchemaTypeSync]": "Union[InfrahubNode, SchemaType]", + "Union[InfrahubNodeSync, SchemaTypeSync, None]": "Union[InfrahubNode, SchemaType, None]", "Union[list[InfrahubNodeSync], list[SchemaTypeSync]]": "Union[list[InfrahubNode], list[SchemaType]]", "InfrahubClientSync": "InfrahubClient", "InfrahubNodeSync": "InfrahubNode", diff --git a/tests/unit/sdk/test_client.py b/tests/unit/sdk/test_client.py index 50dc0b34..e80152d2 100644 --- a/tests/unit/sdk/test_client.py +++ b/tests/unit/sdk/test_client.py @@ -314,6 +314,20 @@ async def test_method_get_not_found(httpx_mock: HTTPXMock, clients, mock_query_r clients.sync.get(kind="CoreRepository", name__value="infrahub-demo-core") +@pytest.mark.parametrize("client_type", client_types) +async def test_method_get_not_found_none( + httpx_mock: HTTPXMock, clients, mock_query_repository_page1_empty, client_type +): # pylint: disable=unused-argument + if client_type == "standard": + response = await clients.standard.get( + kind="CoreRepository", name__value="infrahub-demo-core", raise_when_missing=False + ) + else: + response = clients.sync.get(kind="CoreRepository", name__value="infrahub-demo-core", raise_when_missing=False) + + assert response is None + + @pytest.mark.parametrize("client_type", client_types) async def test_method_get_found_many( httpx_mock: HTTPXMock,