@@ -61,12 +61,62 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr<GenAiServ
61
61
import json
62
62
from pathlib import Path
63
63
64
+ global contextmanager
65
+ from contextlib import contextmanager
66
+
64
67
global jinja2
65
68
import jinja2
69
+ global ImmutableSandboxedEnvironment
66
70
from jinja2.sandbox import ImmutableSandboxedEnvironment
71
+ from jinja2.ext import Extension
67
72
68
73
def raise_exception(message):
69
74
raise jinja2.exceptions.TemplateError(message)
75
+
76
+ # Following the logic from:
77
+ # https://github.com/huggingface/transformers/blob/7188e2e28c6d663284634732564143b820a03f8b/src/transformers/utils/chat_template_utils.py#L398
78
+ class AssistantTracker(Extension):
79
+ # This extension is used to track the indices of assistant-generated tokens in the rendered chat
80
+ tags = {"generation"}
81
+
82
+ def __init__(self, environment: ImmutableSandboxedEnvironment):
83
+ # The class is only initiated by jinja.
84
+ super().__init__(environment)
85
+ environment.extend(activate_tracker=self.activate_tracker)
86
+ self._rendered_blocks = None
87
+ self._generation_indices = None
88
+
89
+ def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
90
+ lineno = next(parser.stream).lineno
91
+ body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
92
+ return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
93
+
94
+ @jinja2.pass_eval_context
95
+ def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
96
+ rv = caller()
97
+ if self.is_active():
98
+ # Only track generation indices if the tracker is active
99
+ start_index = len("".join(self._rendered_blocks))
100
+ end_index = start_index + len(rv)
101
+ self._generation_indices.append((start_index, end_index))
102
+ return rv
103
+
104
+ def is_active(self) -> bool:
105
+ return self._rendered_blocks or self._generation_indices
106
+
107
+ @contextmanager
108
+ def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
109
+ try:
110
+ if self.is_active():
111
+ raise ValueError("AssistantTracker should not be reused before closed")
112
+ self._rendered_blocks = rendered_blocks
113
+ self._generation_indices = generation_indices
114
+
115
+ yield
116
+ finally:
117
+ self._rendered_blocks = None
118
+ self._generation_indices = None
119
+
70
120
71
121
# Default chat template accepts only single message and outputs only it's 'content'
72
122
# effectively turning it into a regular prompt.
@@ -83,7 +133,7 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr<GenAiServ
83
133
# Try to read template from template.jinja file
84
134
jinja_file = Path(templates_directory + "/template.jinja")
85
135
template_loader = jinja2.FileSystemLoader(searchpath=templates_directory)
86
- jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, loader=template_loader)
136
+ jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols], loader=template_loader)
87
137
jinja_env.policies["json.dumps_kwargs"]["ensure_ascii"] = False
88
138
jinja_env.globals["raise_exception"] = raise_exception
89
139
if jinja_file.is_file():
0 commit comments