Skip to content

Commit f8f5731

Browse files
committed
updates
1 parent b29559e commit f8f5731

File tree

13 files changed

+440
-59
lines changed

13 files changed

+440
-59
lines changed

infrahub_sdk/ctl/cli_commands.py

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ..ctl.generator import run as run_generator
3030
from ..ctl.menu import app as menu_app
3131
from ..ctl.object import app as object_app
32-
from ..ctl.render import list_jinja2_transforms
32+
from ..ctl.render import list_jinja2_transforms, print_template_errors
3333
from ..ctl.repository import app as repository_app
3434
from ..ctl.repository import get_repository_config
3535
from ..ctl.schema import app as schema_app
@@ -44,12 +44,7 @@
4444
from ..exceptions import GraphQLError, ModuleImportError
4545
from ..schema import MainSchemaTypesAll, SchemaRoot
4646
from ..template import Jinja2Template
47-
from ..template.exceptions import (
48-
JinjaTemplateError,
49-
JinjaTemplateNotFoundError,
50-
JinjaTemplateSyntaxError,
51-
JinjaTemplateUndefinedError,
52-
)
47+
from ..template.exceptions import JinjaTemplateError
5348
from ..utils import get_branch, write_to_file
5449
from ..yaml import SchemaFile
5550
from .exporter import dump
@@ -172,50 +167,22 @@ async def run(
172167
raise typer.Abort(f"Unable to Load the method {method} in the Python script at {script}")
173168

174169
client = initialize_client(
175-
branch=branch, timeout=timeout, max_concurrent_execution=concurrent, identifier=module_name
170+
branch=branch,
171+
timeout=timeout,
172+
max_concurrent_execution=concurrent,
173+
identifier=module_name,
176174
)
177175
func = getattr(module, method)
178176
await func(client=client, log=log, branch=branch, **variables_dict)
179177

180178

181179
async def render_jinja2_template(template_path: Path, variables: dict[str, Any], data: dict[str, Any]) -> str:
182-
if not template_path.is_file():
183-
console.print(f"[red]Unable to locate the template at {template_path}")
184-
raise typer.Exit(1)
185-
186180
variables["data"] = data
187181
jinja_template = Jinja2Template(template=Path(template_path), template_directory=Path())
188182
try:
189183
rendered_tpl = await jinja_template.render(variables=variables)
190-
except JinjaTemplateNotFoundError as exc:
191-
console.print("[red]An error occurred while rendering the jinja template")
192-
console.print("")
193-
if exc.base_template:
194-
console.print(f"Base template: [yellow]{exc.base_template}")
195-
console.print(f"Missing template: [yellow]{exc.filename}")
196-
raise typer.Exit(1) from exc
197-
198-
except JinjaTemplateUndefinedError as exc:
199-
console.print("[red]An error occurred while rendering the jinja template")
200-
for error in exc.errors:
201-
console.print(f"[yellow]{error.frame.filename} on line {error.frame.lineno}\n")
202-
console.print(error.syntax)
203-
console.print("")
204-
console.print(exc.message)
205-
raise typer.Exit(1) from exc
206-
except JinjaTemplateSyntaxError as exc:
207-
console.print("[red]A syntax error was encountered within the template")
208-
console.print("")
209-
if exc.filename:
210-
console.print(f"Filename: [yellow]{exc.filename}")
211-
console.print(f"Line number: [yellow]{exc.lineno}")
212-
console.print()
213-
console.print(exc.message)
214-
raise typer.Exit(1) from exc
215184
except JinjaTemplateError as exc:
216-
console.print("[red]An error occurred while rendering the jinja template")
217-
console.print("")
218-
console.print(f"[yellow]{exc.message}")
185+
print_template_errors(error=exc, console=console)
219186
raise typer.Exit(1) from exc
220187

221188
return rendered_tpl
@@ -244,7 +211,11 @@ async def _run_transform(
244211

245212
try:
246213
response = execute_graphql_query(
247-
query=query_name, variables_dict=variables, branch=branch, debug=debug, repository_config=repository_config
214+
query=query_name,
215+
variables_dict=variables,
216+
branch=branch,
217+
debug=debug,
218+
repository_config=repository_config,
248219
)
249220

250221
# TODO: response is a dict and can't be printed to the console in this way.
@@ -427,7 +398,10 @@ def version() -> None:
427398

428399
@app.command(name="info")
429400
@catch_exception(console=console)
430-
def info(detail: bool = typer.Option(False, help="Display detailed information."), _: str = CONFIG_PARAM) -> None: # noqa: PLR0915
401+
def info( # noqa: PLR0915
402+
detail: bool = typer.Option(False, help="Display detailed information."),
403+
_: str = CONFIG_PARAM,
404+
) -> None:
431405
"""Display the status of the Python SDK."""
432406

433407
info: dict[str, Any] = {
@@ -493,10 +467,14 @@ def info(detail: bool = typer.Option(False, help="Display detailed information."
493467
infrahub_info = Table(show_header=False, box=None)
494468
if info["user_info"]:
495469
infrahub_info.add_row("User:", info["user_info"]["AccountProfile"]["display_label"])
496-
infrahub_info.add_row("Description:", info["user_info"]["AccountProfile"]["description"]["value"])
470+
infrahub_info.add_row(
471+
"Description:",
472+
info["user_info"]["AccountProfile"]["description"]["value"],
473+
)
497474
infrahub_info.add_row("Status:", info["user_info"]["AccountProfile"]["status"]["label"])
498475
infrahub_info.add_row(
499-
"Number of Groups:", str(info["user_info"]["AccountProfile"]["member_of_groups"]["count"])
476+
"Number of Groups:",
477+
str(info["user_info"]["AccountProfile"]["member_of_groups"]["count"]),
500478
)
501479

502480
if groups := info["groups"]:

infrahub_sdk/ctl/render.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
from rich.console import Console
22

33
from ..schema.repository import InfrahubRepositoryConfig
4+
from ..template.exceptions import (
5+
JinjaTemplateError,
6+
JinjaTemplateNotFoundError,
7+
JinjaTemplateSyntaxError,
8+
JinjaTemplateUndefinedError,
9+
)
410

511

612
def list_jinja2_transforms(config: InfrahubRepositoryConfig) -> None:
@@ -9,3 +15,36 @@ def list_jinja2_transforms(config: InfrahubRepositoryConfig) -> None:
915

1016
for transform in config.jinja2_transforms:
1117
console.print(f"{transform.name} ({transform.template_path})")
18+
19+
20+
def print_template_errors(error: JinjaTemplateError, console: Console) -> None:
21+
if isinstance(error, JinjaTemplateNotFoundError):
22+
console.print("[red]An error occurred while rendering the jinja template")
23+
console.print("")
24+
if error.base_template:
25+
console.print(f"Base template: [yellow]{error.base_template}")
26+
console.print(f"Missing template: [yellow]{error.filename}")
27+
return
28+
29+
if isinstance(error, JinjaTemplateUndefinedError):
30+
console.print("[red]An error occurred while rendering the jinja template")
31+
for current_error in error.errors:
32+
console.print(f"[yellow]{current_error.frame.filename} on line {current_error.frame.lineno}\n")
33+
console.print(current_error.syntax)
34+
console.print("")
35+
console.print(error.message)
36+
return
37+
38+
if isinstance(error, JinjaTemplateSyntaxError):
39+
console.print("[red]A syntax error was encountered within the template")
40+
console.print("")
41+
if error.filename:
42+
console.print(f"Filename: [yellow]{error.filename}")
43+
console.print(f"Line number: [yellow]{error.lineno}")
44+
console.print()
45+
console.print(error.message)
46+
return
47+
48+
console.print("[red]An error occurred while rendering the jinja template")
49+
console.print("")
50+
console.print(f"[yellow]{error.message}")

infrahub_sdk/template/__init__.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,48 @@
55
from typing import Any, Callable, NoReturn
66

77
import jinja2
8-
from jinja2 import meta
8+
from jinja2 import meta, nodes
99
from jinja2.sandbox import SandboxedEnvironment
1010
from netutils.utils import jinja2_convenience_function
1111
from rich.syntax import Syntax
1212
from rich.traceback import Traceback
1313

14-
from .constants import CURRATED_NETUTILS_FILTERS
1514
from .exceptions import (
1615
JinjaTemplateError,
1716
JinjaTemplateNotFoundError,
1817
JinjaTemplateOperationViolationError,
1918
JinjaTemplateSyntaxError,
2019
JinjaTemplateUndefinedError,
2120
)
21+
from .filters import AVAILABLE_FILTERS
2222
from .models import UndefinedJinja2Error
2323

2424
netutils_filters = jinja2_convenience_function()
2525

2626

2727
class Jinja2Template:
28-
def __init__(self, template: str | Path, template_directory: Path | None = None) -> None:
28+
def __init__(
29+
self,
30+
template: str | Path,
31+
template_directory: Path | None = None,
32+
filters: dict[str, Callable] | None = None,
33+
) -> None:
2934
self.is_string_based = isinstance(template, str)
3035
self.is_file_based = isinstance(template, Path)
3136
self._template = str(template)
3237
self._template_directory = template_directory
3338
self._environment: jinja2.Environment | None = None
34-
self._filters: dict[str, Callable] = {}
3539

36-
self._filters.update(
37-
{name: jinja_filter for name, jinja_filter in netutils_filters.items() if name in CURRATED_NETUTILS_FILTERS}
38-
)
40+
self._available_filters = [filter_definition.name for filter_definition in AVAILABLE_FILTERS]
41+
self._trusted_filters = [
42+
filter_definition.name for filter_definition in AVAILABLE_FILTERS if filter_definition.trusted
43+
]
44+
45+
self._filters = filters or {}
46+
for user_filter in self._filters:
47+
self._available_filters.append(user_filter)
48+
self._trusted_filters.append(user_filter)
49+
3950
self._template_definition: jinja2.Template | None = None
4051

4152
def get_environment(self) -> jinja2.Environment:
@@ -78,6 +89,21 @@ def get_variables(self) -> list[str]:
7889

7990
return sorted(meta.find_undeclared_variables(template))
8091

92+
def validate_filters(self, restricted: bool = True) -> None:
93+
allowed_list = self._available_filters
94+
if restricted:
95+
allowed_list = self._trusted_filters
96+
97+
env = self.get_environment()
98+
template_source = self._template
99+
if self.is_file_based and env.loader:
100+
template_source = env.loader.get_source(env, self._template)[0]
101+
102+
template = env.parse(template_source)
103+
for node in template.find_all(nodes.Filter):
104+
if node.name not in allowed_list:
105+
raise JinjaTemplateOperationViolationError(f"The '{node.name}' filter isn't allowed to be used")
106+
81107
async def render(self, variables: dict[str, Any]) -> str:
82108
template = self.get_template()
83109
try:
@@ -101,17 +127,34 @@ async def render(self, variables: dict[str, Any]) -> str:
101127

102128
def _get_string_based_environment(self) -> jinja2.Environment:
103129
env = SandboxedEnvironment(enable_async=True, undefined=jinja2.StrictUndefined)
104-
env.filters.update(self._filters)
130+
self._set_filters(env=env)
105131
self._environment = env
106132
return self._environment
107133

108134
def _get_file_based_environment(self) -> jinja2.Environment:
109135
template_loader = jinja2.FileSystemLoader(searchpath=str(self._template_directory))
110-
env = jinja2.Environment(loader=template_loader, trim_blocks=True, lstrip_blocks=True, enable_async=True)
111-
env.filters.update(self._filters)
136+
env = jinja2.Environment(
137+
loader=template_loader,
138+
trim_blocks=True,
139+
lstrip_blocks=True,
140+
enable_async=True,
141+
)
142+
self._set_filters(env=env)
112143
self._environment = env
113144
return self._environment
114145

146+
def _set_filters(self, env: jinja2.Environment) -> None:
147+
for default_filter in list(env.filters.keys()):
148+
if default_filter not in self._available_filters:
149+
del env.filters[default_filter]
150+
151+
# Add filters from netutils
152+
env.filters.update(
153+
{name: jinja_filter for name, jinja_filter in netutils_filters.items() if name in self._available_filters}
154+
)
155+
# Add user supplied filters
156+
env.filters.update(self._filters)
157+
115158
def _get_string_based_template(self) -> jinja2.Template:
116159
env = self.get_environment()
117160
self._template_definition = env.from_string(self._template)
@@ -150,7 +193,10 @@ def _identify_faulty_jinja_code(traceback: Traceback, nbr_context_lines: int = 3
150193
code,
151194
lexer_name,
152195
line_numbers=True,
153-
line_range=(frame.lineno - nbr_context_lines, frame.lineno + nbr_context_lines),
196+
line_range=(
197+
frame.lineno - nbr_context_lines,
198+
frame.lineno + nbr_context_lines,
199+
),
154200
highlight_lines={frame.lineno},
155201
code_width=88,
156202
theme=traceback.theme,

0 commit comments

Comments
 (0)