Skip to content

Commit df19431

Browse files
committed
Add support for request context
1 parent e59e6a5 commit df19431

File tree

5 files changed

+101
-22
lines changed

5 files changed

+101
-22
lines changed

infrahub_sdk/client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@
5858
if TYPE_CHECKING:
5959
from types import TracebackType
6060

61+
from .context import RequestContext
62+
6163

6264
SchemaType = TypeVar("SchemaType", bound=CoreNode)
6365
SchemaTypeSync = TypeVar("SchemaTypeSync", bound=CoreNodeSync)
@@ -139,6 +141,7 @@ def __init__(
139141
self.identifier = self.config.identifier
140142
self.group_context: InfrahubGroupContext | InfrahubGroupContextSync
141143
self._initialize()
144+
self._request_context: RequestContext | None = None
142145

143146
def _initialize(self) -> None:
144147
"""Sets the properties for each version of the client"""
@@ -153,6 +156,14 @@ def _echo(self, url: str, query: str, variables: dict | None = None) -> None:
153156
if variables:
154157
print(f"VARIABLES:\n{ujson.dumps(variables, indent=4)}\n")
155158

159+
@property
160+
def request_context(self) -> RequestContext | None:
161+
return self._request_context
162+
163+
@request_context.setter
164+
def request_context(self, request_context: RequestContext) -> None:
165+
self._request_context = request_context
166+
156167
def start_tracking(
157168
self,
158169
identifier: str | None = None,

infrahub_sdk/context.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from __future__ import annotations
2+
3+
from pydantic import BaseModel, Field
4+
5+
6+
class ContextAccount(BaseModel):
7+
id: str = Field(..., description="The ID of the account")
8+
9+
10+
class RequestContext(BaseModel):
11+
"""The context can be used to override settings such as the account within mutations."""
12+
13+
account: ContextAccount | None = Field(default=None, description="Account tied to the context")

infrahub_sdk/generator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
if TYPE_CHECKING:
1313
from .client import InfrahubClient
14+
from .context import RequestContext
1415
from .node import InfrahubNode
1516
from .store import NodeStore
1617

@@ -29,6 +30,7 @@ def __init__(
2930
params: dict | None = None,
3031
convert_query_response: bool = False,
3132
logger: logging.Logger | None = None,
33+
request_context: RequestContext | None = None,
3234
) -> None:
3335
self.query = query
3436
self.branch = branch
@@ -44,6 +46,7 @@ def __init__(
4446
self.infrahub_node = infrahub_node
4547
self.convert_query_response = convert_query_response
4648
self.logger = logger if logger else logging.getLogger("infrahub.tasks")
49+
self.request_context = request_context
4750

4851
@property
4952
def store(self) -> NodeStore:

infrahub_sdk/node.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing_extensions import Self
2323

2424
from .client import InfrahubClient, InfrahubClientSync
25+
from .context import RequestContext
2526
from .schema import AttributeSchemaAPI, MainSchemaTypesAPI, RelationshipSchemaAPI
2627
from .types import Order
2728

@@ -766,6 +767,16 @@ def _init_attributes(self, data: dict | None = None) -> None:
766767
Attribute(name=attr_name, schema=attr_schema, data=attr_data),
767768
)
768769

770+
def _get_request_context(self, request_context: RequestContext | None = None) -> dict[str, Any] | None:
771+
if request_context:
772+
return request_context.model_dump(exclude_none=True)
773+
774+
client: InfrahubClient | InfrahubClientSync | None = getattr(self, "_client", None)
775+
if not client or not client.request_context:
776+
return None
777+
778+
return client.request_context.model_dump(exclude_none=True)
779+
769780
def _init_relationships(self, data: dict | None = None) -> None:
770781
pass
771782

@@ -794,7 +805,12 @@ def is_resource_pool(self) -> bool:
794805
def get_raw_graphql_data(self) -> dict | None:
795806
return self._data
796807

797-
def _generate_input_data(self, exclude_unmodified: bool = False, exclude_hfid: bool = False) -> dict[str, dict]: # noqa: C901
808+
def _generate_input_data( # noqa: C901
809+
self,
810+
exclude_unmodified: bool = False,
811+
exclude_hfid: bool = False,
812+
request_context: RequestContext | None = None,
813+
) -> dict[str, dict]:
798814
"""Generate a dictionary that represent the input data required by a mutation.
799815
800816
Returns:
@@ -872,7 +888,15 @@ def _generate_input_data(self, exclude_unmodified: bool = False, exclude_hfid: b
872888
elif self.hfid is not None and not exclude_hfid:
873889
data["hfid"] = self.hfid
874890

875-
return {"data": {"data": data}, "variables": variables, "mutation_variables": mutation_variables}
891+
mutation_payload = {"data": data}
892+
if context_data := self._get_request_context(request_context=request_context):
893+
mutation_payload["context"] = context_data
894+
895+
return {
896+
"data": mutation_payload,
897+
"variables": variables,
898+
"mutation_variables": mutation_variables,
899+
}
876900

877901
@staticmethod
878902
def _strip_unmodified_dict(data: dict, original_data: dict, variables: dict, item: str) -> None:
@@ -1129,8 +1153,11 @@ async def artifact_fetch(self, name: str) -> str | dict[str, Any]:
11291153
content = await self._client.object_store.get(identifier=artifact.storage_id.value) # type: ignore[attr-defined]
11301154
return content
11311155

1132-
async def delete(self, timeout: int | None = None) -> None:
1156+
async def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None:
11331157
input_data = {"data": {"id": self.id}}
1158+
if context_data := self._get_request_context(request_context=request_context):
1159+
input_data["context"] = context_data
1160+
11341161
mutation_query = {"ok": None}
11351162
query = Mutation(
11361163
mutation=f"{self._schema.kind}Delete",
@@ -1145,12 +1172,16 @@ async def delete(self, timeout: int | None = None) -> None:
11451172
)
11461173

11471174
async def save(
1148-
self, allow_upsert: bool = False, update_group_context: bool | None = None, timeout: int | None = None
1175+
self,
1176+
allow_upsert: bool = False,
1177+
update_group_context: bool | None = None,
1178+
timeout: int | None = None,
1179+
request_context: RequestContext | None = None,
11491180
) -> None:
11501181
if self._existing is False or allow_upsert is True:
1151-
await self.create(allow_upsert=allow_upsert, timeout=timeout)
1182+
await self.create(allow_upsert=allow_upsert, timeout=timeout, request_context=request_context)
11521183
else:
1153-
await self.update(timeout=timeout)
1184+
await self.update(timeout=timeout, request_context=request_context)
11541185

11551186
if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING:
11561187
update_group_context = True
@@ -1379,15 +1410,17 @@ async def _process_mutation_result(
13791410
await related_node.fetch(timeout=timeout)
13801411
setattr(self, rel_name, related_node)
13811412

1382-
async def create(self, allow_upsert: bool = False, timeout: int | None = None) -> None:
1413+
async def create(
1414+
self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
1415+
) -> None:
13831416
mutation_query = self._generate_mutation_query()
13841417

13851418
if allow_upsert:
1386-
input_data = self._generate_input_data(exclude_hfid=False)
1419+
input_data = self._generate_input_data(exclude_hfid=False, request_context=request_context)
13871420
mutation_name = f"{self._schema.kind}Upsert"
13881421
tracker = f"mutation-{str(self._schema.kind).lower()}-upsert"
13891422
else:
1390-
input_data = self._generate_input_data(exclude_hfid=True)
1423+
input_data = self._generate_input_data(exclude_hfid=True, request_context=request_context)
13911424
mutation_name = f"{self._schema.kind}Create"
13921425
tracker = f"mutation-{str(self._schema.kind).lower()}-create"
13931426
query = Mutation(
@@ -1405,8 +1438,10 @@ async def create(self, allow_upsert: bool = False, timeout: int | None = None) -
14051438
)
14061439
await self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout)
14071440

1408-
async def update(self, do_full_update: bool = False, timeout: int | None = None) -> None:
1409-
input_data = self._generate_input_data(exclude_unmodified=not do_full_update)
1441+
async def update(
1442+
self, do_full_update: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
1443+
) -> None:
1444+
input_data = self._generate_input_data(exclude_unmodified=not do_full_update, request_context=request_context)
14101445
mutation_query = self._generate_mutation_query()
14111446
mutation_name = f"{self._schema.kind}Update"
14121447

@@ -1645,8 +1680,11 @@ def artifact_fetch(self, name: str) -> str | dict[str, Any]:
16451680
content = self._client.object_store.get(identifier=artifact.storage_id.value) # type: ignore[attr-defined]
16461681
return content
16471682

1648-
def delete(self, timeout: int | None = None) -> None:
1683+
def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None:
16491684
input_data = {"data": {"id": self.id}}
1685+
if context_data := self._get_request_context(request_context=request_context):
1686+
input_data["context"] = context_data
1687+
16501688
mutation_query = {"ok": None}
16511689
query = Mutation(
16521690
mutation=f"{self._schema.kind}Delete",
@@ -1661,12 +1699,16 @@ def delete(self, timeout: int | None = None) -> None:
16611699
)
16621700

16631701
def save(
1664-
self, allow_upsert: bool = False, update_group_context: bool | None = None, timeout: int | None = None
1702+
self,
1703+
allow_upsert: bool = False,
1704+
update_group_context: bool | None = None,
1705+
timeout: int | None = None,
1706+
request_context: RequestContext | None = None,
16651707
) -> None:
16661708
if self._existing is False or allow_upsert is True:
1667-
self.create(allow_upsert=allow_upsert, timeout=timeout)
1709+
self.create(allow_upsert=allow_upsert, timeout=timeout, request_context=request_context)
16681710
else:
1669-
self.update(timeout=timeout)
1711+
self.update(timeout=timeout, request_context=request_context)
16701712

16711713
if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING:
16721714
update_group_context = True
@@ -1890,15 +1932,17 @@ def _process_mutation_result(
18901932
related_node.fetch(timeout=timeout)
18911933
setattr(self, rel_name, related_node)
18921934

1893-
def create(self, allow_upsert: bool = False, timeout: int | None = None) -> None:
1935+
def create(
1936+
self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
1937+
) -> None:
18941938
mutation_query = self._generate_mutation_query()
18951939

18961940
if allow_upsert:
1897-
input_data = self._generate_input_data(exclude_hfid=False)
1941+
input_data = self._generate_input_data(exclude_hfid=False, request_context=request_context)
18981942
mutation_name = f"{self._schema.kind}Upsert"
18991943
tracker = f"mutation-{str(self._schema.kind).lower()}-upsert"
19001944
else:
1901-
input_data = self._generate_input_data(exclude_hfid=True)
1945+
input_data = self._generate_input_data(exclude_hfid=True, request_context=request_context)
19021946
mutation_name = f"{self._schema.kind}Create"
19031947
tracker = f"mutation-{str(self._schema.kind).lower()}-create"
19041948
query = Mutation(
@@ -1917,8 +1961,10 @@ def create(self, allow_upsert: bool = False, timeout: int | None = None) -> None
19171961
)
19181962
self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout)
19191963

1920-
def update(self, do_full_update: bool = False, timeout: int | None = None) -> None:
1921-
input_data = self._generate_input_data(exclude_unmodified=not do_full_update)
1964+
def update(
1965+
self, do_full_update: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
1966+
) -> None:
1967+
input_data = self._generate_input_data(exclude_unmodified=not do_full_update, request_context=request_context)
19221968
mutation_query = self._generate_mutation_query()
19231969
mutation_name = f"{self._schema.kind}Update"
19241970

tests/unit/sdk/test_client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77
from infrahub_sdk.exceptions import NodeNotFoundError
88
from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync
99

10-
async_client_methods = [method for method in dir(InfrahubClient) if not method.startswith("_")]
11-
sync_client_methods = [method for method in dir(InfrahubClientSync) if not method.startswith("_")]
10+
excluded_methods = ["request_context"]
11+
12+
async_client_methods = [
13+
method for method in dir(InfrahubClient) if not method.startswith("_") and method not in excluded_methods
14+
]
15+
sync_client_methods = [
16+
method for method in dir(InfrahubClientSync) if not method.startswith("_") and method not in excluded_methods
17+
]
1218

1319
batch_client_types = [
1420
("standard", False),

0 commit comments

Comments
 (0)