diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index ea8d07feb..312157e63 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -9,6 +9,7 @@ import string from contextlib import ExitStack +from datetime import datetime from typing import ( Any, Dict, @@ -23,6 +24,7 @@ ) import jinja2 +import jinja2.ext as jinja2_ext from jinja2.sandbox import ImmutableSandboxedEnvironment import numpy as np @@ -207,11 +209,14 @@ def __init__( set(stop_token_ids) if stop_token_ids is not None else None ) - self._environment = ImmutableSandboxedEnvironment( - loader=jinja2.BaseLoader(), + environment = ImmutableSandboxedEnvironment( trim_blocks=True, lstrip_blocks=True, - ).from_string(self.template) + extensions=[jinja2_ext.loopcontrols], + ) + environment.filters["tojson"] = lambda x, indent=None, separators=None, sort_keys=False, ensure_ascii=False: json.dumps(x, indent=indent, separators=separators, sort_keys=sort_keys, ensure_ascii=ensure_ascii) + environment.globals["strftime_now"] = lambda format: datetime.now().strftime(format) + self._environment = environment.from_string(self.template) def __call__( self,