Skip to content

Commit 5453e51

Browse files
author
Phillip Simonds
committed
Updates per peer review
1 parent 5ef90eb commit 5453e51

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed
File renamed without changes.

infrahub_sdk/ctl/cli_commands.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from infrahub_sdk import __version__ as sdk_version
1717
from infrahub_sdk import protocols as sdk_protocols
1818
from infrahub_sdk.async_typer import AsyncTyper
19-
from infrahub_sdk.client import InfrahubClient
19+
from infrahub_sdk.client import InfrahubClient, InfrahubClientSync
2020
from infrahub_sdk.ctl import config
2121
from infrahub_sdk.ctl.branch import app as branch_app
2222
from infrahub_sdk.ctl.check import run as run_check
@@ -196,7 +196,7 @@ def render_jinja2_template(template_path: Path, variables: dict[str, str], data:
196196

197197
def _run_transform(
198198
query_name: str,
199-
client: InfrahubClient,
199+
client: InfrahubClient | InfrahubClientSync,
200200
variables: dict[str, Any],
201201
transform_func: Callable,
202202
branch: str,
@@ -208,7 +208,7 @@ def _run_transform(
208208
209209
Args:
210210
query_name: Name of the query to load.
211-
client: InfrahubClient object used to execute a graphql query against the infrahub API
211+
client: client object used to execute a graphql query against the infrahub API
212212
variables: Dictionary of variables used for graphql query
213213
transform_func: A function used to transform the return from the graphql query into a different form
214214
branch: Name of the *infrahub* branch that should be queried for data
@@ -217,9 +217,14 @@ def _run_transform(
217217
"""
218218
branch = get_branch(branch)
219219
query_str = repository_config.get_query(name=query_name).load_query()
220+
query_dict = dict(query=query_str, variables=variables, branch_name=branch)
220221

221222
try:
222-
response = client.execute_graphql(query=query_str, variables=variables, branch_name=branch)
223+
if isinstance(client, InfrahubClient):
224+
response = asyncio.run(client.execute_graphql(**query_dict))
225+
else:
226+
response = client.execute_graphql(**query_dict)
227+
223228
if debug:
224229
message = ("-" * 40, f"Response for GraphQL Query {query_name}", response, "-" * 40)
225230
console.print("\n".join(message))
@@ -338,12 +343,9 @@ def transform(
338343

339344
transform_config = matched[0]
340345

341-
# Get Infrahub Client
342-
client = initialize_client_sync()
343-
344346
# Get python transform class instance
345347
try:
346-
transform = get_transform_class_instance(transform_config=transform_config, branch=branch, client=client)
348+
transform = get_transform_class_instance(transform_config=transform_config, branch=branch)
347349
except InfrahubTransformNotFoundError as exc:
348350
console.print(f"Unable to load {transform_name} from python_transforms")
349351
raise typer.Exit(1) from exc

infrahub_sdk/transforms.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import asyncio
44
import importlib
55
import os
6+
import warnings
67
from abc import abstractmethod
8+
from functools import cached_property
79
from typing import TYPE_CHECKING, Any, Optional
810

911
from git import Repo
@@ -41,9 +43,18 @@ def __init__(self, branch: str = "", root_directory: str = "", server_url: str =
4143
if not self.query:
4244
raise ValueError("A query must be provided")
4345

46+
@cached_property
47+
def client(self):
48+
return InfrahubClient(address=self.server_url)
49+
4450
@classmethod
4551
async def init(cls, client: Optional[InfrahubClient] = None, *args: Any, **kwargs: Any) -> InfrahubTransform:
4652
"""Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically."""
53+
warnings.warn(
54+
"InfrahubClient.init has been deprecated and will be removed in Infrahub SDK 0.14.0 or the next major version",
55+
DeprecationWarning,
56+
stacklevel=1,
57+
)
4758

4859
item = cls(*args, **kwargs)
4960

@@ -94,7 +105,7 @@ async def run(self, data: Optional[dict] = None) -> Any:
94105
def get_transform_class_instance(
95106
transform_config: InfrahubPythonTransformConfig,
96107
search_path: Optional[Path] = None,
97-
client: Optional[InfrahubClient] = None,
108+
# client: Optional[InfrahubClient] = None,
98109
branch: str = "",
99110
) -> InfrahubTransform:
100111
"""Gets an uninstantiated InfrahubTransform class.
@@ -120,7 +131,7 @@ def get_transform_class_instance(
120131
transform_class = getattr(module, transform_config.class_name)
121132

122133
# Create an instance of the class
123-
transform_instance = asyncio.run(transform_class.init(client=client, branch=branch))
134+
transform_instance = transform_class(branch=branch)
124135

125136
except (FileNotFoundError, AttributeError) as exc:
126137
raise InfrahubTransformNotFoundError(name=transform_config.name) from exc

0 commit comments

Comments
 (0)