diff --git a/src/llm/servable_initializer.cpp b/src/llm/servable_initializer.cpp index 431e5a54f8..edcbcb3a67 100644 --- a/src/llm/servable_initializer.cpp +++ b/src/llm/servable_initializer.cpp @@ -61,12 +61,62 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr jinja2.nodes.CallBlock: + lineno = next(parser.stream).lineno + body = parser.parse_statements(["name:endgeneration"], drop_needle=True) + return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) + + @jinja2.pass_eval_context + def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str: + rv = caller() + if self.is_active(): + # Only track generation indices if the tracker is active + start_index = len("".join(self._rendered_blocks)) + end_index = start_index + len(rv) + self._generation_indices.append((start_index, end_index)) + return rv + + def is_active(self) -> bool: + return self._rendered_blocks or self._generation_indices + + @contextmanager + def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]): + try: + if self.is_active(): + raise ValueError("AssistantTracker should not be reused before closed") + self._rendered_blocks = rendered_blocks + self._generation_indices = generation_indices + + yield + finally: + self._rendered_blocks = None + self._generation_indices = None + # Default chat template accepts only single message and outputs only it's 'content' # effectively turning it into a regular prompt. @@ -83,7 +133,7 @@ void GenAiServableInitializer::loadPyTemplateProcessor(std::shared_ptr