Skip to content

Commit f88f031

Browse files
SandboxedEnvironment in Jinja template (#456)
Signed-off-by: Abhishek <[email protected]>
1 parent 59a72cd commit f88f031

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

tests/data/test_data_handlers.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,23 @@ def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
103103
)
104104

105105

106-
def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys():
106+
@pytest.mark.parametrize(
107+
"template",
108+
[
109+
"### Input: {{ not found }} \n\n ### Response: {{ text_label }}",
110+
"### Input: }} Tweet text {{ \n\n ### Response: {{ text_label }}",
111+
"### Input: {{ Tweet text }} \n\n ### Response: {{ ''.__class__ }}",
112+
"### Input: {{ Tweet text }} \n\n ### Response: {{ undefined_variable.split() }}",
113+
],
114+
)
115+
def test_apply_custom_formatting_jinja_template_gives_error_with_wrong_keys(template):
107116
"""Tests that the jinja formatting function will throw error if wrong keys are passed to template"""
108117
json_dataset = datasets.load_dataset(
109118
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL
110119
)
111-
template = "### Input: {{not found}} \n\n ### Response: {{text_label}}"
112120
formatted_dataset_field = "formatted_data_field"
113121
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
114-
with pytest.raises(KeyError):
122+
with pytest.raises((KeyError, ValueError)):
115123
json_dataset.map(
116124
apply_custom_data_formatting_jinja_template,
117125
fn_kwargs={

tuning/data/data_handlers.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
import re
2020

2121
# Third Party
22-
from jinja2 import Environment, StrictUndefined
22+
from jinja2 import StrictUndefined, TemplateSyntaxError, UndefinedError
23+
from jinja2.sandbox import SandboxedEnvironment, SecurityError
2324
from transformers import AutoTokenizer
2425

2526
# Local
@@ -165,13 +166,29 @@ def apply_custom_data_formatting_jinja_template(
165166

166167
template += tokenizer.eos_token
167168
template = process_jinja_placeholders(template)
168-
env = Environment(undefined=StrictUndefined)
169-
jinja_template = env.from_string(template)
169+
env = SandboxedEnvironment(undefined=StrictUndefined)
170+
171+
try:
172+
jinja_template = env.from_string(template)
173+
except TemplateSyntaxError as e:
174+
raise ValueError(
175+
f"Invalid template syntax in provided Jinja template. {e.message}"
176+
) from e
170177

171178
try:
172179
rendered_text = jinja_template.render(element=element, **element)
180+
except UndefinedError as e:
181+
raise KeyError(
182+
f"The dataset does not contain the key used in the provided Jinja template. {e.message}"
183+
) from e
184+
except SecurityError as e:
185+
raise ValueError(
186+
f"Unsafe operation detected in the provided Jinja template. {e.message}"
187+
) from e
173188
except Exception as e:
174-
raise KeyError(f"Dataset does not contain field in template. {e}") from e
189+
raise ValueError(
190+
f"Error occurred while rendering the provided Jinja template. {e.message}"
191+
) from e
175192

176193
return {dataset_text_field: rendered_text}
177194

0 commit comments

Comments
 (0)