Skip to content

Commit 13bd900

Browse files
committed
Merge stable into develop and fix merge conflict
2 parents b889934 + e91bd2c commit 13bd900

File tree

18 files changed

+338
-47
lines changed

18 files changed

+338
-47
lines changed

changelog/8.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make `infrahubctl transform` command set up the InfrahubTransform class with an InfrahubClient instance

infrahub_sdk/ctl/branch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def list_branch(_: str = CONFIG_PARAM) -> None:
3434

3535
logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)
3636

37-
client = await initialize_client()
37+
client = initialize_client()
3838
branches = await client.branch.all()
3939

4040
table = Table(title="List of all branches")
@@ -91,7 +91,7 @@ async def create(
9191

9292
logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)
9393

94-
client = await initialize_client()
94+
client = initialize_client()
9595
branch = await client.branch.create(branch_name=branch_name, description=description, sync_with_git=sync_with_git)
9696
console.print(f"Branch {branch_name!r} created successfully ({branch.id}).")
9797

@@ -103,7 +103,7 @@ async def delete(branch_name: str, _: str = CONFIG_PARAM) -> None:
103103

104104
logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)
105105

106-
client = await initialize_client()
106+
client = initialize_client()
107107
await client.branch.delete(branch_name=branch_name)
108108
console.print(f"Branch '{branch_name}' deleted successfully.")
109109

@@ -115,7 +115,7 @@ async def rebase(branch_name: str, _: str = CONFIG_PARAM) -> None:
115115

116116
logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)
117117

118-
client = await initialize_client()
118+
client = initialize_client()
119119
await client.branch.rebase(branch_name=branch_name)
120120
console.print(f"Branch '{branch_name}' rebased successfully.")
121121

@@ -127,7 +127,7 @@ async def merge(branch_name: str, _: str = CONFIG_PARAM) -> None:
127127

128128
logging.getLogger("infrahub_sdk").setLevel(logging.CRITICAL)
129129

130-
client = await initialize_client()
130+
client = initialize_client()
131131
await client.branch.merge(branch_name=branch_name)
132132
console.print(f"Branch '{branch_name}' merged successfully.")
133133

@@ -137,6 +137,6 @@ async def merge(branch_name: str, _: str = CONFIG_PARAM) -> None:
137137
async def validate(branch_name: str, _: str = CONFIG_PARAM) -> None:
138138
"""Validate if a branch has some conflict and is passing all the tests (NOT IMPLEMENTED YET)."""
139139

140-
client = await initialize_client()
140+
client = initialize_client()
141141
await client.branch.validate(branch_name=branch_name)
142142
console.print(f"Branch '{branch_name}' is valid.")

infrahub_sdk/ctl/cli_commands.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ async def run(
163163
if not hasattr(module, method):
164164
raise typer.Abort(f"Unable to Load the method {method} in the Python script at {script}")
165165

166-
client = await initialize_client(
166+
client = initialize_client(
167167
branch=branch, timeout=timeout, max_concurrent_execution=concurrent, identifier=module_name
168168
)
169169
func = getattr(module, method)
@@ -201,19 +201,35 @@ def render_jinja2_template(template_path: Path, variables: dict[str, str], data:
201201

202202

203203
def _run_transform(
204-
query: str,
204+
query_name: str,
205205
variables: dict[str, Any],
206-
transformer: Callable,
206+
transform_func: Callable,
207207
branch: str,
208208
debug: bool,
209209
repository_config: InfrahubRepositoryConfig,
210210
):
211+
"""
212+
Query GraphQL for the required data then run a transform on that data.
213+
214+
Args:
215+
query_name: Name of the query to load (e.g. tags_query)
216+
variables: Dictionary of variables used for graphql query
217+
transform_func: The function responsible for transforming data received from graphql
218+
branch: Name of the *infrahub* branch that should be queried for data
219+
debug: Prints debug info to the command line
220+
repository_config: Repository config object. This is used to load the graphql query from the repository.
221+
"""
211222
branch = get_branch(branch)
212223

213224
try:
214225
response = execute_graphql_query(
215-
query=query, variables_dict=variables, branch=branch, debug=debug, repository_config=repository_config
226+
query=query_name, variables_dict=variables, branch=branch, debug=debug, repository_config=repository_config
216227
)
228+
229+
# TODO: response is a dict and can't be printed to the console in this way.
230+
# if debug:
231+
# message = ("-" * 40, f"Response for GraphQL Query {query_name}", response, "-" * 40)
232+
# console.print("\n".join(message))
217233
except QueryNotFoundError as exc:
218234
console.print(f"[red]Unable to find query : {exc}")
219235
raise typer.Exit(1) from exc
@@ -228,10 +244,10 @@ def _run_transform(
228244
console.print("[yellow] you can specify a different branch with --branch")
229245
raise typer.Abort()
230246

231-
if asyncio.iscoroutinefunction(transformer.func):
232-
output = asyncio.run(transformer(response))
247+
if asyncio.iscoroutinefunction(transform_func):
248+
output = asyncio.run(transform_func(response))
233249
else:
234-
output = transformer(response)
250+
output = transform_func(response)
235251
return output
236252

237253

@@ -257,23 +273,28 @@ def render(
257273
list_jinja2_transforms(config=repository_config)
258274
return
259275

276+
# Load transform config
260277
try:
261278
transform_config = repository_config.get_jinja2_transform(name=transform_name)
262279
except KeyError as exc:
263280
console.print(f'[red]Unable to find "{transform_name}" in {config.INFRAHUB_REPO_CONFIG_FILE}')
264281
list_jinja2_transforms(config=repository_config)
265282
raise typer.Exit(1) from exc
266283

267-
transformer = functools.partial(render_jinja2_template, transform_config.template_path, variables_dict)
284+
# Construct transform function used to transform data returned from the API
285+
transform_func = functools.partial(render_jinja2_template, transform_config.template_path, variables_dict)
286+
287+
# Query GQL and run the transform
268288
result = _run_transform(
269-
query=transform_config.query,
289+
query_name=transform_config.query,
270290
variables=variables_dict,
271-
transformer=transformer,
291+
transform_func=transform_func,
272292
branch=branch,
273293
debug=debug,
274294
repository_config=repository_config,
275295
)
276296

297+
# Output data
277298
if out:
278299
write_to_file(Path(out), result)
279300
else:
@@ -302,31 +323,41 @@ def transform(
302323
list_transforms(config=repository_config)
303324
return
304325

305-
matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable
306-
307-
if not matched:
326+
# Load transform config
327+
try:
328+
matched = [transform for transform in repository_config.python_transforms if transform.name == transform_name] # pylint: disable=not-an-iterable
329+
if not matched:
330+
raise ValueError(f"{transform_name} does not exist")
331+
except ValueError as exc:
308332
console.print(f"[red]Unable to find requested transform: {transform_name}")
309333
list_transforms(config=repository_config)
310-
return
334+
raise typer.Exit(1) from exc
311335

312336
transform_config = matched[0]
313337

338+
# Get client
339+
client = initialize_client()
340+
341+
# Get python transform class instance
314342
try:
315-
transform_instance = get_transform_class_instance(transform_config=transform_config)
343+
transform = get_transform_class_instance(
344+
transform_config=transform_config,
345+
branch=branch,
346+
client=client,
347+
)
316348
except InfrahubTransformNotFoundError as exc:
317349
console.print(f"Unable to load {transform_name} from python_transforms")
318350
raise typer.Exit(1) from exc
319351

320-
transformer = functools.partial(transform_instance.transform)
321-
result = _run_transform(
322-
query=transform_instance.query,
323-
variables=variables_dict,
324-
transformer=transformer,
325-
branch=branch,
326-
debug=debug,
327-
repository_config=repository_config,
352+
# Get data
353+
query_str = repository_config.get_query(name=transform.query).load_query()
354+
data = asyncio.run(
355+
transform.client.execute_graphql(query=query_str, variables=variables_dict, branch_name=transform.branch_name)
328356
)
329357

358+
# Run Transform
359+
result = asyncio.run(transform.run(data=data))
360+
330361
json_string = ujson.dumps(result, indent=2, sort_keys=True)
331362
if out:
332363
write_to_file(Path(out), json_string)

infrahub_sdk/ctl/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from infrahub_sdk.ctl import config
66

77

8-
async def initialize_client(
8+
def initialize_client(
99
branch: Optional[str] = None,
1010
identifier: Optional[str] = None,
1111
timeout: Optional[int] = None,

infrahub_sdk/ctl/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def run(
4343
if param_key:
4444
identifier = param_key[0]
4545

46-
client = await initialize_client()
46+
client = initialize_client()
4747
if variables_dict:
4848
data = execute_graphql_query(
4949
query=generator_config.query,

infrahub_sdk/ctl/repository.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ async def add(
8888
},
8989
}
9090

91-
client = await initialize_client()
91+
client = initialize_client()
9292

9393
if username:
9494
credential = await client.create(kind="CorePasswordCredential", name=name, username=username, password=password)

infrahub_sdk/ctl/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ async def load(
115115

116116
schemas_data = load_yamlfile_from_disk_and_exit(paths=schemas, file_type=SchemaFile, console=console)
117117
schema_definition = "schema" if len(schemas_data) == 1 else "schemas"
118-
client = await initialize_client()
118+
client = initialize_client()
119119
validate_schema_content_and_exit(client=client, schemas=schemas_data)
120120

121121
start_time = time.time()
@@ -164,7 +164,7 @@ async def check(
164164
init_logging(debug=debug)
165165

166166
schemas_data = load_yamlfile_from_disk_and_exit(paths=schemas, file_type=SchemaFile, console=console)
167-
client = await initialize_client()
167+
client = initialize_client()
168168
validate_schema_content_and_exit(client=client, schemas=schemas_data)
169169

170170
success, response = await client.schema.check(schemas=[item.content for item in schemas_data], branch=branch)

infrahub_sdk/ctl/validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async def validate_schema(schema: Path, _: str = CONFIG_PARAM) -> None:
4040
console.print("[red]Invalid JSON file")
4141
raise typer.Exit(1) from exc
4242

43-
client = await initialize_client()
43+
client = initialize_client()
4444

4545
try:
4646
client.schema.validate(schema_data)

infrahub_sdk/transforms.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import importlib
55
import os
6+
import warnings
67
from abc import abstractmethod
78
from typing import TYPE_CHECKING, Any, Optional
89

@@ -25,33 +26,47 @@ class InfrahubTransform:
2526
query: str
2627
timeout: int = 10
2728

28-
def __init__(self, branch: str = "", root_directory: str = "", server_url: str = ""):
29+
def __init__(
30+
self,
31+
branch: str = "",
32+
root_directory: str = "",
33+
server_url: str = "",
34+
client: Optional[InfrahubClient] = None,
35+
):
2936
self.git: Repo
3037

3138
self.branch = branch
32-
3339
self.server_url = server_url or os.environ.get("INFRAHUB_URL", "http://127.0.0.1:8000")
3440
self.root_directory = root_directory or os.getcwd()
3541

36-
self.client: InfrahubClient
42+
self._client = client
3743

3844
if not self.name:
3945
self.name = self.__class__.__name__
4046

4147
if not self.query:
4248
raise ValueError("A query must be provided")
4349

50+
@property
51+
def client(self) -> InfrahubClient:
52+
if not self._client:
53+
self._client = InfrahubClient(address=self.server_url)
54+
55+
return self._client
56+
4457
@classmethod
4558
async def init(cls, client: Optional[InfrahubClient] = None, *args: Any, **kwargs: Any) -> InfrahubTransform:
4659
"""Async init method, If an existing InfrahubClient client hasn't been provided, one will be created automatically."""
60+
warnings.warn(
61+
f"{cls.__class__.__name__}.init has been deprecated and will be removed in Infrahub SDK 0.15.0 or the next major version",
62+
DeprecationWarning,
63+
stacklevel=1,
64+
)
65+
if client:
66+
kwargs["client"] = client
4767

4868
item = cls(*args, **kwargs)
4969

50-
if client:
51-
item.client = client
52-
else:
53-
item.client = InfrahubClient(address=item.server_url)
54-
5570
return item
5671

5772
@property
@@ -61,7 +76,7 @@ def branch_name(self) -> str:
6176
if self.branch:
6277
return self.branch
6378

64-
if not self.git:
79+
if not hasattr(self, "git") or not self.git:
6580
self.git = Repo(self.root_directory)
6681

6782
self.branch = str(self.git.active_branch)
@@ -79,10 +94,18 @@ async def collect_data(self) -> dict:
7994

8095
async def run(self, data: Optional[dict] = None) -> Any:
8196
"""Execute the transformation after collecting the data from the GraphQL query.
82-
The result of the check is determined based on the presence or not of ERROR log messages."""
97+
98+
The result of the check is determined based on the presence or not of ERROR log messages.
99+
100+
Args:
101+
data: The data on which to run the transform. Data will be queried from the API if not provided
102+
103+
Returns: Transformed data
104+
"""
83105

84106
if not data:
85107
data = await self.collect_data()
108+
86109
unpacked = data.get("data") or data
87110

88111
if asyncio.iscoroutinefunction(self.transform):
@@ -92,8 +115,20 @@ async def run(self, data: Optional[dict] = None) -> Any:
92115

93116

94117
def get_transform_class_instance(
95-
transform_config: InfrahubPythonTransformConfig, search_path: Optional[Path] = None
118+
transform_config: InfrahubPythonTransformConfig,
119+
search_path: Optional[Path] = None,
120+
branch: str = "",
121+
client: Optional[InfrahubClient] = None,
96122
) -> InfrahubTransform:
123+
"""Gets an instance of the InfrahubTransform class.
124+
125+
Args:
126+
transform_config: A config object with information required to find and load the transform.
127+
search_path: The path in which to search for a python file containing the transform. The current directory is
128+
assumed if not speicifed.
129+
branch: Infrahub branch which will be targeted in graphql query used to acquire data for transformation.
130+
client: InfrahubClient used to interact with infrahub API.
131+
"""
97132
if transform_config.file_path.is_absolute() or search_path is None:
98133
search_location = transform_config.file_path
99134
else:
@@ -108,7 +143,8 @@ def get_transform_class_instance(
108143
transform_class = getattr(module, transform_config.class_name)
109144

110145
# Create an instance of the class
111-
transform_instance = transform_class()
146+
transform_instance = transform_class(branch=branch, client=client)
147+
112148
except (FileNotFoundError, AttributeError) as exc:
113149
raise InfrahubTransformNotFoundError(name=transform_config.name) from exc
114150

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
---
2+
python_transforms:
3+
- name: tags_transform
4+
class_name: TagsTransform
5+
file_path: "tags_transform.py"
6+
7+
queries:
8+
- name: "tags_query"
9+
file_path: "tags_query.gql"
10+
11+
jinja2_transforms:
12+
- name: my-jinja2-transform # Unique name for your transform
13+
description: "short description" # (optional)
14+
query: "tags_query" # Name or ID of the GraphQLQuery
15+
template_path: "tags_tpl.j2" # Path to the main Jinja2 template

0 commit comments

Comments
 (0)