Skip to content

Commit 2ea79fe

Browse files
authored
Merge pull request #296 from opsmill/pog-request-context-IFC-1340
Add support for request context
2 parents 826c188 + 03772f1 commit 2ea79fe

File tree

6 files changed

+133
-33
lines changed

6 files changed

+133
-33
lines changed

infrahub_sdk/client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
if TYPE_CHECKING:
6161
from types import TracebackType
6262

63+
from .context import RequestContext
64+
6365

6466
SchemaType = TypeVar("SchemaType", bound=CoreNode)
6567
SchemaTypeSync = TypeVar("SchemaTypeSync", bound=CoreNodeSync)
@@ -141,6 +143,7 @@ def __init__(
141143
self.identifier = self.config.identifier
142144
self.group_context: InfrahubGroupContext | InfrahubGroupContextSync
143145
self._initialize()
146+
self._request_context: RequestContext | None = None
144147

145148
def _initialize(self) -> None:
146149
"""Sets the properties for each version of the client"""
@@ -155,6 +158,14 @@ def _echo(self, url: str, query: str, variables: dict | None = None) -> None:
155158
if variables:
156159
print(f"VARIABLES:\n{ujson.dumps(variables, indent=4)}\n")
157160

161+
@property
162+
def request_context(self) -> RequestContext | None:
163+
return self._request_context
164+
165+
@request_context.setter
166+
def request_context(self, request_context: RequestContext) -> None:
167+
self._request_context = request_context
168+
158169
def start_tracking(
159170
self,
160171
identifier: str | None = None,

infrahub_sdk/context.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from __future__ import annotations
2+
3+
from pydantic import BaseModel, Field
4+
5+
6+
class ContextAccount(BaseModel):
7+
id: str = Field(..., description="The ID of the account")
8+
9+
10+
class RequestContext(BaseModel):
11+
"""The context can be used to override settings such as the account within mutations."""
12+
13+
account: ContextAccount | None = Field(default=None, description="Account tied to the context")

infrahub_sdk/generator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
if TYPE_CHECKING:
1313
from .client import InfrahubClient
14+
from .context import RequestContext
1415
from .node import InfrahubNode
1516
from .store import NodeStore
1617

@@ -29,6 +30,7 @@ def __init__(
2930
params: dict | None = None,
3031
convert_query_response: bool = False,
3132
logger: logging.Logger | None = None,
33+
request_context: RequestContext | None = None,
3234
) -> None:
3335
self.query = query
3436
self.branch = branch
@@ -44,6 +46,7 @@ def __init__(
4446
self.infrahub_node = infrahub_node
4547
self.convert_query_response = convert_query_response
4648
self.logger = logger if logger else logging.getLogger("infrahub.tasks")
49+
self.request_context = request_context
4750

4851
@property
4952
def store(self) -> NodeStore:

infrahub_sdk/node.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing_extensions import Self
2323

2424
from .client import InfrahubClient, InfrahubClientSync
25+
from .context import RequestContext
2526
from .schema import AttributeSchemaAPI, MainSchemaTypesAPI, RelationshipSchemaAPI
2627
from .types import Order
2728

@@ -769,6 +770,16 @@ def _init_attributes(self, data: dict | None = None) -> None:
769770
Attribute(name=attr_name, schema=attr_schema, data=attr_data),
770771
)
771772

773+
def _get_request_context(self, request_context: RequestContext | None = None) -> dict[str, Any] | None:
774+
if request_context:
775+
return request_context.model_dump(exclude_none=True)
776+
777+
client: InfrahubClient | InfrahubClientSync | None = getattr(self, "_client", None)
778+
if not client or not client.request_context:
779+
return None
780+
781+
return client.request_context.model_dump(exclude_none=True)
782+
772783
def _init_relationships(self, data: dict | None = None) -> None:
773784
pass
774785

@@ -797,7 +808,12 @@ def is_resource_pool(self) -> bool:
797808
def get_raw_graphql_data(self) -> dict | None:
798809
return self._data
799810

800-
def _generate_input_data(self, exclude_unmodified: bool = False, exclude_hfid: bool = False) -> dict[str, dict]: # noqa: C901
811+
def _generate_input_data( # noqa: C901
812+
self,
813+
exclude_unmodified: bool = False,
814+
exclude_hfid: bool = False,
815+
request_context: RequestContext | None = None,
816+
) -> dict[str, dict]:
801817
"""Generate a dictionary that represent the input data required by a mutation.
802818
803819
Returns:
@@ -875,7 +891,15 @@ def _generate_input_data(self, exclude_unmodified: bool = False, exclude_hfid: b
875891
elif self.hfid is not None and not exclude_hfid:
876892
data["hfid"] = self.hfid
877893

878-
return {"data": {"data": data}, "variables": variables, "mutation_variables": mutation_variables}
894+
mutation_payload = {"data": data}
895+
if context_data := self._get_request_context(request_context=request_context):
896+
mutation_payload["context"] = context_data
897+
898+
return {
899+
"data": mutation_payload,
900+
"variables": variables,
901+
"mutation_variables": mutation_variables,
902+
}
879903

880904
@staticmethod
881905
def _strip_unmodified_dict(data: dict, original_data: dict, variables: dict, item: str) -> None:
@@ -1132,8 +1156,11 @@ async def artifact_fetch(self, name: str) -> str | dict[str, Any]:
11321156
content = await self._client.object_store.get(identifier=artifact.storage_id.value) # type: ignore[attr-defined]
11331157
return content
11341158

1135-
async def delete(self, timeout: int | None = None) -> None:
1159+
async def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None:
11361160
input_data = {"data": {"id": self.id}}
1161+
if context_data := self._get_request_context(request_context=request_context):
1162+
input_data["context"] = context_data
1163+
11371164
mutation_query = {"ok": None}
11381165
query = Mutation(
11391166
mutation=f"{self._schema.kind}Delete",
@@ -1148,12 +1175,16 @@ async def delete(self, timeout: int | None = None) -> None:
11481175
)
11491176

11501177
async def save(
1151-
self, allow_upsert: bool = False, update_group_context: bool | None = None, timeout: int | None = None
1178+
self,
1179+
allow_upsert: bool = False,
1180+
update_group_context: bool | None = None,
1181+
timeout: int | None = None,
1182+
request_context: RequestContext | None = None,
11521183
) -> None:
11531184
if self._existing is False or allow_upsert is True:
1154-
await self.create(allow_upsert=allow_upsert, timeout=timeout)
1185+
await self.create(allow_upsert=allow_upsert, timeout=timeout, request_context=request_context)
11551186
else:
1156-
await self.update(timeout=timeout)
1187+
await self.update(timeout=timeout, request_context=request_context)
11571188

11581189
if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING:
11591190
update_group_context = True
@@ -1382,15 +1413,17 @@ async def _process_mutation_result(
13821413
await related_node.fetch(timeout=timeout)
13831414
setattr(self, rel_name, related_node)
13841415

1385-
async def create(self, allow_upsert: bool = False, timeout: int | None = None) -> None:
1416+
async def create(
1417+
self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
1418+
) -> None:
13861419
mutation_query = self._generate_mutation_query()
13871420

13881421
if allow_upsert:
1389-
input_data = self._generate_input_data(exclude_hfid=False)
1422+
input_data = self._generate_input_data(exclude_hfid=False, request_context=request_context)
13901423
mutation_name = f"{self._schema.kind}Upsert"
13911424
tracker = f"mutation-{str(self._schema.kind).lower()}-upsert"
13921425
else:
1393-
input_data = self._generate_input_data(exclude_hfid=True)
1426+
input_data = self._generate_input_data(exclude_hfid=True, request_context=request_context)
13941427
mutation_name = f"{self._schema.kind}Create"
13951428
tracker = f"mutation-{str(self._schema.kind).lower()}-create"
13961429
query = Mutation(
@@ -1408,8 +1441,10 @@ async def create(self, allow_upsert: bool = False, timeout: int | None = None) -
14081441
)
14091442
await self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout)
14101443

1411-
async def update(self, do_full_update: bool = False, timeout: int | None = None) -> None:
1412-
input_data = self._generate_input_data(exclude_unmodified=not do_full_update)
1444+
async def update(
1445+
self, do_full_update: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
1446+
) -> None:
1447+
input_data = self._generate_input_data(exclude_unmodified=not do_full_update, request_context=request_context)
14131448
mutation_query = self._generate_mutation_query()
14141449
mutation_name = f"{self._schema.kind}Update"
14151450

@@ -1648,8 +1683,11 @@ def artifact_fetch(self, name: str) -> str | dict[str, Any]:
16481683
content = self._client.object_store.get(identifier=artifact.storage_id.value) # type: ignore[attr-defined]
16491684
return content
16501685

1651-
def delete(self, timeout: int | None = None) -> None:
1686+
def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None:
16521687
input_data = {"data": {"id": self.id}}
1688+
if context_data := self._get_request_context(request_context=request_context):
1689+
input_data["context"] = context_data
1690+
16531691
mutation_query = {"ok": None}
16541692
query = Mutation(
16551693
mutation=f"{self._schema.kind}Delete",
@@ -1664,12 +1702,16 @@ def delete(self, timeout: int | None = None) -> None:
16641702
)
16651703

16661704
def save(
1667-
self, allow_upsert: bool = False, update_group_context: bool | None = None, timeout: int | None = None
1705+
self,
1706+
allow_upsert: bool = False,
1707+
update_group_context: bool | None = None,
1708+
timeout: int | None = None,
1709+
request_context: RequestContext | None = None,
16681710
) -> None:
16691711
if self._existing is False or allow_upsert is True:
1670-
self.create(allow_upsert=allow_upsert, timeout=timeout)
1712+
self.create(allow_upsert=allow_upsert, timeout=timeout, request_context=request_context)
16711713
else:
1672-
self.update(timeout=timeout)
1714+
self.update(timeout=timeout, request_context=request_context)
16731715

16741716
if update_group_context is None and self._client.mode == InfrahubClientMode.TRACKING:
16751717
update_group_context = True
@@ -1893,15 +1935,17 @@ def _process_mutation_result(
18931935
related_node.fetch(timeout=timeout)
18941936
setattr(self, rel_name, related_node)
18951937

1896-
def create(self, allow_upsert: bool = False, timeout: int | None = None) -> None:
1938+
def create(
1939+
self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
1940+
) -> None:
18971941
mutation_query = self._generate_mutation_query()
18981942

18991943
if allow_upsert:
1900-
input_data = self._generate_input_data(exclude_hfid=False)
1944+
input_data = self._generate_input_data(exclude_hfid=False, request_context=request_context)
19011945
mutation_name = f"{self._schema.kind}Upsert"
19021946
tracker = f"mutation-{str(self._schema.kind).lower()}-upsert"
19031947
else:
1904-
input_data = self._generate_input_data(exclude_hfid=True)
1948+
input_data = self._generate_input_data(exclude_hfid=True, request_context=request_context)
19051949
mutation_name = f"{self._schema.kind}Create"
19061950
tracker = f"mutation-{str(self._schema.kind).lower()}-create"
19071951
query = Mutation(
@@ -1920,8 +1964,10 @@ def create(self, allow_upsert: bool = False, timeout: int | None = None) -> None
19201964
)
19211965
self._process_mutation_result(mutation_name=mutation_name, response=response, timeout=timeout)
19221966

1923-
def update(self, do_full_update: bool = False, timeout: int | None = None) -> None:
1924-
input_data = self._generate_input_data(exclude_unmodified=not do_full_update)
1967+
def update(
1968+
self, do_full_update: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
1969+
) -> None:
1970+
input_data = self._generate_input_data(exclude_unmodified=not do_full_update, request_context=request_context)
19251971
mutation_query = self._generate_mutation_query()
19261972
mutation_name = f"{self._schema.kind}Update"
19271973

infrahub_sdk/protocols_base.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
if TYPE_CHECKING:
66
import ipaddress
77

8+
from .context import RequestContext
89
from .schema import MainSchemaTypes
910

1011

@@ -169,13 +170,23 @@ def extract(self, params: dict[str, str]) -> dict[str, Any]: ...
169170

170171
@runtime_checkable
171172
class CoreNode(CoreNodeBase, Protocol):
172-
async def save(self, allow_upsert: bool = False, update_group_context: bool | None = None) -> None: ...
173+
async def save(
174+
self,
175+
allow_upsert: bool = False,
176+
update_group_context: bool | None = None,
177+
timeout: int | None = None,
178+
request_context: RequestContext | None = None,
179+
) -> None: ...
173180

174-
async def delete(self) -> None: ...
181+
async def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None: ...
175182

176-
async def update(self, do_full_update: bool) -> None: ...
183+
async def update(
184+
self, do_full_update: bool, timeout: int | None = None, request_context: RequestContext | None = None
185+
) -> None: ...
177186

178-
async def create(self, allow_upsert: bool = False) -> None: ...
187+
async def create(
188+
self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
189+
) -> None: ...
179190

180191
async def add_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ...
181192

@@ -184,13 +195,23 @@ async def remove_relationships(self, relation_to_update: str, related_nodes: lis
184195

185196
@runtime_checkable
186197
class CoreNodeSync(CoreNodeBase, Protocol):
187-
def save(self, allow_upsert: bool = False, update_group_context: bool | None = None) -> None: ...
188-
189-
def delete(self) -> None: ...
190-
191-
def update(self, do_full_update: bool) -> None: ...
192-
193-
def create(self, allow_upsert: bool = False) -> None: ...
198+
def save(
199+
self,
200+
allow_upsert: bool = False,
201+
update_group_context: bool | None = None,
202+
timeout: int | None = None,
203+
request_context: RequestContext | None = None,
204+
) -> None: ...
205+
206+
def delete(self, timeout: int | None = None, request_context: RequestContext | None = None) -> None: ...
207+
208+
def update(
209+
self, do_full_update: bool, timeout: int | None = None, request_context: RequestContext | None = None
210+
) -> None: ...
211+
212+
def create(
213+
self, allow_upsert: bool = False, timeout: int | None = None, request_context: RequestContext | None = None
214+
) -> None: ...
194215

195216
def add_relationships(self, relation_to_update: str, related_nodes: list[str]) -> None: ...
196217

tests/unit/sdk/test_client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77
from infrahub_sdk.exceptions import NodeNotFoundError
88
from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync
99

10-
async_client_methods = [method for method in dir(InfrahubClient) if not method.startswith("_")]
11-
sync_client_methods = [method for method in dir(InfrahubClientSync) if not method.startswith("_")]
10+
excluded_methods = ["request_context"]
11+
12+
async_client_methods = [
13+
method for method in dir(InfrahubClient) if not method.startswith("_") and method not in excluded_methods
14+
]
15+
sync_client_methods = [
16+
method for method in dir(InfrahubClientSync) if not method.startswith("_") and method not in excluded_methods
17+
]
1218

1319
batch_client_types = [
1420
("standard", False),

0 commit comments

Comments
 (0)