1616from infrahub_sdk import __version__ as sdk_version
1717from infrahub_sdk import protocols as sdk_protocols
1818from infrahub_sdk .async_typer import AsyncTyper
19- from infrahub_sdk .client import InfrahubClient
19+ from infrahub_sdk .client import InfrahubClient , InfrahubClientSync
2020from infrahub_sdk .ctl import config
2121from infrahub_sdk .ctl .branch import app as branch_app
2222from infrahub_sdk .ctl .check import run as run_check
@@ -196,7 +196,7 @@ def render_jinja2_template(template_path: Path, variables: dict[str, str], data:
196196
197197def _run_transform (
198198 query_name : str ,
199- client : InfrahubClient ,
199+ client : InfrahubClient | InfrahubClientSync ,
200200 variables : dict [str , Any ],
201201 transform_func : Callable ,
202202 branch : str ,
@@ -208,7 +208,7 @@ def _run_transform(
208208
209209 Args:
210210 query_name: Name of the query to load.
211- client: InfrahubClient object used to execute a graphql query against the infrahub API
211+ client: client object used to execute a graphql query against the infrahub API
212212 variables: Dictionary of variables used for graphql query
213213 transform_func: A function used to transform the return from the graphql query into a different form
214214 branch: Name of the *infrahub* branch that should be queried for data
@@ -217,9 +217,14 @@ def _run_transform(
217217 """
218218 branch = get_branch (branch )
219219 query_str = repository_config .get_query (name = query_name ).load_query ()
220+ query_dict = dict (query = query_str , variables = variables , branch_name = branch )
220221
221222 try :
222- response = client .execute_graphql (query = query_str , variables = variables , branch_name = branch )
223+ if isinstance (client , InfrahubClient ):
224+ response = asyncio .run (client .execute_graphql (** query_dict ))
225+ else :
226+ response = client .execute_graphql (** query_dict )
227+
223228 if debug :
224229 message = ("-" * 40 , f"Response for GraphQL Query { query_name } " , response , "-" * 40 )
225230 console .print ("\n " .join (message ))
@@ -338,12 +343,9 @@ def transform(
338343
339344 transform_config = matched [0 ]
340345
341- # Get Infrahub Client
342- client = initialize_client_sync ()
343-
344346 # Get python transform class instance
345347 try :
346- transform = get_transform_class_instance (transform_config = transform_config , branch = branch , client = client )
348+ transform = get_transform_class_instance (transform_config = transform_config , branch = branch )
347349 except InfrahubTransformNotFoundError as exc :
348350 console .print (f"Unable to load { transform_name } from python_transforms" )
349351 raise typer .Exit (1 ) from exc
0 commit comments