Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ This project uses [*towncrier*](https://towncrier.readthedocs.io/) and the chang

<!-- towncrier release notes start -->

## [1.12.0](https://github.com/opsmill/infrahub-sdk-python/tree/v1.12.0) - 2025-04-29

### Added

- Added the ability to convert the query response to InfrahubNode objects when using Python Transforms in the same way you can with Generators. ([#281](https://github.com/opsmill/infrahub-sdk-python/issues/281))
- Added a "branch" parameter to the client.clone() method to allow properly cloning a client that targets another branch.

## [1.11.1](https://github.com/opsmill/infrahub-sdk-python/tree/v1.11.1) - 2025-04-28

### Changed
Expand Down
8 changes: 4 additions & 4 deletions infrahub_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,9 +847,9 @@ async def process_non_batch() -> tuple[list[InfrahubNode], list[InfrahubNode]]:
self.store.set(node=node)
return nodes

def clone(self) -> InfrahubClient:
def clone(self, branch: str | None = None) -> InfrahubClient:
"""Return a cloned version of the client using the same configuration"""
return InfrahubClient(config=self.config)
return InfrahubClient(config=self.config.clone(branch=branch))

async def execute_graphql(
self,
Expand Down Expand Up @@ -1591,9 +1591,9 @@ def delete(self, kind: str | type[SchemaTypeSync], id: str, branch: str | None =
node = InfrahubNodeSync(client=self, schema=schema, branch=branch, data={"id": id})
node.delete()

def clone(self) -> InfrahubClientSync:
def clone(self, branch: str | None = None) -> InfrahubClientSync:
"""Return a cloned version of the client using the same configuration"""
return InfrahubClientSync(config=self.config)
return InfrahubClientSync(config=self.config.clone(branch=branch))

def execute_graphql(
self,
Expand Down
17 changes: 17 additions & 0 deletions infrahub_sdk/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from copy import deepcopy
from typing import Any

from pydantic import Field, field_validator, model_validator
Expand Down Expand Up @@ -158,3 +159,19 @@ def set_custom_recorder(cls, values: dict[str, Any]) -> dict[str, Any]:
elif values.get("recorder") == RecorderType.JSON and "custom_recorder" not in values:
values["custom_recorder"] = JSONRecorder()
return values

def clone(self, branch: str | None = None) -> Config:
config: dict[str, Any] = {
"default_branch": branch or self.default_branch,
"recorder": self.recorder,
"custom_recorder": self.custom_recorder,
"requester": self.requester,
"sync_requester": self.sync_requester,
"log": self.log,
}
covered_keys = list(config.keys())
for field in Config.model_fields.keys():
if field not in covered_keys:
config[field] = deepcopy(getattr(self, field))

return Config(**config)
8 changes: 7 additions & 1 deletion infrahub_sdk/ctl/cli_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from ..ctl.validate import app as validate_app
from ..exceptions import GraphQLError, ModuleImportError
from ..node import InfrahubNode
from ..protocols_generator.generator import CodeGenerator
from ..schema import MainSchemaTypesAll, SchemaRoot
from ..template import Jinja2Template
Expand Down Expand Up @@ -330,7 +331,12 @@ def transform(
console.print(f"[red]{exc.message}")
raise typer.Exit(1) from exc

transform = transform_class(client=client, branch=branch)
transform = transform_class(
client=client,
branch=branch,
infrahub_node=InfrahubNode,
convert_query_response=transform_config.convert_query_response,
)
# Get data
query_str = repository_config.get_query(name=transform.query).load_query()
data = asyncio.run(
Expand Down
4 changes: 2 additions & 2 deletions infrahub_sdk/ctl/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def run(
generator = generator_class(
query=generator_config.query,
client=client,
branch=branch,
branch=branch or "",
params=variables_dict,
convert_query_response=generator_config.convert_query_response,
infrahub_node=InfrahubNode,
Expand Down Expand Up @@ -91,7 +91,7 @@ async def run(
generator = generator_class(
query=generator_config.query,
client=client,
branch=branch,
branch=branch or "",
params=params,
convert_query_response=generator_config.convert_query_response,
infrahub_node=InfrahubNode,
Expand Down
78 changes: 12 additions & 66 deletions infrahub_sdk/generator.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,27 @@
from __future__ import annotations

import logging
import os
from abc import abstractmethod
from typing import TYPE_CHECKING

from infrahub_sdk.repository import GitRepoManager

from .exceptions import UninitializedError
from .operation import InfrahubOperation

if TYPE_CHECKING:
from .client import InfrahubClient
from .context import RequestContext
from .node import InfrahubNode
from .store import NodeStore


class InfrahubGenerator:
class InfrahubGenerator(InfrahubOperation):
"""Infrahub Generator class"""

def __init__(
self,
query: str,
client: InfrahubClient,
infrahub_node: type[InfrahubNode],
branch: str | None = None,
branch: str = "",
root_directory: str = "",
generator_instance: str = "",
params: dict | None = None,
Expand All @@ -33,37 +30,21 @@ def __init__(
request_context: RequestContext | None = None,
) -> None:
self.query = query
self.branch = branch
self.git: GitRepoManager | None = None

super().__init__(
client=client,
infrahub_node=infrahub_node,
convert_query_response=convert_query_response,
branch=branch,
root_directory=root_directory,
)

self.params = params or {}
self.root_directory = root_directory or os.getcwd()
self.generator_instance = generator_instance
self._init_client = client.clone()
self._init_client.config.default_branch = self._init_client.default_branch = self.branch_name
self._init_client.store._default_branch = self.branch_name
self._client: InfrahubClient | None = None
self._nodes: list[InfrahubNode] = []
self._related_nodes: list[InfrahubNode] = []
self.infrahub_node = infrahub_node
self.convert_query_response = convert_query_response
self.logger = logger if logger else logging.getLogger("infrahub.tasks")
self.request_context = request_context

@property
def store(self) -> NodeStore:
"""The store will be populated with nodes based on the query during the collection of data if activated"""
return self._init_client.store

@property
def nodes(self) -> list[InfrahubNode]:
"""Returns nodes collected and parsed during the data collection process if this feature is enables"""
return self._nodes

@property
def related_nodes(self) -> list[InfrahubNode]:
"""Returns nodes collected and parsed during the data collection process if this feature is enables"""
return self._related_nodes

@property
def subscribers(self) -> list[str] | None:
if self.generator_instance:
Expand All @@ -80,20 +61,6 @@ def client(self) -> InfrahubClient:
def client(self, value: InfrahubClient) -> None:
self._client = value

@property
def branch_name(self) -> str:
"""Return the name of the current git branch."""

if self.branch:
return self.branch

if not self.git:
self.git = GitRepoManager(self.root_directory)

self.branch = str(self.git.active_branch)

return self.branch

async def collect_data(self) -> dict:
"""Query the result of the GraphQL Query defined in self.query and return the result"""

Expand All @@ -119,27 +86,6 @@ async def run(self, identifier: str, data: dict | None = None) -> None:
) as self.client:
await self.generate(data=unpacked)

async def process_nodes(self, data: dict) -> None:
if not self.convert_query_response:
return

await self._init_client.schema.all(branch=self.branch_name)

for kind in data:
if kind in self._init_client.schema.cache[self.branch_name].nodes.keys():
for result in data[kind].get("edges", []):
node = await self.infrahub_node.from_graphql(
client=self._init_client, branch=self.branch_name, data=result
)
self._nodes.append(node)
await node._process_relationships(
node_data=result, branch=self.branch_name, related_nodes=self._related_nodes
)

for node in self._nodes + self._related_nodes:
if node.id:
self._init_client.store.set(node=node)

@abstractmethod
async def generate(self, data: dict) -> None:
"""Code to run the generator
Expand Down
80 changes: 80 additions & 0 deletions infrahub_sdk/operation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING

from .repository import GitRepoManager

if TYPE_CHECKING:
from . import InfrahubClient
from .node import InfrahubNode
from .store import NodeStore


class InfrahubOperation:
def __init__(
self,
client: InfrahubClient,
infrahub_node: type[InfrahubNode],
convert_query_response: bool,
branch: str,
root_directory: str,
):
self.branch = branch
self.convert_query_response = convert_query_response
self.root_directory = root_directory or os.getcwd()
self.infrahub_node = infrahub_node
self._nodes: list[InfrahubNode] = []
self._related_nodes: list[InfrahubNode] = []
self._init_client = client.clone(branch=self.branch_name)
self.git: GitRepoManager | None = None

@property
def branch_name(self) -> str:
"""Return the name of the current git branch."""

if self.branch:
return self.branch

if not hasattr(self, "git") or not self.git:
self.git = GitRepoManager(self.root_directory)

self.branch = str(self.git.active_branch)

return self.branch

@property
def store(self) -> NodeStore:
"""The store will be populated with nodes based on the query during the collection of data if activated"""
return self._init_client.store

@property
def nodes(self) -> list[InfrahubNode]:
"""Returns nodes collected and parsed during the data collection process if this feature is enabled"""
return self._nodes

Check warning on line 54 in infrahub_sdk/operation.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/operation.py#L54

Added line #L54 was not covered by tests

@property
def related_nodes(self) -> list[InfrahubNode]:
"""Returns nodes collected and parsed during the data collection process if this feature is enabled"""
return self._related_nodes

Check warning on line 59 in infrahub_sdk/operation.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/operation.py#L59

Added line #L59 was not covered by tests

async def process_nodes(self, data: dict) -> None:
if not self.convert_query_response:
return

await self._init_client.schema.all(branch=self.branch_name)

for kind in data:
if kind in self._init_client.schema.cache[self.branch_name].nodes.keys():
for result in data[kind].get("edges", []):
node = await self.infrahub_node.from_graphql(
client=self._init_client, branch=self.branch_name, data=result
)
self._nodes.append(node)
await node._process_relationships(
node_data=result, branch=self.branch_name, related_nodes=self._related_nodes
)

for node in self._nodes + self._related_nodes:
if node.id:
self._init_client.store.set(node=node)
12 changes: 12 additions & 0 deletions infrahub_sdk/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ class CoreMenu(CoreNode):
children: RelationshipManager


class CoreObjectComponentTemplate(CoreNode):
template_name: String


class CoreObjectTemplate(CoreNode):
template_name: String

Expand Down Expand Up @@ -205,6 +209,7 @@ class CoreWebhook(CoreNode):
name: String
event_type: Enum
branch_scope: Dropdown
node_kind: StringOptional
description: StringOptional
url: URL
validate_certificates: BooleanOptional
Expand Down Expand Up @@ -479,6 +484,7 @@ class CoreTransformJinja2(CoreTransformation):
class CoreTransformPython(CoreTransformation):
file_path: String
class_name: String
convert_query_response: BooleanOptional


class CoreUserValidator(CoreValidator):
Expand Down Expand Up @@ -625,6 +631,10 @@ class CoreMenuSync(CoreNodeSync):
children: RelationshipManagerSync


class CoreObjectComponentTemplateSync(CoreNodeSync):
template_name: String


class CoreObjectTemplateSync(CoreNodeSync):
template_name: String

Expand Down Expand Up @@ -676,6 +686,7 @@ class CoreWebhookSync(CoreNodeSync):
name: String
event_type: Enum
branch_scope: Dropdown
node_kind: StringOptional
description: StringOptional
url: URL
validate_certificates: BooleanOptional
Expand Down Expand Up @@ -950,6 +961,7 @@ class CoreTransformJinja2Sync(CoreTransformationSync):
class CoreTransformPythonSync(CoreTransformationSync):
file_path: String
class_name: String
convert_query_response: BooleanOptional


class CoreUserValidatorSync(CoreValidatorSync):
Expand Down
3 changes: 3 additions & 0 deletions infrahub_sdk/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
def default(cls) -> NoRecorder:
return cls()

def __eq__(self, other: object) -> bool:
return isinstance(other, NoRecorder)

Check warning on line 35 in infrahub_sdk/recorder.py

View check run for this annotation

Codecov / codecov/patch

infrahub_sdk/recorder.py#L35

Added line #L35 was not covered by tests


class JSONRecorder(BaseSettings):
model_config = SettingsConfigDict(env_prefix="INFRAHUB_JSON_RECORDER_")
Expand Down
4 changes: 4 additions & 0 deletions infrahub_sdk/schema/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ class InfrahubPythonTransformConfig(InfrahubRepositoryConfigElement):
name: str = Field(..., description="The name of the Transform")
file_path: Path = Field(..., description="The file within the repository with the transform code.")
class_name: str = Field(default="Transform", description="The name of the transform class to run.")
convert_query_response: bool = Field(
default=False,
description="Decide if the transform should convert the result of the GraphQL query to SDK InfrahubNode objects.",
)

def load_class(self, import_root: str | None = None, relative_path: str | None = None) -> type[InfrahubTransform]:
module = import_module(module_path=self.file_path, import_root=import_root, relative_path=relative_path)
Expand Down
Loading