Skip to content

Commit 62db6f0

Browse files
committed
Add sync option for jinja2 templates and set it as default
1 parent f3334a6 commit 62db6f0

File tree

4 files changed

+132
-59
lines changed

4 files changed

+132
-59
lines changed

infrahub_sdk/ctl/cli_commands.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from ..ctl.validate import app as validate_app
4444
from ..exceptions import GraphQLError, ModuleImportError
4545
from ..schema import MainSchemaTypesAll, SchemaRoot
46-
from ..template import Jinja2Template
46+
from ..template import Jinja2TemplateSync
4747
from ..template.exceptions import JinjaTemplateError
4848
from ..utils import get_branch, write_to_file
4949
from ..yaml import SchemaFile
@@ -178,9 +178,9 @@ async def run(
178178

179179
async def render_jinja2_template(template_path: Path, variables: dict[str, Any], data: dict[str, Any]) -> str:
180180
variables["data"] = data
181-
jinja_template = Jinja2Template(template=Path(template_path), template_directory=Path())
181+
jinja_template = Jinja2TemplateSync(template=Path(template_path), template_directory=Path())
182182
try:
183-
rendered_tpl = await jinja_template.render(variables=variables)
183+
rendered_tpl = jinja_template.render(variables=variables)
184184
except JinjaTemplateError as exc:
185185
print_template_errors(error=exc, console=console)
186186
raise typer.Exit(1) from exc

infrahub_sdk/pytest_plugin/items/jinja2_transform.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import asyncio
43
import difflib
54
from pathlib import Path
65
from typing import TYPE_CHECKING, Any
@@ -9,7 +8,7 @@
98
import ujson
109
from httpx import HTTPStatusError
1110

12-
from ...template import Jinja2Template
11+
from ...template import Jinja2TemplateSync
1312
from ...template.exceptions import JinjaTemplateError
1413
from ..exceptions import OutputMatchError
1514
from ..models import InfrahubInputOutputTest, InfrahubTestExpectedResult
@@ -20,8 +19,8 @@
2019

2120

2221
class InfrahubJinja2Item(InfrahubItem):
23-
def _get_jinja2(self) -> Jinja2Template:
24-
return Jinja2Template(
22+
def _get_jinja2(self) -> Jinja2TemplateSync:
23+
return Jinja2TemplateSync(
2524
template=Path(self.resource_config.template_path), # type: ignore[attr-defined]
2625
template_directory=Path(self.session.infrahub_config_path.parent), # type: ignore[attr-defined]
2726
)
@@ -38,7 +37,7 @@ def render_jinja2_template(self, variables: dict[str, Any]) -> str | None:
3837
jinja2_template = self._get_jinja2()
3938

4039
try:
41-
return asyncio.run(jinja2_template.render(variables=variables))
40+
return jinja2_template.render(variables=variables)
4241
except JinjaTemplateError as exc:
4342
if self.test.expect == InfrahubTestExpectedResult.PASS:
4443
raise exc

infrahub_sdk/template/__init__.py

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import linecache
44
from pathlib import Path
5-
from typing import Any, Callable, NoReturn
5+
from typing import Any, Callable, ClassVar, NoReturn
66

77
import jinja2
88
from jinja2 import meta, nodes
@@ -24,7 +24,9 @@
2424
netutils_filters = jinja2_convenience_function()
2525

2626

27-
class Jinja2Template:
27+
class Jinja2TemplateBase:
28+
_is_async: ClassVar[bool] = True
29+
2830
def __init__(
2931
self,
3032
template: str | Path,
@@ -106,29 +108,8 @@ def validate(self, restricted: bool = True) -> None:
106108
f"These operations are forbidden for string based templates: {forbidden_operations}"
107109
)
108110

109-
async def render(self, variables: dict[str, Any]) -> str:
110-
template = self.get_template()
111-
try:
112-
output = await template.render_async(variables)
113-
except jinja2.exceptions.TemplateNotFound as exc:
114-
raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name), base_template=template.name)
115-
except jinja2.TemplateSyntaxError as exc:
116-
self._raise_template_syntax_error(error=exc)
117-
except jinja2.UndefinedError as exc:
118-
traceback = Traceback(show_locals=False)
119-
errors = _identify_faulty_jinja_code(traceback=traceback)
120-
raise JinjaTemplateUndefinedError(message=exc.message, errors=errors)
121-
except Exception as exc:
122-
if error_message := getattr(exc, "message", None):
123-
message = error_message
124-
else:
125-
message = str(exc)
126-
raise JinjaTemplateError(message=message or "Unknown template error")
127-
128-
return output
129-
130111
def _get_string_based_environment(self) -> jinja2.Environment:
131-
env = SandboxedEnvironment(enable_async=True, undefined=jinja2.StrictUndefined)
112+
env = SandboxedEnvironment(enable_async=self._is_async, undefined=jinja2.StrictUndefined)
132113
self._set_filters(env=env)
133114
self._environment = env
134115
return self._environment
@@ -139,7 +120,7 @@ def _get_file_based_environment(self) -> jinja2.Environment:
139120
loader=template_loader,
140121
trim_blocks=True,
141122
lstrip_blocks=True,
142-
enable_async=True,
123+
enable_async=self._is_async,
143124
)
144125
self._set_filters(env=env)
145126
self._environment = env
@@ -177,6 +158,54 @@ def _raise_template_syntax_error(self, error: jinja2.TemplateSyntaxError) -> NoR
177158
raise JinjaTemplateSyntaxError(message=error.message, filename=filename, lineno=error.lineno)
178159

179160

161+
class Jinja2Template(Jinja2TemplateBase):
162+
async def render(self, variables: dict[str, Any]) -> str:
163+
template = self.get_template()
164+
try:
165+
output = await template.render_async(variables)
166+
except jinja2.exceptions.TemplateNotFound as exc:
167+
raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name), base_template=template.name)
168+
except jinja2.TemplateSyntaxError as exc:
169+
self._raise_template_syntax_error(error=exc)
170+
except jinja2.UndefinedError as exc:
171+
traceback = Traceback(show_locals=False)
172+
errors = _identify_faulty_jinja_code(traceback=traceback)
173+
raise JinjaTemplateUndefinedError(message=exc.message, errors=errors)
174+
except Exception as exc:
175+
if error_message := getattr(exc, "message", None):
176+
message = error_message
177+
else:
178+
message = str(exc)
179+
raise JinjaTemplateError(message=message or "Unknown template error")
180+
181+
return output
182+
183+
184+
class Jinja2TemplateSync(Jinja2TemplateBase):
185+
_is_async: ClassVar[bool] = False
186+
187+
def render(self, variables: dict[str, Any]) -> str:
188+
template = self.get_template()
189+
try:
190+
output = template.render(variables)
191+
except jinja2.exceptions.TemplateNotFound as exc:
192+
raise JinjaTemplateNotFoundError(message=exc.message, filename=str(exc.name), base_template=template.name)
193+
except jinja2.TemplateSyntaxError as exc:
194+
self._raise_template_syntax_error(error=exc)
195+
except jinja2.UndefinedError as exc:
196+
traceback = Traceback(show_locals=False)
197+
errors = _identify_faulty_jinja_code(traceback=traceback)
198+
raise JinjaTemplateUndefinedError(message=exc.message, errors=errors)
199+
except Exception as exc:
200+
if error_message := getattr(exc, "message", None):
201+
message = error_message
202+
else:
203+
message = str(exc)
204+
raise JinjaTemplateError(message=message or "Unknown template error")
205+
206+
return output
207+
208+
180209
def _identify_faulty_jinja_code(traceback: Traceback, nbr_context_lines: int = 3) -> list[UndefinedJinja2Error]:
181210
"""This function identifies the faulty Jinja2 code and beautify it to provide meaningful information to the user.
182211

tests/unit/sdk/test_template.py

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from rich.syntax import Syntax
77
from rich.traceback import Frame
88

9-
from infrahub_sdk.template import Jinja2Template
9+
from infrahub_sdk.template import Jinja2Template, Jinja2TemplateSync
1010
from infrahub_sdk.template.exceptions import (
1111
JinjaTemplateError,
1212
JinjaTemplateNotFoundError,
@@ -78,9 +78,15 @@ class JinjaTestCaseFailing:
7878
"test_case",
7979
[pytest.param(tc, id=tc.name) for tc in SUCCESSFUL_STRING_TEST_CASES],
8080
)
81-
async def test_render_string(test_case: JinjaTestCase) -> None:
82-
jinja = Jinja2Template(template=test_case.template)
83-
assert test_case.expected == await jinja.render(variables=test_case.variables)
81+
@pytest.mark.parametrize("is_async", [True, False])
82+
async def test_render_string(test_case: JinjaTestCase, is_async: bool) -> None:
83+
if is_async:
84+
jinja = Jinja2Template(template=test_case.template)
85+
assert test_case.expected == await jinja.render(variables=test_case.variables)
86+
else:
87+
jinja = Jinja2TemplateSync(template=test_case.template)
88+
assert test_case.expected == jinja.render(variables=test_case.variables)
89+
8490
assert test_case.expected_variables == jinja.get_variables()
8591

8692

@@ -106,9 +112,14 @@ async def test_render_string(test_case: JinjaTestCase) -> None:
106112
"test_case",
107113
[pytest.param(tc, id=tc.name) for tc in SUCCESSFUL_FILE_TEST_CASES],
108114
)
109-
async def test_render_template_from_file(test_case: JinjaTestCase) -> None:
110-
jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
111-
assert test_case.expected == await jinja.render(variables=test_case.variables)
115+
@pytest.mark.parametrize("is_async", [True, False])
116+
async def test_render_template_from_file(test_case: JinjaTestCase, is_async: bool) -> None:
117+
if is_async:
118+
jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
119+
assert test_case.expected == await jinja.render(variables=test_case.variables)
120+
else:
121+
jinja = Jinja2TemplateSync(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
122+
assert test_case.expected == jinja.render(variables=test_case.variables)
112123
assert test_case.expected_variables == jinja.get_variables()
113124
assert jinja.get_template()
114125

@@ -153,10 +164,16 @@ async def test_render_template_from_file(test_case: JinjaTestCase) -> None:
153164
"test_case",
154165
[pytest.param(tc, id=tc.name) for tc in FAILING_STRING_TEST_CASES],
155166
)
156-
async def test_render_string_errors(test_case: JinjaTestCaseFailing) -> None:
157-
jinja = Jinja2Template(template=test_case.template, template_directory=TEMPLATE_DIRECTORY)
158-
with pytest.raises(test_case.error.__class__) as exc:
159-
await jinja.render(variables=test_case.variables)
167+
@pytest.mark.parametrize("is_async", [True, False])
168+
async def test_render_string_errors(test_case: JinjaTestCaseFailing, is_async: bool) -> None:
169+
if is_async:
170+
jinja = Jinja2Template(template=test_case.template, template_directory=TEMPLATE_DIRECTORY)
171+
with pytest.raises(test_case.error.__class__) as exc:
172+
await jinja.render(variables=test_case.variables)
173+
else:
174+
jinja = Jinja2TemplateSync(template=test_case.template, template_directory=TEMPLATE_DIRECTORY)
175+
with pytest.raises(test_case.error.__class__) as exc:
176+
jinja.render(variables=test_case.variables)
160177

161178
_compare_errors(expected=test_case.error, received=exc.value)
162179

@@ -234,36 +251,64 @@ async def test_render_string_errors(test_case: JinjaTestCaseFailing) -> None:
234251
"test_case",
235252
[pytest.param(tc, id=tc.name) for tc in FAILING_FILE_TEST_CASES],
236253
)
237-
async def test_manage_file_based_errors(test_case: JinjaTestCaseFailing) -> None:
238-
jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
239-
with pytest.raises(test_case.error.__class__) as exc:
240-
await jinja.render(variables=test_case.variables)
254+
@pytest.mark.parametrize("is_async", [True, False])
255+
async def test_manage_file_based_errors(test_case: JinjaTestCaseFailing, is_async: bool) -> None:
256+
if is_async:
257+
jinja = Jinja2Template(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
258+
with pytest.raises(test_case.error.__class__) as exc:
259+
await jinja.render(variables=test_case.variables)
260+
else:
261+
jinja = Jinja2TemplateSync(template=Path(test_case.template), template_directory=TEMPLATE_DIRECTORY)
262+
with pytest.raises(test_case.error.__class__) as exc:
263+
jinja.render(variables=test_case.variables)
241264

242265
_compare_errors(expected=test_case.error, received=exc.value)
243266

244267

245-
async def test_manage_unhandled_error() -> None:
246-
jinja = Jinja2Template(
247-
template="Hello {{ number | divide_by_zero }}",
248-
filters={"divide_by_zero": _divide_by_zero},
249-
)
250-
with pytest.raises(JinjaTemplateError) as exc:
251-
await jinja.render(variables={"number": 1})
268+
@pytest.mark.parametrize("is_async", [True, False])
269+
async def test_manage_unhandled_error(is_async: bool) -> None:
270+
template = "Hello {{ number | divide_by_zero }}"
271+
filters = {"divide_by_zero": _divide_by_zero}
272+
if is_async:
273+
jinja = Jinja2Template(
274+
template=template,
275+
filters=filters,
276+
)
277+
with pytest.raises(JinjaTemplateError) as exc:
278+
await jinja.render(variables={"number": 1})
279+
else:
280+
jinja = Jinja2TemplateSync(
281+
template=template,
282+
filters=filters,
283+
)
284+
with pytest.raises(JinjaTemplateError) as exc:
285+
jinja.render(variables={"number": 1})
252286

253287
assert exc.value.message == "division by zero"
254288

255289

256-
async def test_validate_filter() -> None:
257-
jinja = Jinja2Template(template="{{ network | get_all_host }}")
290+
@pytest.mark.parametrize("is_async", [True, False])
291+
async def test_validate_filter(is_async: bool) -> None:
292+
template = "{{ network | get_all_host }}"
293+
if is_async:
294+
jinja = Jinja2Template(template=template)
295+
else:
296+
jinja = Jinja2TemplateSync(template=template)
297+
258298
jinja.validate(restricted=False)
259299
with pytest.raises(JinjaTemplateOperationViolationError) as exc:
260300
jinja.validate(restricted=True)
261301

262302
assert exc.value.message == "The 'get_all_host' filter isn't allowed to be used"
263303

264304

265-
async def test_validate_operation() -> None:
266-
jinja = Jinja2Template(template="Hello {% include 'very-forbidden.j2' %}")
305+
@pytest.mark.parametrize("is_async", [True, False])
306+
async def test_validate_operation(is_async: bool) -> None:
307+
if is_async:
308+
jinja = Jinja2Template(template="Hello {% include 'very-forbidden.j2' %}")
309+
else:
310+
jinja = Jinja2TemplateSync(template="Hello {% include 'very-forbidden.j2' %}")
311+
267312
with pytest.raises(JinjaTemplateOperationViolationError) as exc:
268313
jinja.validate(restricted=True)
269314

0 commit comments

Comments
 (0)