diff --git a/.vale/styles/spelling-exceptions.txt b/.vale/styles/spelling-exceptions.txt index d018438b..904fc251 100644 --- a/.vale/styles/spelling-exceptions.txt +++ b/.vale/styles/spelling-exceptions.txt @@ -85,6 +85,7 @@ namespace namespaces Nautobot Netbox +Netutils Newsfragment Nornir npm diff --git a/CHANGELOG.md b/CHANGELOG.md index b8e9608e..d4f8a7d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,23 @@ This project uses [*towncrier*](https://towncrier.readthedocs.io/) and the chang +## [1.10.0](https://github.com/opsmill/infrahub-sdk-python/tree/v1.10.0) - 2025-04-01 + +### Deprecated + +- The method `get_by_hfid` on the object Store has been deprecated, use `get(key=[hfid])` instead +- Using a Store without specifying a default branch is now deprecated and will be removed in a future version. + +### Added + +- All nodes generated by the SDK will now be assigned an `internal_id` (`_internal_id`). This ID has no significance outside of the SDK. +- Jinja2 templating has been refactored to allow for filters within Infrahub. Builtin filters as well as those from Netutils are available. +- The object store has been refactored to support more use cases in the future and it now properly support branches. + +### Fixed + +- Fix node processing, when using fragment with `prefetch_relationships`. ([#331](https://github.com/opsmill/infrahub-sdk-python/issues/331)) + ## [1.9.2](https://github.com/opsmill/infrahub-sdk-python/tree/v1.9.2) - 2025-03-26 ### Changed diff --git a/changelog/331.fixed.md b/changelog/331.fixed.md deleted file mode 100644 index e67d6182..00000000 --- a/changelog/331.fixed.md +++ /dev/null @@ -1 +0,0 @@ -Fix node processing, when using fragment with `prefetch_relationships`. \ No newline at end of file diff --git a/docs/_templates/sdk_template_reference.j2 b/docs/_templates/sdk_template_reference.j2 new file mode 100644 index 00000000..dcd59d78 --- /dev/null +++ b/docs/_templates/sdk_template_reference.j2 @@ -0,0 +1,27 @@ +--- +title: Python SDK Templating +--- +Filters can be used when defining [computed attributes](https://docs.infrahub.app/guides/computed-attributes) or [Jinja2 Transforms](https://docs.infrahub.app/guides/jinja2-transform) within Infrahub. + +## Builtin Jinja2 filters + +The following filters are those that are [shipped with Jinja2](https://jinja.palletsprojects.com/en/stable/templates/#list-of-builtin-filters) and enabled within Infrahub. The trusted column indicates if the filter is allowed for use with Infrahub's computed attributes when the server is configured in strict mode. + + +| Name | Trusted | +|----------|----------| +{% for filter in builtin %} +| {{ filter.name }} | {% if filter.trusted %}✅{% else %}❌{% endif %} | +{% endfor %} + + +## Netutils filters + +The following Jinja2 filters from Netutils are included within Infrahub. + +| Name | Trusted | +|----------|----------| +{% for filter in netutils %} +| {{ filter.name }} | {% if filter.trusted %}✅{% else %}❌{% endif %} | +{% endfor %} + diff --git a/docs/docs/python-sdk/introduction.mdx b/docs/docs/python-sdk/introduction.mdx index 9a9b3bb9..699eef8c 100644 --- a/docs/docs/python-sdk/introduction.mdx +++ b/docs/docs/python-sdk/introduction.mdx @@ -4,7 +4,7 @@ title: Python SDK The Infrahub Python SDK greatly simplifies how you can interact with Infrahub programmatically. -## Blog Posts +## Blog posts - [Querying Data in Infrahub via the Python SDK](https://www.opsmill.com/querying-data-in-infrahub-via-the-python-sdk/) diff --git a/docs/docs/python-sdk/reference/templating.mdx b/docs/docs/python-sdk/reference/templating.mdx new file mode 100644 index 00000000..62f1b8aa --- /dev/null +++ b/docs/docs/python-sdk/reference/templating.mdx @@ -0,0 +1,153 @@ +--- +title: Python SDK Templating +--- +Filters can be used when defining [computed attributes](https://docs.infrahub.app/guides/computed-attributes) or [Jinja2 Transforms](https://docs.infrahub.app/guides/jinja2-transform) within Infrahub. + +## Builtin Jinja2 filters + +The following filters are those that are [shipped with Jinja2](https://jinja.palletsprojects.com/en/stable/templates/#list-of-builtin-filters) and enabled within Infrahub. The trusted column indicates if the filter is allowed for use with Infrahub's computed attributes when the server is configured in strict mode. + + +| Name | Trusted | +|----------|----------| +| abs | ✅ | +| attr | ❌ | +| batch | ❌ | +| capitalize | ✅ | +| center | ✅ | +| count | ✅ | +| d | ✅ | +| default | ✅ | +| dictsort | ❌ | +| e | ✅ | +| escape | ✅ | +| filesizeformat | ✅ | +| first | ✅ | +| float | ✅ | +| forceescape | ✅ | +| format | ✅ | +| groupby | ❌ | +| indent | ✅ | +| int | ✅ | +| items | ❌ | +| join | ✅ | +| last | ✅ | +| length | ✅ | +| list | ✅ | +| lower | ✅ | +| map | ❌ | +| max | ✅ | +| min | ✅ | +| pprint | ❌ | +| random | ❌ | +| reject | ❌ | +| rejectattr | ❌ | +| replace | ✅ | +| reverse | ✅ | +| round | ✅ | +| safe | ❌ | +| select | ❌ | +| selectattr | ❌ | +| slice | ✅ | +| sort | ❌ | +| string | ✅ | +| striptags | ✅ | +| sum | ✅ | +| title | ✅ | +| tojson | ❌ | +| trim | ✅ | +| truncate | ✅ | +| unique | ❌ | +| upper | ✅ | +| urlencode | ✅ | +| urlize | ❌ | +| wordcount | ✅ | +| wordwrap | ✅ | +| xmlattr | ❌ | + + +## Netutils filters + +The following Jinja2 filters from Netutils are included within Infrahub. + +| Name | Trusted | +|----------|----------| +| abbreviated_interface_name | ✅ | +| abbreviated_interface_name_list | ✅ | +| asn_to_int | ✅ | +| bits_to_name | ✅ | +| bytes_to_name | ✅ | +| canonical_interface_name | ✅ | +| canonical_interface_name_list | ✅ | +| cidr_to_netmask | ✅ | +| cidr_to_netmaskv6 | ✅ | +| clean_config | ✅ | +| compare_version_loose | ✅ | +| compare_version_strict | ✅ | +| config_compliance | ✅ | +| config_section_not_parsed | ✅ | +| delimiter_change | ✅ | +| diff_network_config | ✅ | +| feature_compliance | ✅ | +| find_unordered_cfg_lines | ✅ | +| fqdn_to_ip | ❌ | +| get_all_host | ❌ | +| get_broadcast_address | ✅ | +| get_first_usable | ✅ | +| get_ips_sorted | ✅ | +| get_nist_urls | ✅ | +| get_nist_vendor_platform_urls | ✅ | +| get_oui | ✅ | +| get_peer_ip | ✅ | +| get_range_ips | ✅ | +| get_upgrade_path | ✅ | +| get_usable_range | ✅ | +| hash_data | ✅ | +| int_to_asdot | ✅ | +| interface_range_compress | ✅ | +| interface_range_expansion | ✅ | +| ip_addition | ✅ | +| ip_subtract | ✅ | +| ip_to_bin | ✅ | +| ip_to_hex | ✅ | +| ipaddress_address | ✅ | +| ipaddress_interface | ✅ | +| ipaddress_network | ✅ | +| is_classful | ✅ | +| is_fqdn_resolvable | ❌ | +| is_ip | ✅ | +| is_ip_range | ✅ | +| is_ip_within | ✅ | +| is_netmask | ✅ | +| is_network | ✅ | +| is_reversible_wildcardmask | ✅ | +| is_valid_mac | ✅ | +| longest_prefix_match | ✅ | +| mac_normalize | ✅ | +| mac_to_format | ✅ | +| mac_to_int | ✅ | +| mac_type | ✅ | +| name_to_bits | ✅ | +| name_to_bytes | ✅ | +| name_to_name | ✅ | +| netmask_to_cidr | ✅ | +| netmask_to_wildcardmask | ✅ | +| normalise_delimiter_caret_c | ✅ | +| paloalto_panos_brace_to_set | ✅ | +| paloalto_panos_clean_newlines | ✅ | +| regex_findall | ❌ | +| regex_match | ❌ | +| regex_search | ❌ | +| regex_split | ❌ | +| regex_sub | ❌ | +| sanitize_config | ✅ | +| section_config | ✅ | +| sort_interface_list | ✅ | +| split_interface | ✅ | +| uptime_seconds_to_string | ✅ | +| uptime_string_to_seconds | ✅ | +| version_metadata | ✅ | +| vlanconfig_to_list | ✅ | +| vlanlist_to_config | ✅ | +| wildcardmask_to_netmask | ✅ | + \ No newline at end of file diff --git a/docs/sidebars-python-sdk.ts b/docs/sidebars-python-sdk.ts index 8da5a81b..7cde4058 100644 --- a/docs/sidebars-python-sdk.ts +++ b/docs/sidebars-python-sdk.ts @@ -38,6 +38,7 @@ const sidebars: SidebarsConfig = { label: 'Reference', items: [ 'reference/config', + 'reference/templating', ], }, ], diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index 1835ff00..fffa8164 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -281,7 +281,7 @@ def _initialize(self) -> None: self.schema = InfrahubSchema(self) self.branch = InfrahubBranchManager(self) self.object_store = ObjectStore(self) - self.store = NodeStore() + self.store = NodeStore(default_branch=self.default_branch) self.task = InfrahubTaskManager(self) self.concurrent_execution_limit = asyncio.Semaphore(self.max_concurrent_execution) self._request_method: AsyncRequester = self.config.requester or self._default_request_method @@ -840,11 +840,11 @@ async def process_non_batch() -> tuple[list[InfrahubNode], list[InfrahubNode]]: if populate_store: for node in nodes: if node.id: - self.store.set(key=node.id, node=node) + self.store.set(node=node) related_nodes = list(set(related_nodes)) for node in related_nodes: if node.id: - self.store.set(key=node.id, node=node) + self.store.set(node=node) return nodes def clone(self) -> InfrahubClient: @@ -1529,7 +1529,7 @@ def _initialize(self) -> None: self.schema = InfrahubSchemaSync(self) self.branch = InfrahubBranchManagerSync(self) self.object_store = ObjectStoreSync(self) - self.store = NodeStoreSync() + self.store = NodeStoreSync(default_branch=self.default_branch) self.task = InfrahubTaskManagerSync(self) self._request_method: SyncRequester = self.config.sync_requester or self._default_request_method self.group_context = InfrahubGroupContextSync(self) @@ -1997,11 +1997,11 @@ def process_non_batch() -> tuple[list[InfrahubNodeSync], list[InfrahubNodeSync]] if populate_store: for node in nodes: if node.id: - self.store.set(key=node.id, node=node) + self.store.set(node=node) related_nodes = list(set(related_nodes)) for node in related_nodes: if node.id: - self.store.set(key=node.id, node=node) + self.store.set(node=node) return nodes @overload diff --git a/infrahub_sdk/ctl/cli_commands.py b/infrahub_sdk/ctl/cli_commands.py index 633ccd2c..0d9a850f 100644 --- a/infrahub_sdk/ctl/cli_commands.py +++ b/infrahub_sdk/ctl/cli_commands.py @@ -9,7 +9,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional -import jinja2 import typer import ujson from rich.console import Console @@ -18,7 +17,6 @@ from rich.panel import Panel from rich.pretty import Pretty from rich.table import Table -from rich.traceback import Traceback from .. import __version__ as sdk_version from ..async_typer import AsyncTyper @@ -31,7 +29,7 @@ from ..ctl.generator import run as run_generator from ..ctl.menu import app as menu_app from ..ctl.object import app as object_app -from ..ctl.render import list_jinja2_transforms +from ..ctl.render import list_jinja2_transforms, print_template_errors from ..ctl.repository import app as repository_app from ..ctl.repository import get_repository_config from ..ctl.schema import app as schema_app @@ -44,8 +42,9 @@ ) from ..ctl.validate import app as validate_app from ..exceptions import GraphQLError, ModuleImportError -from ..jinja2 import identify_faulty_jinja_code from ..schema import MainSchemaTypesAll, SchemaRoot +from ..template import Jinja2Template +from ..template.exceptions import JinjaTemplateError from ..utils import get_branch, write_to_file from ..yaml import SchemaFile from .exporter import dump @@ -168,43 +167,28 @@ async def run( raise typer.Abort(f"Unable to Load the method {method} in the Python script at {script}") client = initialize_client( - branch=branch, timeout=timeout, max_concurrent_execution=concurrent, identifier=module_name + branch=branch, + timeout=timeout, + max_concurrent_execution=concurrent, + identifier=module_name, ) func = getattr(module, method) await func(client=client, log=log, branch=branch, **variables_dict) -def render_jinja2_template(template_path: Path, variables: dict[str, str], data: dict[str, Any]) -> str: - if not template_path.is_file(): - console.print(f"[red]Unable to locate the template at {template_path}") - raise typer.Exit(1) - - templateLoader = jinja2.FileSystemLoader(searchpath=".") - templateEnv = jinja2.Environment(loader=templateLoader, trim_blocks=True, lstrip_blocks=True) - template = templateEnv.get_template(str(template_path)) - +async def render_jinja2_template(template_path: Path, variables: dict[str, Any], data: dict[str, Any]) -> str: + variables["data"] = data + jinja_template = Jinja2Template(template=Path(template_path), template_directory=Path()) try: - rendered_tpl = template.render(**variables, data=data) # type: ignore[arg-type] - except jinja2.TemplateSyntaxError as exc: - console.print("[red]Syntax Error detected on the template") - console.print(f"[yellow] {exc}") - raise typer.Exit(1) from exc - - except jinja2.UndefinedError as exc: - console.print("[red]An error occurred while rendering the jinja template") - traceback = Traceback(show_locals=False) - errors = identify_faulty_jinja_code(traceback=traceback) - for frame, syntax in errors: - console.print(f"[yellow]{frame.filename} on line {frame.lineno}\n") - console.print(syntax) - console.print("") - console.print(traceback.trace.stacks[0].exc_value) + rendered_tpl = await jinja_template.render(variables=variables) + except JinjaTemplateError as exc: + print_template_errors(error=exc, console=console) raise typer.Exit(1) from exc return rendered_tpl -def _run_transform( +async def _run_transform( query_name: str, variables: dict[str, Any], transform_func: Callable, @@ -227,7 +211,11 @@ def _run_transform( try: response = execute_graphql_query( - query=query_name, variables_dict=variables, branch=branch, debug=debug, repository_config=repository_config + query=query_name, + variables_dict=variables, + branch=branch, + debug=debug, + repository_config=repository_config, ) # TODO: response is a dict and can't be printed to the console in this way. @@ -249,7 +237,7 @@ def _run_transform( raise typer.Abort() if asyncio.iscoroutinefunction(transform_func): - output = asyncio.run(transform_func(response)) + output = await transform_func(response) else: output = transform_func(response) return output @@ -257,7 +245,7 @@ def _run_transform( @app.command(name="render") @catch_exception(console=console) -def render( +async def render( transform_name: str = typer.Argument(default="", help="Name of the Python transformation", show_default=False), variables: Optional[list[str]] = typer.Argument( None, help="Variables to pass along with the query. Format key=value key=value." @@ -289,7 +277,7 @@ def render( transform_func = functools.partial(render_jinja2_template, transform_config.template_path, variables_dict) # Query GQL and run the transform - result = _run_transform( + result = await _run_transform( query_name=transform_config.query, variables=variables_dict, transform_func=transform_func, @@ -410,7 +398,10 @@ def version() -> None: @app.command(name="info") @catch_exception(console=console) -def info(detail: bool = typer.Option(False, help="Display detailed information."), _: str = CONFIG_PARAM) -> None: # noqa: PLR0915 +def info( # noqa: PLR0915 + detail: bool = typer.Option(False, help="Display detailed information."), + _: str = CONFIG_PARAM, +) -> None: """Display the status of the Python SDK.""" info: dict[str, Any] = { @@ -476,10 +467,14 @@ def info(detail: bool = typer.Option(False, help="Display detailed information." infrahub_info = Table(show_header=False, box=None) if info["user_info"]: infrahub_info.add_row("User:", info["user_info"]["AccountProfile"]["display_label"]) - infrahub_info.add_row("Description:", info["user_info"]["AccountProfile"]["description"]["value"]) + infrahub_info.add_row( + "Description:", + info["user_info"]["AccountProfile"]["description"]["value"], + ) infrahub_info.add_row("Status:", info["user_info"]["AccountProfile"]["status"]["label"]) infrahub_info.add_row( - "Number of Groups:", str(info["user_info"]["AccountProfile"]["member_of_groups"]["count"]) + "Number of Groups:", + str(info["user_info"]["AccountProfile"]["member_of_groups"]["count"]), ) if groups := info["groups"]: diff --git a/infrahub_sdk/ctl/render.py b/infrahub_sdk/ctl/render.py index 05122102..cb1c962e 100644 --- a/infrahub_sdk/ctl/render.py +++ b/infrahub_sdk/ctl/render.py @@ -1,6 +1,12 @@ from rich.console import Console from ..schema.repository import InfrahubRepositoryConfig +from ..template.exceptions import ( + JinjaTemplateError, + JinjaTemplateNotFoundError, + JinjaTemplateSyntaxError, + JinjaTemplateUndefinedError, +) def list_jinja2_transforms(config: InfrahubRepositoryConfig) -> None: @@ -9,3 +15,36 @@ def list_jinja2_transforms(config: InfrahubRepositoryConfig) -> None: for transform in config.jinja2_transforms: console.print(f"{transform.name} ({transform.template_path})") + + +def print_template_errors(error: JinjaTemplateError, console: Console) -> None: + if isinstance(error, JinjaTemplateNotFoundError): + console.print("[red]An error occurred while rendering the jinja template") + console.print("") + if error.base_template: + console.print(f"Base template: [yellow]{error.base_template}") + console.print(f"Missing template: [yellow]{error.filename}") + return + + if isinstance(error, JinjaTemplateUndefinedError): + console.print("[red]An error occurred while rendering the jinja template") + for current_error in error.errors: + console.print(f"[yellow]{current_error.frame.filename} on line {current_error.frame.lineno}\n") + console.print(current_error.syntax) + console.print("") + console.print(error.message) + return + + if isinstance(error, JinjaTemplateSyntaxError): + console.print("[red]A syntax error was encountered within the template") + console.print("") + if error.filename: + console.print(f"Filename: [yellow]{error.filename}") + console.print(f"Line number: [yellow]{error.lineno}") + console.print() + console.print(error.message) + return + + console.print("[red]An error occurred while rendering the jinja template") + console.print("") + console.print(f"[yellow]{error.message}") diff --git a/infrahub_sdk/exceptions.py b/infrahub_sdk/exceptions.py index 257ce6b4..f8a5b541 100644 --- a/infrahub_sdk/exceptions.py +++ b/infrahub_sdk/exceptions.py @@ -69,12 +69,12 @@ def __init__(self, message: str | None = None): class NodeNotFoundError(Error): def __init__( self, - node_type: str, identifier: Mapping[str, list[str]], message: str = "Unable to find the node in the database.", branch_name: str | None = None, + node_type: str | None = None, ): - self.node_type = node_type + self.node_type = node_type or "unknown" self.identifier = identifier self.branch_name = branch_name @@ -88,6 +88,10 @@ def __str__(self) -> str: """ +class NodeInvalidError(NodeNotFoundError): + pass + + class ResourceNotDefinedError(Error): """Raised when trying to access a resource that hasn't been defined.""" diff --git a/infrahub_sdk/generator.py b/infrahub_sdk/generator.py index e08f6642..24e4bebc 100644 --- a/infrahub_sdk/generator.py +++ b/infrahub_sdk/generator.py @@ -137,7 +137,7 @@ async def process_nodes(self, data: dict) -> None: for node in self._nodes + self._related_nodes: if node.id: - self._init_client.store.set(key=node.id, node=node) + self._init_client.store.set(node=node) @abstractmethod async def generate(self, data: dict) -> None: diff --git a/infrahub_sdk/node.py b/infrahub_sdk/node.py index 45ca0bd3..6033fe3b 100644 --- a/infrahub_sdk/node.py +++ b/infrahub_sdk/node.py @@ -15,7 +15,7 @@ ) from .graphql import Mutation, Query from .schema import GenericSchemaAPI, RelationshipCardinality, RelationshipKind -from .utils import compare_lists, get_flat_value +from .utils import compare_lists, generate_short_id, get_flat_value from .uuidt import UUIDT if TYPE_CHECKING: @@ -43,6 +43,20 @@ "calling generate is only supported for CoreArtifactDefinition nodes" ) +HFID_STR_SEPARATOR = "__" + + +def parse_human_friendly_id(hfid: str | list[str]) -> tuple[str | None, list[str]]: + """Parse a human friendly ID into a kind and an identifier.""" + if isinstance(hfid, str): + hfid_parts = hfid.split(HFID_STR_SEPARATOR) + if len(hfid_parts) == 1: + return None, hfid_parts + return hfid_parts[0], hfid_parts[1:] + if isinstance(hfid, list): + return None, hfid + raise ValueError(f"Invalid human friendly ID: {hfid}") + class Attribute: """Represents an attribute of a Node, including its schema, value, and properties.""" @@ -340,10 +354,10 @@ def get(self) -> InfrahubNode: return self._peer # type: ignore[return-value] if self.id and self.typename: - return self._client.store.get(key=self.id, kind=self.typename) # type: ignore[return-value] + return self._client.store.get(key=self.id, kind=self.typename, branch=self._branch) # type: ignore[return-value] if self.hfid_str: - return self._client.store.get_by_hfid(key=self.hfid_str) # type: ignore[return-value] + return self._client.store.get(key=self.hfid_str, branch=self._branch) # type: ignore[return-value] raise ValueError("Node must have at least one identifier (ID or HFID) to query it.") @@ -387,10 +401,10 @@ def get(self) -> InfrahubNodeSync: return self._peer # type: ignore[return-value] if self.id and self.typename: - return self._client.store.get(key=self.id, kind=self.typename) # type: ignore[return-value] + return self._client.store.get(key=self.id, kind=self.typename, branch=self._branch) # type: ignore[return-value] if self.hfid_str: - return self._client.store.get_by_hfid(key=self.hfid_str) # type: ignore[return-value] + return self._client.store.get(key=self.hfid_str, branch=self._branch) # type: ignore[return-value] raise ValueError("Node must have at least one identifier (ID or HFID) to query it.") @@ -678,6 +692,11 @@ def __init__(self, schema: MainSchemaTypesAPI, branch: str, data: dict | None = self._branch = branch self._existing: bool = True + # Generate a unique ID only to be used inside the SDK + # The format if this ID is purposely different from the ID used by the API + # This is done to avoid confusion and potential conflicts between the IDs + self._internal_id = generate_short_id() + self.id = data.get("id", None) if isinstance(data, dict) else None self.display_label: str | None = data.get("display_label", None) if isinstance(data, dict) else None self.typename: str | None = data.get("__typename", schema.kind) if isinstance(data, dict) else schema.kind @@ -694,6 +713,9 @@ def __init__(self, schema: MainSchemaTypesAPI, branch: str, data: dict | None = self._init_attributes(data) self._init_relationships(data) + def get_branch(self) -> str: + return self._branch + def get_path_value(self, path: str) -> Any: path_parts = path.split("__") return_value = None @@ -794,6 +816,11 @@ def __repr__(self) -> str: def get_kind(self) -> str: return self._schema.kind + def get_all_kinds(self) -> list[str]: + if hasattr(self._schema, "inherit_from"): + return [self._schema.kind] + self._schema.inherit_from + return [self._schema.kind] + def is_ip_prefix(self) -> bool: builtin_ipprefix_kind = "BuiltinIPPrefix" return self.get_kind() == builtin_ipprefix_kind or builtin_ipprefix_kind in self._schema.inherit_from # type: ignore[union-attr] @@ -1201,7 +1228,7 @@ async def save( else: await self._client.group_context.add_related_nodes(ids=[self.id], update_group_context=update_group_context) - self._client.store.set(key=self.id, node=self) + self._client.store.set(node=self) async def generate_query_data( self, @@ -1726,7 +1753,7 @@ def save( else: self._client.group_context.add_related_nodes(ids=[self.id], update_group_context=update_group_context) - self._client.store.set(key=self.id, node=self) + self._client.store.set(node=self) def generate_query_data( self, diff --git a/infrahub_sdk/protocols_base.py b/infrahub_sdk/protocols_base.py index c634d37f..2d533ac7 100644 --- a/infrahub_sdk/protocols_base.py +++ b/infrahub_sdk/protocols_base.py @@ -144,7 +144,8 @@ class AnyAttributeOptional(Attribute): @runtime_checkable class CoreNodeBase(Protocol): _schema: MainSchemaTypes - id: str + _internal_id: str + id: str # NOTE this is incorrect, should be str | None display_label: str | None @property @@ -153,10 +154,16 @@ def hfid(self) -> list[str] | None: ... @property def hfid_str(self) -> str | None: ... + def get_human_friendly_id(self) -> list[str] | None: ... + def get_human_friendly_id_as_string(self, include_kind: bool = False) -> str | None: ... def get_kind(self) -> str: ... + def get_all_kinds(self) -> list[str]: ... + + def get_branch(self) -> str: ... + def is_ip_prefix(self) -> bool: ... def is_ip_address(self) -> bool: ... diff --git a/infrahub_sdk/pytest_plugin/items/jinja2_transform.py b/infrahub_sdk/pytest_plugin/items/jinja2_transform.py index a5bba094..4ed2e2c5 100644 --- a/infrahub_sdk/pytest_plugin/items/jinja2_transform.py +++ b/infrahub_sdk/pytest_plugin/items/jinja2_transform.py @@ -1,51 +1,47 @@ from __future__ import annotations +import asyncio import difflib +from pathlib import Path from typing import TYPE_CHECKING, Any import jinja2 import ujson from httpx import HTTPStatusError -from rich.console import Console -from rich.traceback import Traceback -from ...jinja2 import identify_faulty_jinja_code -from ..exceptions import Jinja2TransformError, Jinja2TransformUndefinedError, OutputMatchError +from ...template import Jinja2Template +from ...template.exceptions import JinjaTemplateError +from ..exceptions import OutputMatchError from ..models import InfrahubInputOutputTest, InfrahubTestExpectedResult from .base import InfrahubItem if TYPE_CHECKING: - from pathlib import Path - from pytest import ExceptionInfo class InfrahubJinja2Item(InfrahubItem): + def _get_jinja2(self) -> Jinja2Template: + return Jinja2Template( + template=Path(self.resource_config.template_path), # type: ignore[attr-defined] + template_directory=Path(self.session.infrahub_config_path.parent), # type: ignore[attr-defined] + ) + def get_jinja2_environment(self) -> jinja2.Environment: - loader = jinja2.FileSystemLoader(self.session.infrahub_config_path.parent) # type: ignore[attr-defined] - return jinja2.Environment(loader=loader, trim_blocks=True, lstrip_blocks=True) + jinja2_template = self._get_jinja2() + return jinja2_template.get_environment() def get_jinja2_template(self) -> jinja2.Template: - return self.get_jinja2_environment().get_template(str(self.resource_config.template_path)) # type: ignore[attr-defined] + jinja2_template = self._get_jinja2() + return jinja2_template.get_template() def render_jinja2_template(self, variables: dict[str, Any]) -> str | None: + jinja2_template = self._get_jinja2() + try: - return self.get_jinja2_template().render(**variables) - except jinja2.UndefinedError as exc: - traceback = Traceback(show_locals=False) - errors = identify_faulty_jinja_code(traceback=traceback) - console = Console() - with console.capture() as capture: - console.print(f"An error occurred while rendering Jinja2 transform:{self.name!r}\n", soft_wrap=True) - console.print(f"{exc.message}\n", soft_wrap=True) - for frame, syntax in errors: - console.print(f"{frame.filename} on line {frame.lineno}\n", soft_wrap=True) - console.print(syntax, soft_wrap=True) - str_output = capture.get() + return asyncio.run(jinja2_template.render(variables=variables)) + except JinjaTemplateError as exc: if self.test.expect == InfrahubTestExpectedResult.PASS: - raise Jinja2TransformUndefinedError( - name=self.name, message=str_output, rtb=traceback, errors=errors - ) from exc + raise exc return None def get_result_differences(self, computed: Any) -> str | None: @@ -99,8 +95,8 @@ def runtest(self) -> None: raise OutputMatchError(name=self.name, differences=differences) def repr_failure(self, excinfo: ExceptionInfo, style: str | None = None) -> str: - if isinstance(excinfo.value, (Jinja2TransformUndefinedError, Jinja2TransformError)): - return excinfo.value.message + if isinstance(excinfo.value, (JinjaTemplateError)): + return str(excinfo.value.message) return super().repr_failure(excinfo, style=style) diff --git a/infrahub_sdk/store.py b/infrahub_sdk/store.py index 624722e2..99659fc0 100644 --- a/infrahub_sdk/store.py +++ b/infrahub_sdk/store.py @@ -1,16 +1,18 @@ from __future__ import annotations -from collections import defaultdict -from typing import TYPE_CHECKING, Any, Literal, overload +import warnings +from typing import TYPE_CHECKING, Literal, overload -from .exceptions import NodeNotFoundError +from .exceptions import NodeInvalidError, NodeNotFoundError +from .node import parse_human_friendly_id if TYPE_CHECKING: - from .client import SchemaType + from .client import SchemaType, SchemaTypeSync from .node import InfrahubNode, InfrahubNodeSync + from .protocols_base import CoreNode, CoreNodeSync -def get_schema_name(schema: str | type[SchemaType] | None = None) -> str | None: +def get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str | None = None) -> str | None: if isinstance(schema, str): return schema @@ -20,130 +22,404 @@ def get_schema_name(schema: str | type[SchemaType] | None = None) -> str | None: return None -class NodeStoreBase: - """Internal Store for InfrahubNode objects. - - Often while creating a lot of new objects, - we need to save them in order to reuse them later to associate them with another node for example. - """ +class NodeStoreBranch: + def __init__(self, name: str) -> None: + self.branch_name = name - def __init__(self) -> None: - self._store: dict[str, dict] = defaultdict(dict) - self._store_by_hfid: dict[str, Any] = defaultdict(dict) + self._objs: dict[str, InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync] = {} + self._hfids: dict[str, dict[tuple, str]] = {} + self._keys: dict[str, str] = {} + self._uuids: dict[str, str] = {} - def _set(self, node: InfrahubNode | InfrahubNodeSync | SchemaType, key: str | None = None) -> None: - hfid = node.get_human_friendly_id_as_string(include_kind=True) + def count(self) -> int: + return len(self._objs) - if not key and not hfid: - raise ValueError("Cannot store node without human friendly ID or key.") + def set(self, node: InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync, key: str | None = None) -> None: + self._objs[node._internal_id] = node if key: - node_kind = node._schema.kind - self._store[node_kind][key] = node + self._keys[key] = node._internal_id - if hfid: - self._store_by_hfid[hfid] = node + if node.id: + self._uuids[node.id] = node._internal_id + + if hfid := node.get_human_friendly_id(): + for kind in node.get_all_kinds(): + if kind not in self._hfids: + self._hfids[kind] = {} + self._hfids[kind][tuple(hfid)] = node._internal_id + + def get( + self, + key: str | list[str], + kind: type[SchemaType | SchemaTypeSync] | str | None = None, + raise_when_missing: bool = True, + ) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync | None: + found_invalid = False - def _get(self, key: str, kind: str | type[SchemaType] | None = None, raise_when_missing: bool = True): # type: ignore[no-untyped-def] kind_name = get_schema_name(schema=kind) - if kind_name and kind_name not in self._store and key not in self._store[kind_name]: # type: ignore[attr-defined] - if not raise_when_missing: - return None + + if isinstance(key, list): + try: + return self._get_by_hfid(key, kind=kind_name) + except NodeNotFoundError: + pass + + elif isinstance(key, str): + try: + return self._get_by_internal_id(key, kind=kind_name) + except NodeInvalidError: + found_invalid = True + except NodeNotFoundError: + pass + + try: + return self._get_by_id(key, kind=kind_name) + except NodeInvalidError: + found_invalid = True + except NodeNotFoundError: + pass + + try: + return self._get_by_key(key, kind=kind_name) + except NodeInvalidError: + found_invalid = True + except NodeNotFoundError: + pass + + try: + return self._get_by_hfid(key, kind=kind_name) + except NodeNotFoundError: + pass + + if not raise_when_missing: + return None + + if kind and found_invalid: + raise NodeInvalidError( + identifier={"key": [key] if isinstance(key, str) else key}, + message=f"Found a node of a different kind instead of {kind} for key {key!r} in the store ({self.branch_name})", + ) + + raise NodeNotFoundError( + identifier={"key": [key] if isinstance(key, str) else key}, + message=f"Unable to find the node {key!r} in the store ({self.branch_name})", + ) + + def _get_by_internal_id( + self, internal_id: str, kind: str | None = None + ) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync: + if internal_id not in self._objs: + raise NodeNotFoundError( + identifier={"internal_id": [internal_id]}, + message=f"Unable to find the node {internal_id!r} in the store ({self.branch_name})", + ) + + node = self._objs[internal_id] + if kind and kind not in node.get_all_kinds(): + raise NodeInvalidError( + node_type=kind, + identifier={"internal_id": [internal_id]}, + message=f"Found a node of kind {node.get_kind()} instead of {kind} for internal_id {internal_id!r} in the store ({self.branch_name})", + ) + + return node + + def _get_by_key( + self, key: str, kind: str | None = None + ) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync: + if key not in self._keys: raise NodeNotFoundError( - node_type=kind_name, identifier={"key": [key]}, - message="Unable to find the node in the Store", + message=f"Unable to find the node {key!r} in the store ({self.branch_name})", ) - if kind_name and kind_name in self._store and key in self._store[kind_name]: # type: ignore[attr-defined] - return self._store[kind_name][key] # type: ignore[attr-defined] + node = self._get_by_internal_id(self._keys[key]) - for item in self._store.values(): # type: ignore[attr-defined] - if key in item: - return item[key] + if kind and node.get_kind() != kind: + raise NodeInvalidError( + node_type=kind, + identifier={"key": [key]}, + message=f"Found a node of kind {node.get_kind()} instead of {kind} for key {key!r} in the store ({self.branch_name})", + ) - if not raise_when_missing: - return None - raise NodeNotFoundError( - node_type="n/a", - identifier={"key": [key]}, - message=f"Unable to find the node {key!r} in the store", + return node + + def _get_by_id(self, id: str, kind: str | None = None) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync: + if id not in self._uuids: + raise NodeNotFoundError( + identifier={"id": [id]}, + message=f"Unable to find the node {id!r} in the store ({self.branch_name})", + ) + + node = self._get_by_internal_id(self._uuids[id]) + if kind and kind not in node.get_all_kinds(): + raise NodeInvalidError( + node_type=kind, + identifier={"id": [id]}, + message=f"Found a node of kind {node.get_kind()} instead of {kind} for id {id!r} in the store ({self.branch_name})", + ) + + return node + + def _get_by_hfid( + self, hfid: str | list[str], kind: str | None = None + ) -> InfrahubNode | InfrahubNodeSync | CoreNode | CoreNodeSync: + if not kind: + node_kind, node_hfid = parse_human_friendly_id(hfid) + elif kind and isinstance(hfid, str) and hfid.startswith(kind): + node_kind, node_hfid = parse_human_friendly_id(hfid) + else: + node_kind = kind + node_hfid = [hfid] if isinstance(hfid, str) else hfid + + exception_to_raise_if_not_found = NodeNotFoundError( + node_type=node_kind, + identifier={"hfid": node_hfid}, + message=f"Unable to find the node {hfid!r} in the store ({self.branch_name})", ) - def _get_by_hfid(self, key: str, raise_when_missing: bool = True): # type: ignore[no-untyped-def] - try: - return self._store_by_hfid[key] - except KeyError as exc: - if raise_when_missing: - raise NodeNotFoundError( - node_type="n/a", - identifier={"key": [key]}, - message=f"Unable to find the node {key!r} in the store", - ) from exc - return None + if node_kind not in self._hfids: + raise exception_to_raise_if_not_found + + if tuple(node_hfid) not in self._hfids[node_kind]: + raise exception_to_raise_if_not_found + + internal_id = self._hfids[node_kind][tuple(node_hfid)] + return self._objs[internal_id] + + +class NodeStoreBase: + """Internal Store for InfrahubNode objects. + + Often while creating a lot of new objects, + we need to save them in order to reuse them later to associate them with another node for example. + """ + + def __init__(self, default_branch: str | None = None) -> None: + self._branches: dict[str, NodeStoreBranch] = {} + + if default_branch is None: + default_branch = "main" + warnings.warn( + "Using a store without specifying a default branch is deprecated and will be removed in a future version. " + "Please explicitly specify a branch name.", + DeprecationWarning, + stacklevel=2, + ) + + self._default_branch = default_branch + + def _get_branch(self, branch: str | None = None) -> str: + return branch or self._default_branch + + def _set( + self, + node: InfrahubNode | InfrahubNodeSync | SchemaType | SchemaTypeSync, + key: str | None = None, + branch: str | None = None, + ) -> None: + branch = self._get_branch(branch or node.get_branch()) + + if branch not in self._branches: + self._branches[branch] = NodeStoreBranch(name=branch) + + self._branches[branch].set(node=node, key=key) + + def _get( # type: ignore[no-untyped-def] + self, + key: str | list[str], + kind: type[SchemaType | SchemaTypeSync] | str | None = None, + raise_when_missing: bool = True, + branch: str | None = None, + ): + branch = self._get_branch(branch) + + if branch not in self._branches: + self._branches[branch] = NodeStoreBranch(name=branch) + + return self._branches[branch].get(key=key, kind=kind, raise_when_missing=raise_when_missing) + + def count(self, branch: str | None = None) -> int: + branch = self._get_branch(branch) + + if branch not in self._branches: + return 0 + + return self._branches[branch].count() class NodeStore(NodeStoreBase): @overload - def get(self, key: str, kind: type[SchemaType], raise_when_missing: Literal[True] = True) -> SchemaType: ... + def get( + self, + key: str | list[str], + kind: type[SchemaType], + raise_when_missing: Literal[True] = True, + branch: str | None = ..., + ) -> SchemaType: ... @overload def get( - self, key: str, kind: type[SchemaType], raise_when_missing: Literal[False] = False + self, + key: str | list[str], + kind: type[SchemaType], + raise_when_missing: Literal[False] = False, + branch: str | None = ..., ) -> SchemaType | None: ... @overload - def get(self, key: str, kind: type[SchemaType], raise_when_missing: bool = ...) -> SchemaType: ... + def get( + self, + key: str | list[str], + kind: type[SchemaType], + raise_when_missing: bool = ..., + branch: str | None = ..., + ) -> SchemaType: ... @overload def get( - self, key: str, kind: str | None = ..., raise_when_missing: Literal[False] = False - ) -> InfrahubNode | None: ... + self, + key: str | list[str], + kind: str | None = ..., + raise_when_missing: Literal[True] = True, + branch: str | None = ..., + ) -> InfrahubNode: ... @overload - def get(self, key: str, kind: str | None = ..., raise_when_missing: Literal[True] = True) -> InfrahubNode: ... + def get( + self, + key: str | list[str], + kind: str | None = ..., + raise_when_missing: Literal[False] = False, + branch: str | None = ..., + ) -> InfrahubNode | None: ... @overload - def get(self, key: str, kind: str | None = ..., raise_when_missing: bool = ...) -> InfrahubNode: ... + def get( + self, + key: str | list[str], + kind: str | None = ..., + raise_when_missing: bool = ..., + branch: str | None = ..., + ) -> InfrahubNode: ... def get( - self, key: str, kind: str | type[SchemaType] | None = None, raise_when_missing: bool = True + self, + key: str | list[str], + kind: str | type[SchemaType] | None = None, + raise_when_missing: bool = True, + branch: str | None = None, ) -> InfrahubNode | SchemaType | None: - return self._get(key=key, kind=kind, raise_when_missing=raise_when_missing) + return self._get(key=key, kind=kind, raise_when_missing=raise_when_missing, branch=branch) @overload - def get_by_hfid(self, key: str, raise_when_missing: Literal[True] = True) -> InfrahubNode: ... + def get_by_hfid( + self, key: str | list[str], raise_when_missing: Literal[True] = True, branch: str | None = ... + ) -> InfrahubNode: ... @overload - def get_by_hfid(self, key: str, raise_when_missing: Literal[False] = False) -> InfrahubNode | None: ... + def get_by_hfid( + self, key: str | list[str], raise_when_missing: Literal[False] = False, branch: str | None = ... + ) -> InfrahubNode | None: ... - def get_by_hfid(self, key: str, raise_when_missing: bool = True) -> InfrahubNode | None: - return self._get_by_hfid(key=key, raise_when_missing=raise_when_missing) + def get_by_hfid( + self, key: str | list[str], raise_when_missing: bool = True, branch: str | None = None + ) -> InfrahubNode | None: + warnings.warn( + "get_by_hfid() is deprecated and will be removed in a future version. Use get() instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.get(key=key, raise_when_missing=raise_when_missing, branch=branch) - def set(self, node: Any, key: str | None = None) -> None: - return self._set(node=node, key=key) + def set(self, node: InfrahubNode | SchemaType, key: str | None = None, branch: str | None = None) -> None: + return self._set(node=node, key=key, branch=branch) class NodeStoreSync(NodeStoreBase): @overload - def get(self, key: str, kind: str | None = None, raise_when_missing: Literal[True] = True) -> InfrahubNodeSync: ... + def get( + self, + key: str | list[str], + kind: type[SchemaTypeSync], + raise_when_missing: Literal[True] = True, + branch: str | None = ..., + ) -> SchemaTypeSync: ... @overload def get( - self, key: str, kind: str | None = None, raise_when_missing: Literal[False] = False + self, + key: str | list[str], + kind: type[SchemaTypeSync], + raise_when_missing: Literal[False] = False, + branch: str | None = ..., + ) -> SchemaTypeSync | None: ... + + @overload + def get( + self, + key: str | list[str], + kind: type[SchemaTypeSync], + raise_when_missing: bool = ..., + branch: str | None = ..., + ) -> SchemaTypeSync: ... + + @overload + def get( + self, + key: str | list[str], + kind: str | None = ..., + raise_when_missing: Literal[True] = True, + branch: str | None = ..., + ) -> InfrahubNodeSync: ... + + @overload + def get( + self, + key: str | list[str], + kind: str | None = ..., + raise_when_missing: Literal[False] = False, + branch: str | None = ..., ) -> InfrahubNodeSync | None: ... - def get(self, key: str, kind: str | None = None, raise_when_missing: bool = True) -> InfrahubNodeSync | None: - return self._get(key=key, kind=kind, raise_when_missing=raise_when_missing) + @overload + def get( + self, + key: str | list[str], + kind: str | None = ..., + raise_when_missing: bool = ..., + branch: str | None = ..., + ) -> InfrahubNodeSync: ... + + def get( + self, + key: str | list[str], + kind: str | type[SchemaTypeSync] | None = None, + raise_when_missing: bool = True, + branch: str | None = None, + ) -> InfrahubNodeSync | SchemaTypeSync | None: + return self._get(key=key, kind=kind, raise_when_missing=raise_when_missing, branch=branch) @overload - def get_by_hfid(self, key: str, raise_when_missing: Literal[True] = True) -> InfrahubNodeSync: ... + def get_by_hfid( + self, key: str | list[str], raise_when_missing: Literal[True] = True, branch: str | None = ... + ) -> InfrahubNodeSync: ... @overload - def get_by_hfid(self, key: str, raise_when_missing: Literal[False] = False) -> InfrahubNodeSync | None: ... + def get_by_hfid( + self, key: str | list[str], raise_when_missing: Literal[False] = False, branch: str | None = ... + ) -> InfrahubNodeSync | None: ... - def get_by_hfid(self, key: str, raise_when_missing: bool = True) -> InfrahubNodeSync | None: - return self._get_by_hfid(key=key, raise_when_missing=raise_when_missing) + def get_by_hfid( + self, key: str | list[str], raise_when_missing: bool = True, branch: str | None = None + ) -> InfrahubNodeSync | None: + warnings.warn( + "get_by_hfid() is deprecated and will be removed in a future version. Use get() instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.get(key=key, raise_when_missing=raise_when_missing, branch=branch) - def set(self, node: InfrahubNodeSync, key: str | None = None) -> None: - return self._set(node=node, key=key) + def set(self, node: InfrahubNodeSync | SchemaTypeSync, key: str | None = None, branch: str | None = None) -> None: + return self._set(node=node, key=key, branch=branch) diff --git a/infrahub_sdk/template/__init__.py b/infrahub_sdk/template/__init__.py new file mode 100644 index 00000000..c43f7ad9 --- /dev/null +++ b/infrahub_sdk/template/__init__.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +import linecache +from pathlib import Path +from typing import Any, Callable, NoReturn + +import jinja2 +from jinja2 import meta, nodes +from jinja2.sandbox import SandboxedEnvironment +from netutils.utils import jinja2_convenience_function +from rich.syntax import Syntax +from rich.traceback import Traceback + +from .exceptions import ( + JinjaTemplateError, + JinjaTemplateNotFoundError, + JinjaTemplateOperationViolationError, + JinjaTemplateSyntaxError, + JinjaTemplateUndefinedError, +) +from .filters import AVAILABLE_FILTERS +from .models import UndefinedJinja2Error + +netutils_filters = jinja2_convenience_function() + + +class Jinja2Template: + def __init__( + self, + template: str | Path, + template_directory: Path | None = None, + filters: dict[str, Callable] | None = None, + ) -> None: + self.is_string_based = isinstance(template, str) + self.is_file_based = isinstance(template, Path) + self._template = str(template) + self._template_directory = template_directory + self._environment: jinja2.Environment | None = None + + self._available_filters = [filter_definition.name for filter_definition in AVAILABLE_FILTERS] + self._trusted_filters = [ + filter_definition.name for filter_definition in AVAILABLE_FILTERS if filter_definition.trusted + ] + + self._filters = filters or {} + for user_filter in self._filters: + self._available_filters.append(user_filter) + self._trusted_filters.append(user_filter) + + self._template_definition: jinja2.Template | None = None + + def get_environment(self) -> jinja2.Environment: + if self._environment: + return self._environment + + if self.is_string_based: + return self._get_string_based_environment() + + return self._get_file_based_environment() + + def get_template(self) -> jinja2.Template: + if self._template_definition: + return self._template_definition + + try: + if self.is_string_based: + template = self._get_string_based_template() + else: + template = self._get_file_based_template() + except jinja2.TemplateSyntaxError as exc: + self._raise_template_syntax_error(error=exc) + except jinja2.TemplateNotFound as exc: + raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name)) + + return template + + def get_variables(self) -> list[str]: + env = self.get_environment() + + template_source = self._template + if self.is_file_based and env.loader: + template_source = env.loader.get_source(env, self._template)[0] + + template = env.parse(template_source) + + return sorted(meta.find_undeclared_variables(template)) + + def validate(self, restricted: bool = True) -> None: + allowed_list = self._available_filters + if restricted: + allowed_list = self._trusted_filters + + env = self.get_environment() + template_source = self._template + if self.is_file_based and env.loader: + template_source = env.loader.get_source(env, self._template)[0] + + template = env.parse(template_source) + for node in template.find_all(nodes.Filter): + if node.name not in allowed_list: + raise JinjaTemplateOperationViolationError(f"The '{node.name}' filter isn't allowed to be used") + + forbidden_operations = ["Call", "Import", "Include"] + if self.is_string_based and any(node.__class__.__name__ in forbidden_operations for node in template.body): + raise JinjaTemplateOperationViolationError( + f"These operations are forbidden for string based templates: {forbidden_operations}" + ) + + async def render(self, variables: dict[str, Any]) -> str: + template = self.get_template() + try: + output = await template.render_async(variables) + except jinja2.exceptions.TemplateNotFound as exc: + raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name), base_template=template.name) + except jinja2.TemplateSyntaxError as exc: + self._raise_template_syntax_error(error=exc) + except jinja2.UndefinedError as exc: + traceback = Traceback(show_locals=False) + errors = _identify_faulty_jinja_code(traceback=traceback) + raise JinjaTemplateUndefinedError(message=exc.message, errors=errors) + except Exception as exc: + if error_message := getattr(exc, "message", None): + message = error_message + else: + message = str(exc) + raise JinjaTemplateError(message=message or "Unknown template error") + + return output + + def _get_string_based_environment(self) -> jinja2.Environment: + env = SandboxedEnvironment(enable_async=True, undefined=jinja2.StrictUndefined) + self._set_filters(env=env) + self._environment = env + return self._environment + + def _get_file_based_environment(self) -> jinja2.Environment: + template_loader = jinja2.FileSystemLoader(searchpath=str(self._template_directory)) + env = jinja2.Environment( + loader=template_loader, + trim_blocks=True, + lstrip_blocks=True, + enable_async=True, + ) + self._set_filters(env=env) + self._environment = env + return self._environment + + def _set_filters(self, env: jinja2.Environment) -> None: + for default_filter in list(env.filters.keys()): + if default_filter not in self._available_filters: + del env.filters[default_filter] + + # Add filters from netutils + env.filters.update( + {name: jinja_filter for name, jinja_filter in netutils_filters.items() if name in self._available_filters} + ) + # Add user supplied filters + env.filters.update(self._filters) + + def _get_string_based_template(self) -> jinja2.Template: + env = self.get_environment() + self._template_definition = env.from_string(self._template) + return self._template_definition + + def _get_file_based_template(self) -> jinja2.Template: + env = self.get_environment() + self._template_definition = env.get_template(self._template) + return self._template_definition + + def _raise_template_syntax_error(self, error: jinja2.TemplateSyntaxError) -> NoReturn: + filename: str | None = None + if error.filename and self._template_directory: + filename = error.filename + if error.filename.startswith(str(self._template_directory)): + filename = error.filename[len(str(self._template_directory)) :] + + raise JinjaTemplateSyntaxError(message=error.message, filename=filename, lineno=error.lineno) + + +def _identify_faulty_jinja_code(traceback: Traceback, nbr_context_lines: int = 3) -> list[UndefinedJinja2Error]: + """This function identifies the faulty Jinja2 code and beautify it to provide meaningful information to the user. + + We use the rich's Traceback to parse the complete stack trace and extract Frames for each exception found in the trace. + """ + response = [] + + # Extract only the Jinja related exception + for frame in [frame for frame in traceback.trace.stacks[0].frames if not frame.filename.endswith(".py")]: + code = "".join(linecache.getlines(frame.filename)) + if frame.filename == "