@@ -163,7 +163,7 @@ async def run(
163163 if not hasattr (module , method ):
164164 raise typer .Abort (f"Unable to Load the method { method } in the Python script at { script } " )
165165
166- client = await initialize_client (
166+ client = initialize_client (
167167 branch = branch , timeout = timeout , max_concurrent_execution = concurrent , identifier = module_name
168168 )
169169 func = getattr (module , method )
@@ -201,19 +201,35 @@ def render_jinja2_template(template_path: Path, variables: dict[str, str], data:
201201
202202
203203def _run_transform (
204- query : str ,
204+ query_name : str ,
205205 variables : dict [str , Any ],
206- transformer : Callable ,
206+ transform_func : Callable ,
207207 branch : str ,
208208 debug : bool ,
209209 repository_config : InfrahubRepositoryConfig ,
210210):
211+ """
212+ Query GraphQL for the required data then run a transform on that data.
213+
214+ Args:
215+ query_name: Name of the query to load (e.g. tags_query)
216+ variables: Dictionary of variables used for graphql query
217+ transform_func: The function responsible for transforming data received from graphql
218+ branch: Name of the *infrahub* branch that should be queried for data
219+ debug: Prints debug info to the command line
220+ repository_config: Repository config object. This is used to load the graphql query from the repository.
221+ """
211222 branch = get_branch (branch )
212223
213224 try :
214225 response = execute_graphql_query (
215- query = query , variables_dict = variables , branch = branch , debug = debug , repository_config = repository_config
226+ query = query_name , variables_dict = variables , branch = branch , debug = debug , repository_config = repository_config
216227 )
228+
229+ # TODO: response is a dict and can't be printed to the console in this way.
230+ # if debug:
231+ # message = ("-" * 40, f"Response for GraphQL Query {query_name}", response, "-" * 40)
232+ # console.print("\n".join(message))
217233 except QueryNotFoundError as exc :
218234 console .print (f"[red]Unable to find query : { exc } " )
219235 raise typer .Exit (1 ) from exc
@@ -228,10 +244,10 @@ def _run_transform(
228244 console .print ("[yellow] you can specify a different branch with --branch" )
229245 raise typer .Abort ()
230246
231- if asyncio .iscoroutinefunction (transformer . func ):
232- output = asyncio .run (transformer (response ))
247+ if asyncio .iscoroutinefunction (transform_func ):
248+ output = asyncio .run (transform_func (response ))
233249 else :
234- output = transformer (response )
250+ output = transform_func (response )
235251 return output
236252
237253
@@ -257,23 +273,28 @@ def render(
257273 list_jinja2_transforms (config = repository_config )
258274 return
259275
276+ # Load transform config
260277 try :
261278 transform_config = repository_config .get_jinja2_transform (name = transform_name )
262279 except KeyError as exc :
263280 console .print (f'[red]Unable to find "{ transform_name } " in { config .INFRAHUB_REPO_CONFIG_FILE } ' )
264281 list_jinja2_transforms (config = repository_config )
265282 raise typer .Exit (1 ) from exc
266283
267- transformer = functools .partial (render_jinja2_template , transform_config .template_path , variables_dict )
284+ # Construct transform function used to transform data returned from the API
285+ transform_func = functools .partial (render_jinja2_template , transform_config .template_path , variables_dict )
286+
287+ # Query GQL and run the transform
268288 result = _run_transform (
269- query = transform_config .query ,
289+ query_name = transform_config .query ,
270290 variables = variables_dict ,
271- transformer = transformer ,
291+ transform_func = transform_func ,
272292 branch = branch ,
273293 debug = debug ,
274294 repository_config = repository_config ,
275295 )
276296
297+ # Output data
277298 if out :
278299 write_to_file (Path (out ), result )
279300 else :
@@ -302,31 +323,41 @@ def transform(
302323 list_transforms (config = repository_config )
303324 return
304325
305- matched = [transform for transform in repository_config .python_transforms if transform .name == transform_name ] # pylint: disable=not-an-iterable
306-
307- if not matched :
326+ # Load transform config
327+ try :
328+ matched = [transform for transform in repository_config .python_transforms if transform .name == transform_name ] # pylint: disable=not-an-iterable
329+ if not matched :
330+ raise ValueError (f"{ transform_name } does not exist" )
331+ except ValueError as exc :
308332 console .print (f"[red]Unable to find requested transform: { transform_name } " )
309333 list_transforms (config = repository_config )
310- return
334+ raise typer . Exit ( 1 ) from exc
311335
312336 transform_config = matched [0 ]
313337
338+ # Get client
339+ client = initialize_client ()
340+
341+ # Get python transform class instance
314342 try :
315- transform_instance = get_transform_class_instance (transform_config = transform_config )
343+ transform = get_transform_class_instance (
344+ transform_config = transform_config ,
345+ branch = branch ,
346+ client = client ,
347+ )
316348 except InfrahubTransformNotFoundError as exc :
317349 console .print (f"Unable to load { transform_name } from python_transforms" )
318350 raise typer .Exit (1 ) from exc
319351
320- transformer = functools .partial (transform_instance .transform )
321- result = _run_transform (
322- query = transform_instance .query ,
323- variables = variables_dict ,
324- transformer = transformer ,
325- branch = branch ,
326- debug = debug ,
327- repository_config = repository_config ,
352+ # Get data
353+ query_str = repository_config .get_query (name = transform .query ).load_query ()
354+ data = asyncio .run (
355+ transform .client .execute_graphql (query = query_str , variables = variables_dict , branch_name = transform .branch_name )
328356 )
329357
358+ # Run Transform
359+ result = asyncio .run (transform .run (data = data ))
360+
330361 json_string = ujson .dumps (result , indent = 2 , sort_keys = True )
331362 if out :
332363 write_to_file (Path (out ), json_string )
0 commit comments