Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions changelog/11.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Method client.get() can now return `None` instead of raising an exception when `raise_when_missing` is set to False

```python
response = await clients.get(
kind="CoreRepository", name__value="infrahub-demo", raise_when_missing=False
)
```
154 changes: 150 additions & 4 deletions infrahub_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,41 @@ async def delete(self, kind: Union[str, type[SchemaType]], id: str, branch: Opti
async def get(
self,
kind: type[SchemaType],
raise_when_missing: Literal[False],
at: Optional[Timestamp] = ...,
branch: Optional[str] = ...,
id: Optional[str] = ...,
hfid: Optional[list[str]] = ...,
include: Optional[list[str]] = ...,
exclude: Optional[list[str]] = ...,
populate_store: bool = ...,
fragment: bool = ...,
prefetch_relationships: bool = ...,
**kwargs: Any,
) -> Optional[SchemaType]: ...

@overload
async def get(
self,
kind: type[SchemaType],
raise_when_missing: Literal[True],
at: Optional[Timestamp] = ...,
branch: Optional[str] = ...,
id: Optional[str] = ...,
hfid: Optional[list[str]] = ...,
include: Optional[list[str]] = ...,
exclude: Optional[list[str]] = ...,
populate_store: bool = ...,
fragment: bool = ...,
prefetch_relationships: bool = ...,
**kwargs: Any,
) -> SchemaType: ...

@overload
async def get(
self,
kind: type[SchemaType],
raise_when_missing: bool = ...,
at: Optional[Timestamp] = ...,
branch: Optional[str] = ...,
id: Optional[str] = ...,
Expand All @@ -379,6 +414,41 @@ async def get(
async def get(
self,
kind: str,
raise_when_missing: Literal[False],
at: Optional[Timestamp] = ...,
branch: Optional[str] = ...,
id: Optional[str] = ...,
hfid: Optional[list[str]] = ...,
include: Optional[list[str]] = ...,
exclude: Optional[list[str]] = ...,
populate_store: bool = ...,
fragment: bool = ...,
prefetch_relationships: bool = ...,
**kwargs: Any,
) -> Optional[InfrahubNode]: ...

@overload
async def get(
self,
kind: str,
raise_when_missing: Literal[True],
at: Optional[Timestamp] = ...,
branch: Optional[str] = ...,
id: Optional[str] = ...,
hfid: Optional[list[str]] = ...,
include: Optional[list[str]] = ...,
exclude: Optional[list[str]] = ...,
populate_store: bool = ...,
fragment: bool = ...,
prefetch_relationships: bool = ...,
**kwargs: Any,
) -> InfrahubNode: ...

@overload
async def get(
self,
kind: str,
raise_when_missing: bool = ...,
at: Optional[Timestamp] = ...,
branch: Optional[str] = ...,
id: Optional[str] = ...,
Expand All @@ -394,6 +464,7 @@ async def get(
async def get(
self,
kind: Union[str, type[SchemaType]],
raise_when_missing: bool = True,
at: Optional[Timestamp] = None,
branch: Optional[str] = None,
id: Optional[str] = None,
Expand All @@ -404,7 +475,7 @@ async def get(
fragment: bool = False,
prefetch_relationships: bool = False,
**kwargs: Any,
) -> Union[InfrahubNode, SchemaType]:
) -> Union[InfrahubNode, SchemaType, None]:
branch = branch or self.default_branch
schema = await self.schema.get(kind=kind, branch=branch)

Expand Down Expand Up @@ -437,8 +508,10 @@ async def get(
**filters,
)

if len(results) == 0:
if len(results) == 0 and raise_when_missing:
raise NodeNotFoundError(branch_name=branch, node_type=schema.kind, identifier=filters)
if len(results) == 0 and not raise_when_missing:
return None
if len(results) > 1:
raise IndexError("More than 1 node returned")

Expand Down Expand Up @@ -1737,6 +1810,41 @@ def filters(
def get(
self,
kind: type[SchemaTypeSync],
raise_when_missing: Literal[False],
at: Optional[Timestamp] = ...,
branch: Optional[str] = ...,
id: Optional[str] = ...,
hfid: Optional[list[str]] = ...,
include: Optional[list[str]] = ...,
exclude: Optional[list[str]] = ...,
populate_store: bool = ...,
fragment: bool = ...,
prefetch_relationships: bool = ...,
**kwargs: Any,
) -> Optional[SchemaTypeSync]: ...

@overload
def get(
self,
kind: type[SchemaTypeSync],
raise_when_missing: Literal[True],
at: Optional[Timestamp] = ...,
branch: Optional[str] = ...,
id: Optional[str] = ...,
hfid: Optional[list[str]] = ...,
include: Optional[list[str]] = ...,
exclude: Optional[list[str]] = ...,
populate_store: bool = ...,
fragment: bool = ...,
prefetch_relationships: bool = ...,
**kwargs: Any,
) -> SchemaTypeSync: ...

@overload
def get(
self,
kind: type[SchemaTypeSync],
raise_when_missing: bool = ...,
at: Optional[Timestamp] = ...,
branch: Optional[str] = ...,
id: Optional[str] = ...,
Expand All @@ -1753,6 +1861,41 @@ def get(
def get(
self,
kind: str,
raise_when_missing: Literal[False],
at: Optional[Timestamp] = ...,
branch: Optional[str] = ...,
id: Optional[str] = ...,
hfid: Optional[list[str]] = ...,
include: Optional[list[str]] = ...,
exclude: Optional[list[str]] = ...,
populate_store: bool = ...,
fragment: bool = ...,
prefetch_relationships: bool = ...,
**kwargs: Any,
) -> Optional[InfrahubNodeSync]: ...

@overload
def get(
self,
kind: str,
raise_when_missing: Literal[True],
at: Optional[Timestamp] = ...,
branch: Optional[str] = ...,
id: Optional[str] = ...,
hfid: Optional[list[str]] = ...,
include: Optional[list[str]] = ...,
exclude: Optional[list[str]] = ...,
populate_store: bool = ...,
fragment: bool = ...,
prefetch_relationships: bool = ...,
**kwargs: Any,
) -> InfrahubNodeSync: ...

@overload
def get(
self,
kind: str,
raise_when_missing: bool = ...,
at: Optional[Timestamp] = ...,
branch: Optional[str] = ...,
id: Optional[str] = ...,
Expand All @@ -1768,6 +1911,7 @@ def get(
def get(
self,
kind: Union[str, type[SchemaTypeSync]],
raise_when_missing: bool = True,
at: Optional[Timestamp] = None,
branch: Optional[str] = None,
id: Optional[str] = None,
Expand All @@ -1778,7 +1922,7 @@ def get(
fragment: bool = False,
prefetch_relationships: bool = False,
**kwargs: Any,
) -> Union[InfrahubNodeSync, SchemaTypeSync]:
) -> Union[InfrahubNodeSync, SchemaTypeSync, None]:
branch = branch or self.default_branch
schema = self.schema.get(kind=kind, branch=branch)

Expand Down Expand Up @@ -1811,8 +1955,10 @@ def get(
**filters,
)

if len(results) == 0:
if len(results) == 0 and raise_when_missing:
raise NodeNotFoundError(branch_name=branch, node_type=schema.kind, identifier=filters)
if len(results) == 0 and not raise_when_missing:
return None
if len(results) > 1:
raise IndexError("More than 1 node returned")

Expand Down
2 changes: 2 additions & 0 deletions tests/unit/sdk/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def replace_annotation(annotation: str) -> str:
"Optional[CoreNode]": "Optional[CoreNodeSync]",
"Union[str, type[SchemaType]]": "Union[str, type[SchemaTypeSync]]",
"Union[InfrahubNode, SchemaType]": "Union[InfrahubNodeSync, SchemaTypeSync]",
"Union[InfrahubNode, SchemaType, None]": "Union[InfrahubNodeSync, SchemaTypeSync, None]",
"Union[list[InfrahubNode], list[SchemaType]]": "Union[list[InfrahubNodeSync], list[SchemaTypeSync]]",
"InfrahubClient": "InfrahubClientSync",
"InfrahubNode": "InfrahubNodeSync",
Expand Down Expand Up @@ -104,6 +105,7 @@ def replace_annotation(annotation: str) -> str:
"Optional[CoreNodeSync]": "Optional[CoreNode]",
"Union[str, type[SchemaTypeSync]]": "Union[str, type[SchemaType]]",
"Union[InfrahubNodeSync, SchemaTypeSync]": "Union[InfrahubNode, SchemaType]",
"Union[InfrahubNodeSync, SchemaTypeSync, None]": "Union[InfrahubNode, SchemaType, None]",
"Union[list[InfrahubNodeSync], list[SchemaTypeSync]]": "Union[list[InfrahubNode], list[SchemaType]]",
"InfrahubClientSync": "InfrahubClient",
"InfrahubNodeSync": "InfrahubNode",
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/sdk/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,20 @@ async def test_method_get_not_found(httpx_mock: HTTPXMock, clients, mock_query_r
clients.sync.get(kind="CoreRepository", name__value="infrahub-demo-core")


@pytest.mark.parametrize("client_type", client_types)
async def test_method_get_not_found_none(
httpx_mock: HTTPXMock, clients, mock_query_repository_page1_empty, client_type
): # pylint: disable=unused-argument
if client_type == "standard":
response = await clients.standard.get(
kind="CoreRepository", name__value="infrahub-demo-core", raise_when_missing=False
)
else:
response = clients.sync.get(kind="CoreRepository", name__value="infrahub-demo-core", raise_when_missing=False)

assert response is None


@pytest.mark.parametrize("client_type", client_types)
async def test_method_get_found_many(
httpx_mock: HTTPXMock,
Expand Down