Skip to content

Commit 53126f6

Browse files
committed
added jinja2 extension class to enable Phi4-reasoning loading (#3564)
* added jinja2 extension class to enable Phi4-reasoning loading * Added comment about the source of the new class
1 parent 8dc01c4 commit 53126f6

File tree

1 file changed

+51
-1
lines changed

1 file changed

+51
-1
lines changed

src/llm/servable_initializer.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,62 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr<GenAiServ
6161
import json
6262
from pathlib import Path
6363
64+
global contextmanager
65+
from contextlib import contextmanager
66+
6467
global jinja2
6568
import jinja2
69+
global ImmutableSandboxedEnvironment
6670
from jinja2.sandbox import ImmutableSandboxedEnvironment
71+
from jinja2.ext import Extension
6772
6873
def raise_exception(message):
6974
raise jinja2.exceptions.TemplateError(message)
75+
76+
# Following the logic from:
77+
# https://github.com/huggingface/transformers/blob/7188e2e28c6d663284634732564143b820a03f8b/src/transformers/utils/chat_template_utils.py#L398
78+
class AssistantTracker(Extension):
79+
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
80+
tags = {"generation"}
81+
82+
def __init__(self, environment: ImmutableSandboxedEnvironment):
83+
# The class is only initiated by jinja.
84+
super().__init__(environment)
85+
environment.extend(activate_tracker=self.activate_tracker)
86+
self._rendered_blocks = None
87+
self._generation_indices = None
88+
89+
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
90+
lineno = next(parser.stream).lineno
91+
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
92+
return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
93+
94+
@jinja2.pass_eval_context
95+
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
96+
rv = caller()
97+
if self.is_active():
98+
# Only track generation indices if the tracker is active
99+
start_index = len("".join(self._rendered_blocks))
100+
end_index = start_index + len(rv)
101+
self._generation_indices.append((start_index, end_index))
102+
return rv
103+
104+
def is_active(self) -> bool:
105+
return self._rendered_blocks or self._generation_indices
106+
107+
@contextmanager
108+
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
109+
try:
110+
if self.is_active():
111+
raise ValueError("AssistantTracker should not be reused before closed")
112+
self._rendered_blocks = rendered_blocks
113+
self._generation_indices = generation_indices
114+
115+
yield
116+
finally:
117+
self._rendered_blocks = None
118+
self._generation_indices = None
119+
70120
71121
# Default chat template accepts only single message and outputs only it's 'content'
72122
# effectively turning it into a regular prompt.
@@ -83,7 +133,7 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr<GenAiServ
83133
# Try to read template from template.jinja file
84134
jinja_file = Path(templates_directory + "/template.jinja")
85135
template_loader = jinja2.FileSystemLoader(searchpath=templates_directory)
86-
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, loader=template_loader)
136+
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols], loader=template_loader)
87137
jinja_env.policies["json.dumps_kwargs"]["ensure_ascii"] = False
88138
jinja_env.globals["raise_exception"] = raise_exception
89139
if jinja_file.is_file():

0 commit comments

Comments
 (0)