Skip to content

Commit 30192f4

Browse files
committed
updates
1 parent dec0ccb commit 30192f4

File tree

5 files changed

+113
-39
lines changed

5 files changed

+113
-39
lines changed

infrahub_sdk/ctl/cli_commands.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,9 @@ async def render_jinja2_template(template_path: Path, variables: dict[str, Any],
184184
raise typer.Exit(1)
185185

186186
variables["data"] = data
187-
jinja_template = Jinja2Template(template_directory=Path())
187+
jinja_template = Jinja2Template(template=Path(template_path), template_directory=Path())
188188
try:
189-
rendered_tpl = await jinja_template.render_from_file(template=template_path, variables=variables)
189+
rendered_tpl = await jinja_template.render(variables=variables)
190190
except JinjaTemplateNotFoundError as exc:
191191
console.print("[red]An error occurred while rendering the jinja template")
192192
console.print("")

infrahub_sdk/pytest_plugin/items/jinja2_transform.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ def get_jinja2_template(self) -> jinja2.Template:
2828
return self.get_jinja2_environment().get_template(str(self.resource_config.template_path)) # type: ignore[attr-defined]
2929

3030
def render_jinja2_template(self, variables: dict[str, Any]) -> str | None:
31-
jinja2_template = Jinja2Template(template_directory=Path(self.session.infrahub_config_path.parent)) # type: ignore[attr-defined]
31+
jinja2_template = Jinja2Template(
32+
template=Path(self.resource_config.template_path), # type: ignore[attr-defined]
33+
template_directory=Path(self.session.infrahub_config_path.parent), # type: ignore[attr-defined]
34+
)
3235

3336
try:
34-
return asyncio.run(
35-
jinja2_template.render_from_file(template=Path(self.resource_config.template_path), variables=variables) # type: ignore[attr-defined]
36-
)
37+
return asyncio.run(jinja2_template.render(variables=variables))
3738
except JinjaTemplateError as exc:
3839
if self.test.expect == InfrahubTestExpectedResult.PASS:
3940
raise exc

infrahub_sdk/template/__init__.py

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Callable, NoReturn
66

77
import jinja2
8+
from jinja2 import meta
89
from jinja2.sandbox import SandboxedEnvironment
910
from netutils.utils import jinja2_convenience_function
1011
from rich.syntax import Syntax
@@ -23,40 +24,56 @@
2324

2425

2526
class Jinja2Template:
26-
FORBIDDEN_OPERATIONS: list[str] | None = None
27-
28-
def __init__(self, template_directory: Path | None = None) -> None:
27+
def __init__(self, template: str | Path, template_directory: Path | None = None) -> None:
28+
self.is_string_based = isinstance(template, str)
29+
self.is_file_based = isinstance(template, Path)
30+
self._template = str(template)
2931
self._template_directory = template_directory
32+
self._environment: jinja2.Environment | None = None
3033
self._filters: dict[str, Callable] = {}
3134

3235
self._filters.update(
3336
{name: jinja_filter for name, jinja_filter in netutils_filters.items() if name in CURRATED_NETUTILS_FILTERS}
3437
)
38+
self._template_definition: jinja2.Template | None = None
3539

36-
async def render_from_file(self, template: Path, variables: dict[str, Any]) -> str:
37-
template_loader = jinja2.FileSystemLoader(searchpath=str(self._template_directory))
38-
env = jinja2.Environment(loader=template_loader, trim_blocks=True, lstrip_blocks=True, enable_async=True)
39-
env.filters.update(self._filters)
40+
def get_environment(self) -> jinja2.Environment:
41+
if self._environment:
42+
return self._environment
43+
44+
if self.is_string_based:
45+
return self._get_string_based_environment()
46+
47+
return self._get_file_based_environment()
48+
49+
def get_template(self) -> jinja2.Template:
50+
if self._template_definition:
51+
return self._template_definition
4052

4153
try:
42-
jinja2_template = env.get_template(str(template))
54+
if self.is_string_based:
55+
template = self._get_string_based_template()
56+
else:
57+
template = self._get_file_based_template()
4358
except jinja2.TemplateSyntaxError as exc:
4459
self._raise_template_syntax_error(error=exc)
4560
except jinja2.TemplateNotFound as exc:
4661
raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name))
47-
return await self._render(template=jinja2_template, variables=variables)
4862

49-
async def render_from_string(self, template: str, variables: dict[str, Any]) -> str:
50-
env = SandboxedEnvironment(enable_async=True, undefined=jinja2.StrictUndefined)
51-
env.filters.update(self._filters)
52-
try:
53-
jinja2_template = env.from_string(template)
54-
except jinja2.TemplateSyntaxError as exc:
55-
self._raise_template_syntax_error(error=exc)
63+
return template
64+
65+
def get_variables(self) -> list[str]:
66+
env = self.get_environment()
5667

57-
return await self._render(template=jinja2_template, variables=variables)
68+
template_source = self._template
69+
if self.is_file_based and env.loader:
70+
template_source = env.loader.get_source(env, self._template)[0]
5871

59-
async def _render(self, template: jinja2.Template, variables: dict[str, Any]) -> str:
72+
template = env.parse(template_source)
73+
return sorted(meta.find_undeclared_variables(template))
74+
75+
async def render(self, variables: dict[str, Any]) -> str:
76+
template = self.get_template()
6077
try:
6178
output = await template.render_async(variables)
6279
except jinja2.exceptions.TemplateNotFound as exc:
@@ -76,6 +93,29 @@ async def _render(self, template: jinja2.Template, variables: dict[str, Any]) ->
7693

7794
return output
7895

96+
def _get_string_based_environment(self) -> jinja2.Environment:
97+
env = SandboxedEnvironment(enable_async=True, undefined=jinja2.StrictUndefined)
98+
env.filters.update(self._filters)
99+
self._environment = env
100+
return self._environment
101+
102+
def _get_file_based_environment(self) -> jinja2.Environment:
103+
template_loader = jinja2.FileSystemLoader(searchpath=str(self._template_directory))
104+
env = jinja2.Environment(loader=template_loader, trim_blocks=True, lstrip_blocks=True, enable_async=True)
105+
env.filters.update(self._filters)
106+
self._environment = env
107+
return self._environment
108+
109+
def _get_string_based_template(self) -> jinja2.Template:
110+
env = self.get_environment()
111+
self._template_definition = env.from_string(self._template)
112+
return self._template_definition
113+
114+
def _get_file_based_template(self) -> jinja2.Template:
115+
env = self.get_environment()
116+
self._template_definition = env.get_template(self._template)
117+
return self._template_definition
118+
79119
def _raise_template_syntax_error(self, error: jinja2.TemplateSyntaxError) -> NoReturn:
80120
filename: str | None = None
81121
if error.filename and self._template_directory:
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
<html>
2+
<body>
3+
<ul>
4+
{% for server in servers %}
5+
<li>{{server.name}}: {{ server.ip.primary }}</li>
6+
{% endfor %}
7+
</ul>
8+
9+
</body>
10+
11+
</html>

tests/unit/sdk/test_template.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import dataclass
1+
from dataclasses import dataclass, field
22
from pathlib import Path
33
from typing import Any
44

@@ -25,6 +25,7 @@ class JinjaTestCase:
2525
template: str
2626
variables: dict[str, Any]
2727
expected: str
28+
expected_variables: list[str] = field(default_factory=list)
2829

2930

3031
@dataclass
@@ -41,24 +42,28 @@ class JinjaTestCaseFailing:
4142
template="Hello {{ name }}",
4243
variables={"name": "Infrahub"},
4344
expected="Hello Infrahub",
45+
expected_variables=["name"],
4446
),
4547
JinjaTestCase(
4648
name="hello-if-defined",
4749
template="Hello {% if name is undefined %}stranger{% else %}{{name}}{% endif %}",
4850
variables={"name": "OpsMill"},
4951
expected="Hello OpsMill",
52+
expected_variables=["name"],
5053
),
5154
JinjaTestCase(
5255
name="hello-if-undefined",
5356
template="Hello {% if name is undefined %}stranger{% else %}{{name}}{% endif %}",
5457
variables={},
5558
expected="Hello stranger",
59+
expected_variables=["name"],
5660
),
5761
JinjaTestCase(
5862
name="netutils-ip-addition",
5963
template="IP={{ ip_address|ip_addition(200) }}",
6064
variables={"ip_address": "192.168.12.15"},
6165
expected="IP=192.168.12.215",
66+
expected_variables=["ip_address"],
6267
),
6368
]
6469

@@ -68,10 +73,9 @@ class JinjaTestCaseFailing:
6873
[pytest.param(tc, id=tc.name) for tc in SUCCESSFUL_STRING_TEST_CASES],
6974
)
7075
async def test_render_string(test_case: JinjaTestCase) -> None:
71-
jinja = Jinja2Template()
72-
assert test_case.expected == await jinja.render_from_string(
73-
template=test_case.template, variables=test_case.variables
74-
)
76+
jinja = Jinja2Template(template=test_case.template)
77+
assert test_case.expected == await jinja.render(variables=test_case.variables)
78+
assert test_case.expected_variables == jinja.get_variables()
7579

7680

7781
SUCCESSFUL_FILE_TEST_CASES = [
@@ -80,12 +84,14 @@ async def test_render_string(test_case: JinjaTestCase) -> None:
8084
template="hello-world.j2",
8185
variables={"name": "Infrahub"},
8286
expected="Hello Infrahub",
87+
expected_variables=["name"],
8388
),
8489
JinjaTestCase(
8590
name="netutils-convert-address",
8691
template="ip_report.j2",
8792
variables={"address": "192.168.18.40/255.255.255.0"},
8893
expected="IP Address: 192.168.18.40/24",
94+
expected_variables=["address"],
8995
),
9096
]
9197

@@ -95,10 +101,9 @@ async def test_render_string(test_case: JinjaTestCase) -> None:
95101
[pytest.param(tc, id=tc.name) for tc in SUCCESSFUL_FILE_TEST_CASES],
96102
)
97103
async def test_render_template_from_file(test_case: JinjaTestCase) -> None:
98-
jinja = Jinja2Template(template_directory=TEMPLATE_DIRECTORY)
99-
assert test_case.expected == await jinja.render_from_file(
100-
template=Path(test_case.template), variables=test_case.variables
101-
)
104+
jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
105+
assert test_case.expected == await jinja.render(variables=test_case.variables)
106+
assert test_case.expected_variables == jinja.get_variables()
102107

103108

104109
FAILING_STRING_TEST_CASES = [
@@ -142,9 +147,9 @@ async def test_render_template_from_file(test_case: JinjaTestCase) -> None:
142147
[pytest.param(tc, id=tc.name) for tc in FAILING_STRING_TEST_CASES],
143148
)
144149
async def test_render_string_errors(test_case: JinjaTestCaseFailing) -> None:
145-
jinja = Jinja2Template(template_directory=TEMPLATE_DIRECTORY)
150+
jinja = Jinja2Template(template=test_case.template, template_directory=TEMPLATE_DIRECTORY)
146151
with pytest.raises(test_case.error.__class__) as exc:
147-
await jinja.render_from_string(template=test_case.template, variables=test_case.variables)
152+
await jinja.render(variables=test_case.variables)
148153

149154
_compare_errors(expected=test_case.error, received=exc.value)
150155

@@ -189,6 +194,23 @@ async def test_render_string_errors(test_case: JinjaTestCaseFailing) -> None:
189194
base_template="imports-missing-file.html",
190195
),
191196
),
197+
JinjaTestCaseFailing(
198+
name="invalid-variable-input",
199+
template="report.html",
200+
variables={"servers": [{"name": "server1", "ip": {"primary": "172.18.12.1"}}, {"name": "server1"}]},
201+
error=JinjaTemplateUndefinedError(
202+
message="'dict object' has no attribute 'ip'",
203+
errors=[
204+
UndefinedJinja2Error(
205+
frame=Frame(filename=f"{TEMPLATE_DIRECTORY}/report.html", lineno=5, name="top-level template code"),
206+
syntax=Syntax(
207+
code="<html>\n<body>\n<ul>\n{% for server in servers %}\n <li>{{server.name}}: {{ server.ip.primary }}</li>\n{% endfor %}\n</ul>\n\n</body>\n\n</html>\n", # noqa E501
208+
lexer="",
209+
),
210+
)
211+
],
212+
),
213+
),
192214
]
193215

194216

@@ -197,18 +219,18 @@ async def test_render_string_errors(test_case: JinjaTestCaseFailing) -> None:
197219
[pytest.param(tc, id=tc.name) for tc in FAILING_FILE_TEST_CASES],
198220
)
199221
async def test_manage_file_based_errors(test_case: JinjaTestCaseFailing) -> None:
200-
jinja = Jinja2Template(template_directory=TEMPLATE_DIRECTORY)
222+
jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
201223
with pytest.raises(test_case.error.__class__) as exc:
202-
await jinja.render_from_file(template=Path(test_case.template), variables=test_case.variables)
224+
await jinja.render(variables=test_case.variables)
203225

204226
_compare_errors(expected=test_case.error, received=exc.value)
205227

206228

207229
async def test_manage_unhandled_error() -> None:
208-
jinja = Jinja2Template()
230+
jinja = Jinja2Template(template="Hello {{ number | divide_by_zero }}")
209231
jinja._filters["divide_by_zero"] = _divide_by_zero
210232
with pytest.raises(JinjaTemplateError) as exc:
211-
await jinja.render_from_string(template="Hello {{ number | divide_by_zero }}", variables={"number": 1})
233+
await jinja.render(variables={"number": 1})
212234

213235
assert exc.value.message == "division by zero"
214236

0 commit comments

Comments
 (0)