Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion src/llm/servable_initializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,62 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr<GenAiServ
import json
from pathlib import Path

global contextmanager
from contextlib import contextmanager

global jinja2
import jinja2
global ImmutableSandboxedEnvironment
from jinja2.sandbox import ImmutableSandboxedEnvironment
from jinja2.ext import Extension

def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)

# Following the logic from:
# https://github.com/huggingface/transformers/blob/7188e2e28c6d663284634732564143b820a03f8b/src/transformers/utils/chat_template_utils.py#L398
class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}

def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
self._rendered_blocks = None
self._generation_indices = None

def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)

@jinja2.pass_eval_context
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
rv = caller()
if self.is_active():
# Only track generation indices if the tracker is active
start_index = len("".join(self._rendered_blocks))
end_index = start_index + len(rv)
self._generation_indices.append((start_index, end_index))
return rv

def is_active(self) -> bool:
return self._rendered_blocks or self._generation_indices

@contextmanager
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
try:
if self.is_active():
raise ValueError("AssistantTracker should not be reused before closed")
self._rendered_blocks = rendered_blocks
self._generation_indices = generation_indices

yield
finally:
self._rendered_blocks = None
self._generation_indices = None


# Default chat template accepts only single message and outputs only it's 'content'
# effectively turning it into a regular prompt.
Expand All @@ -83,7 +133,7 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr<GenAiServ
# Try to read template from template.jinja file
jinja_file = Path(templates_directory + "/template.jinja")
template_loader = jinja2.FileSystemLoader(searchpath=templates_directory)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, loader=template_loader)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols], loader=template_loader)
jinja_env.policies["json.dumps_kwargs"]["ensure_ascii"] = False
jinja_env.globals["raise_exception"] = raise_exception
if jinja_file.is_file():
Expand Down