diff --git a/docs/features/models/llamacpp.md b/docs/features/models/llamacpp.md index 0f1717e24..867d325c5 100644 --- a/docs/features/models/llamacpp.md +++ b/docs/features/models/llamacpp.md @@ -15,7 +15,9 @@ Outlines provides an integration with [Llama.cpp](https://github.com/ggerganov/l ## Model Initialization -To load the model, you can use the `from_llamacpp` function. The single argument of the function is a `Llama` model instance from the `llama_cpp` library. Consult the [Llama class API reference](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) for detailed information on how to create a model instance and on the various available parameters. +To load the model, you can use the `from_llamacpp` function. The first argument of the function is a `Llama` model instance from the `llama_cpp` library. Consult the [Llama class API reference](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) for detailed information on how to create a model instance and on the various available parameters. + +You can also pass a `chat_mode` argument to `from_llamacpp`. If `True` (default), the model will regard all `str` inputs as user messages in a chat conversation. If `False`, the model will regard all `str` inputs as plain text prompts. For instance: @@ -31,6 +33,21 @@ model = outlines.from_llamacpp( ) ``` +You can also disable chat mode: + +```python +import outlines +from llama_cpp import Llama + +model = outlines.from_llamacpp( + Llama.from_pretrained( + repo_id="TheBloke/Mistral-7B-Instruct-v0.2-GGUF", + filename="mistral-7b-instruct-v0.2.Q5_K_M.gguf", + ), + chat_mode=False, +) +``` + ## Text Generation To generate text, you can simply call the model with a prompt. diff --git a/outlines/models/llamacpp.py b/outlines/models/llamacpp.py index 8b788a7f9..e931d9c8a 100644 --- a/outlines/models/llamacpp.py +++ b/outlines/models/llamacpp.py @@ -152,6 +152,15 @@ class LlamaCppTypeAdapter(ModelTypeAdapter): """ + def __init__(self, has_chat_template: bool = False): + """ + Parameters + ---------- + has_chat_template + Whether the model has a chat template defined. + """ + self.has_chat_template = has_chat_template + @singledispatchmethod def format_input(self, model_input): """Generate the prompt argument to pass to the model. @@ -173,7 +182,9 @@ def format_input(self, model_input): ) @format_input.register(str) - def format_str_input(self, model_input: str) -> str: + def format_str_input(self, model_input: str) -> str | list: + if self.has_chat_template: + return [{"role": "user", "content": model_input}] return model_input @format_input.register(Chat) @@ -227,17 +238,26 @@ class LlamaCpp(Model): tensor_library_name = "numpy" - def __init__(self, model: "Llama"): + def __init__(self, model: "Llama", chat_mode: bool = True): """ Parameters ---------- model A `llama_cpp.Llama` model instance. + chat_mode + Whether to enable chat mode. If `False`, the model will regard + all `str` inputs as plain text prompts. If `True`, the model will + regard all `str` inputs as user messages in a chat conversation. """ self.model = model self.tokenizer = LlamaCppTokenizer(self.model) - self.type_adapter = LlamaCppTypeAdapter() + + # Note: llama-cpp-python provides a default chat-template fallback even when + # the user hasn't explicitly configured one: + # https://github.com/abetlen/llama-cpp-python/blob/c37132b/llama_cpp/llama.py#L540-L545 + # We keep the default as True because the upstream library generally favors chat-style usage. + self.type_adapter = LlamaCppTypeAdapter(has_chat_template=chat_mode) def generate( self, @@ -273,13 +293,15 @@ def generate( **inference_kwargs, ) result = completion["choices"][0]["text"] - elif isinstance(prompt, list): # pragma: no cover + elif isinstance(prompt, list): completion = self.model.create_chat_completion( prompt, logits_processor=self.type_adapter.format_output_type(output_type), **inference_kwargs, ) result = completion["choices"][0]["message"]["content"] + else: # Never reached # pragma: no cover + raise ValueError("Unexpected prompt type.") self.model.reset() @@ -330,7 +352,7 @@ def generate_stream( for chunk in generator: yield chunk["choices"][0]["text"] - elif isinstance(prompt, list): # pragma: no cover + elif isinstance(prompt, list): generator = self.model.create_chat_completion( prompt, logits_processor=self.type_adapter.format_output_type(output_type), @@ -339,9 +361,10 @@ def generate_stream( ) for chunk in generator: yield chunk["choices"][0]["delta"].get("content", "") + else: # Never reached # pragma: no cover + raise ValueError("Unexpected prompt type.") - -def from_llamacpp(model: "Llama"): +def from_llamacpp(model: "Llama", chat_mode: bool = True) -> LlamaCpp: """Create an Outlines `LlamaCpp` model instance from a `llama_cpp.Llama` instance. @@ -349,6 +372,10 @@ def from_llamacpp(model: "Llama"): ---------- model A `llama_cpp.Llama` instance. + chat_mode + Whether to enable chat mode. If `False`, the model will regard + all `str` inputs as plain text prompts. If `True`, the model will + regard all `str` inputs as user messages in a chat conversation. Returns ------- @@ -356,4 +383,4 @@ def from_llamacpp(model: "Llama"): An Outlines `LlamaCpp` model instance. """ - return LlamaCpp(model) + return LlamaCpp(model, chat_mode=chat_mode) diff --git a/outlines/models/mlxlm.py b/outlines/models/mlxlm.py index 5111eccaa..5f8bed685 100644 --- a/outlines/models/mlxlm.py +++ b/outlines/models/mlxlm.py @@ -5,6 +5,7 @@ from outlines.inputs import Chat from outlines.models.base import Model, ModelTypeAdapter +from outlines.models.tokenizer import _check_hf_chat_template from outlines.models.transformers import TransformerTokenizer from outlines.processors import OutlinesLogitsProcessor @@ -18,8 +19,9 @@ class MLXLMTypeAdapter(ModelTypeAdapter): """Type adapter for the `MLXLM` model.""" - def __init__(self, **kwargs): - self.tokenizer = kwargs.get("tokenizer") + def __init__(self, tokenizer: "PreTrainedTokenizer", has_chat_template: bool = False): + self.tokenizer = tokenizer + self.has_chat_template = has_chat_template @singledispatchmethod def format_input(self, model_input): @@ -42,7 +44,9 @@ def format_input(self, model_input): ) @format_input.register(str) - def format_str_input(self, model_input: str): + def format_str_input(self, model_input: str) -> str: + if self.has_chat_template: + return self.format_chat_input(Chat([{"role": "user", "content": model_input}])) return model_input @format_input.register(Chat) @@ -113,7 +117,10 @@ def __init__( self.mlx_tokenizer = tokenizer # self.tokenizer is used by the logits processor self.tokenizer = TransformerTokenizer(tokenizer._tokenizer) - self.type_adapter = MLXLMTypeAdapter(tokenizer=tokenizer) + self.type_adapter = MLXLMTypeAdapter( + tokenizer=tokenizer, + has_chat_template=_check_hf_chat_template(tokenizer) + ) def generate( self, diff --git a/outlines/models/tokenizer.py b/outlines/models/tokenizer.py index 5463c5eaa..37a918b93 100644 --- a/outlines/models/tokenizer.py +++ b/outlines/models/tokenizer.py @@ -4,6 +4,7 @@ if TYPE_CHECKING: import numpy as np from numpy.typing import NDArray + from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast class Tokenizer(Hashable, Protocol): @@ -31,3 +32,12 @@ def convert_token_to_string(self, token: str) -> str: token that includes `Ġ` with a string. """ ... + + +def _check_hf_chat_template(tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast") -> bool: + """Check if the HuggingFace tokenizer has a chat template.""" + try: + tokenizer.get_chat_template() + return True + except ValueError: + return False diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 42f05553c..1dc911f41 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -8,7 +8,7 @@ from outlines.inputs import Audio, Chat, Image, Video from outlines.models.base import Model, ModelTypeAdapter -from outlines.models.tokenizer import Tokenizer +from outlines.models.tokenizer import Tokenizer, _check_hf_chat_template from outlines.processors import OutlinesLogitsProcessor if TYPE_CHECKING: @@ -136,8 +136,9 @@ def __setstate__(self, state): class TransformersTypeAdapter(ModelTypeAdapter): """Type adapter for the `Transformers` model.""" - def __init__(self, **kwargs): - self.tokenizer = kwargs.get("tokenizer") + def __init__(self, tokenizer: "PreTrainedTokenizer", has_chat_template: bool = False): + self.tokenizer = tokenizer + self.has_chat_template = has_chat_template @singledispatchmethod def format_input(self, model_input): @@ -161,6 +162,8 @@ def format_input(self, model_input): @format_input.register(str) def format_str_input(self, model_input: str) -> str: + if self.has_chat_template: + return self.format_chat_input(Chat([{"role": "user", "content": model_input}])) return model_input @format_input.register(Chat) @@ -243,7 +246,10 @@ def __init__( self.hf_tokenizer = tokenizer self.tokenizer = TransformerTokenizer(tokenizer) self.device_dtype = device_dtype - self.type_adapter = TransformersTypeAdapter(tokenizer=tokenizer) + self.type_adapter = TransformersTypeAdapter( + tokenizer=tokenizer, + has_chat_template=_check_hf_chat_template(tokenizer) + ) if ( FlaxPreTrainedModel is not None diff --git a/outlines/models/vllm_offline.py b/outlines/models/vllm_offline.py index 9d1ea08f3..59d1eb477 100644 --- a/outlines/models/vllm_offline.py +++ b/outlines/models/vllm_offline.py @@ -19,6 +19,9 @@ class VLLMOfflineTypeAdapter(ModelTypeAdapter): """Type adapter for the `VLLMOffline` model.""" + def __init__(self, has_chat_template: bool = False): + self.has_chat_template = has_chat_template + @singledispatchmethod def format_input(self, model_input): """Generate the prompt argument to pass to the model. @@ -36,10 +39,12 @@ def format_input(self, model_input): ) @format_input.register(str) - def format_input_str(self, model_input: str) -> str: + def format_input_str(self, model_input: str) -> str | list: """Format a `str` input. """ + if self.has_chat_template: + return self.format_input_chat(Chat([{"role": "user", "content": model_input}])) return model_input @format_input.register(Chat) @@ -109,7 +114,8 @@ def __init__(self, model: "LLM"): """ self.model = model - self.type_adapter = VLLMOfflineTypeAdapter() + self.tokenizer = self.model.get_tokenizer() + self.type_adapter = VLLMOfflineTypeAdapter(has_chat_template=self._check_chat_template()) def _build_generation_args( self, @@ -163,15 +169,17 @@ def generate( output_type, ) - if isinstance(model_input, Chat): + model_input = self.type_adapter.format_input(model_input) + + if isinstance(model_input, list): results = self.model.chat( - messages=self.type_adapter.format_input(model_input), + messages=model_input, sampling_params=sampling_params, **inference_kwargs, ) else: results = self.model.generate( - prompts=self.type_adapter.format_input(model_input), + prompts=model_input, sampling_params=sampling_params, **inference_kwargs, ) @@ -213,16 +221,20 @@ def generate_batch( output_type, ) - if any(isinstance(item, Chat) for item in model_input): - raise TypeError( - "Batch generation is not available for the `Chat` input type." - ) + model_inputs = [self.type_adapter.format_input(item) for item in model_input] - results = self.model.generate( - prompts=[self.type_adapter.format_input(item) for item in model_input], - sampling_params=sampling_params, - **inference_kwargs, - ) + if model_inputs and isinstance(model_inputs[0], list): + results = self.model.chat( + messages=model_inputs, + sampling_params=sampling_params, + **inference_kwargs, + ) + else: + results = self.model.generate( + prompts=model_inputs, + sampling_params=sampling_params, + **inference_kwargs, + ) return [[sample.text for sample in batch.outputs] for batch in results] def generate_stream(self, model_input, output_type, **inference_kwargs): @@ -235,6 +247,28 @@ def generate_stream(self, model_input, output_type, **inference_kwargs): "Streaming is not available for the vLLM offline integration." ) + def _check_chat_template(self) -> bool: + """Check if the tokenizer has a chat template.""" + from vllm.transformers_utils.tokenizer import ( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + TokenizerBase + ) + from outlines.models.tokenizer import _check_hf_chat_template + + if isinstance(self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + return _check_hf_chat_template(self.tokenizer) + elif isinstance(self.tokenizer, TokenizerBase): + # vLLM defines its own TokenizerBase class, and only provides + # limited compatibility with HuggingFace tokenizers. So we + # need to check for chat template support differently. + try: + self.tokenizer.apply_chat_template([{"role": "user", "content": "test"}]) + return True + except Exception: + return False + else: # Never reached # pragma: no cover + return False def from_vllm_offline(model: "LLM") -> VLLMOffline: """Create an Outlines `VLLMOffline` model instance from a `vllm.LLM` diff --git a/tests/models/test_llamacpp.py b/tests/models/test_llamacpp.py index 776da800b..18c57ce6c 100644 --- a/tests/models/test_llamacpp.py +++ b/tests/models/test_llamacpp.py @@ -41,6 +41,16 @@ def model(tmp_path_factory): ) ) +@pytest.fixture(scope="session") +def model_no_chat(tmp_path_factory): + return LlamaCpp( + Llama.from_pretrained( + repo_id="tensorblock/Llama3-1B-Base-GGUF", + filename="Llama3-1B-Base-Q2_K.gguf", + ), + chat_mode=False + ) + @pytest.fixture def lark_grammar(): return """ @@ -171,8 +181,10 @@ class Foo(BaseModel): generator = model.stream("foo?", Foo) - x = next(generator) - assert x == "{" + # NOTE: The first few chunks may be empty (role info, control tokens, finish chunks) + # Relevant issue: https://github.com/abetlen/llama-cpp-python/issues/372 + first_non_empty_token = next(x for x in generator if x) + assert first_non_empty_token == "{" def test_llamacpp_stream_cfg(model, ebnf_grammar): @@ -204,8 +216,8 @@ class Foo(Enum): generator = model.stream("foo?", Foo) - x = next(generator) - assert x[0] in ("B", "F") + first_non_empty_token = next(x for x in generator if x) + assert first_non_empty_token[0] in ("B", "F") def test_llamacpp_stream_text_stop(model): @@ -221,3 +233,11 @@ def test_llamacpp_batch(model): model.batch( ["Respond with one word.", "Respond with one word."], ) + +def test_llamacpp_no_chat(model_no_chat): + result = model_no_chat.generate("Respond with one word. Not more.", None) + assert isinstance(result, str) + + generator = model_no_chat.stream("Respond with one word. Not more.", None) + for x in generator: + assert isinstance(x, str) diff --git a/tests/models/test_llamacpp_type_adapter.py b/tests/models/test_llamacpp_type_adapter.py index 403b3589c..86f9fdf9f 100644 --- a/tests/models/test_llamacpp_type_adapter.py +++ b/tests/models/test_llamacpp_type_adapter.py @@ -62,6 +62,22 @@ def test_llamacpp_type_adapter_format_input(adapter, image): ])) +def test_llamacpp_type_adapter_format_input_with_chat_template(): + adapter = LlamaCppTypeAdapter(has_chat_template=True) + message = "prompt" + result = adapter.format_input(message) + + assert result == [{"role": "user", "content": "prompt"}] + + +def test_llamacpp_type_adapter_format_input_without_chat_template(): + adapter = LlamaCppTypeAdapter(has_chat_template=False) + message = "prompt" + result = adapter.format_input(message) + + assert result == "prompt" + + def test_llamacpp_type_adapter_format_output_type(adapter, logits_processor): formatted = adapter.format_output_type(logits_processor) assert isinstance(formatted, LogitsProcessorList) diff --git a/tests/models/test_mlxlm_type_adapter.py b/tests/models/test_mlxlm_type_adapter.py index dd2feb16d..152eead01 100644 --- a/tests/models/test_mlxlm_type_adapter.py +++ b/tests/models/test_mlxlm_type_adapter.py @@ -1,5 +1,6 @@ import pytest import io +from unittest.mock import MagicMock from outlines_core import Index, Vocabulary from PIL import Image as PILImage @@ -46,6 +47,34 @@ def image(): return image +def test_mlxlm_type_adapter_format_input_with_template(): + tokenizer = MagicMock() + tokenizer.chat_template = "some_template" + tokenizer.apply_chat_template.return_value = "formatted_prompt" + + adapter = MLXLMTypeAdapter(tokenizer=tokenizer, has_chat_template=True) + message = "prompt" + result = adapter.format_input(message) + + assert result == "formatted_prompt" + tokenizer.apply_chat_template.assert_called_once_with( + [{"role": "user", "content": "prompt"}], + tokenize=False, + add_generation_prompt=True, + ) + + +def test_mlxlm_type_adapter_format_input_without_template(): + tokenizer = MagicMock() + tokenizer.chat_template = None + + adapter = MLXLMTypeAdapter(tokenizer=tokenizer, has_chat_template=False) + message = "prompt" + result = adapter.format_input(message) + + assert result == "prompt" + + @pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") def test_mlxlm_type_adapter_format_input(adapter, image): # Anything else than a string/Chat (invalid) diff --git a/tests/models/test_tokenizer.py b/tests/models/test_tokenizer.py index 831f7fe3e..13ba45f3e 100644 --- a/tests/models/test_tokenizer.py +++ b/tests/models/test_tokenizer.py @@ -1,8 +1,14 @@ import pytest -from outlines.models.tokenizer import Tokenizer +from outlines.models.tokenizer import Tokenizer, _check_hf_chat_template def test_tokenizer(): with pytest.raises(TypeError, match="instantiate abstract"): Tokenizer() + +def test_check_hf_chat_template(): + from transformers import AutoTokenizer + + assert _check_hf_chat_template(AutoTokenizer.from_pretrained("openai-community/gpt2")) is False + assert _check_hf_chat_template(AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")) is True diff --git a/tests/models/test_transformers_type_adapter.py b/tests/models/test_transformers_type_adapter.py index da7bbfcd2..e5318f9f3 100644 --- a/tests/models/test_transformers_type_adapter.py +++ b/tests/models/test_transformers_type_adapter.py @@ -50,10 +50,18 @@ def test_transformers_type_adapter_format_input(adapter, image): with pytest.raises(TypeError, match="is not available."): adapter.format_input(["prompt", Image(image)]) - # string + # string with chat template + # The fixture sets a chat template, so it should be formatted + adapter.has_chat_template = True + assert adapter.format_input("Hello, world!") == "user: Hello, world!" + + # string without chat template + adapter.has_chat_template = False assert adapter.format_input("Hello, world!") == "Hello, world!" # chat + # Restore chat template for chat test + adapter.has_chat_template = True assert isinstance(adapter.format_input(Chat(messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello, world!"}, diff --git a/tests/models/test_vllm_offline_type_adapter.py b/tests/models/test_vllm_offline_type_adapter.py index c431230e7..bb9b9fc69 100644 --- a/tests/models/test_vllm_offline_type_adapter.py +++ b/tests/models/test_vllm_offline_type_adapter.py @@ -65,6 +65,22 @@ def test_vllm_offline_type_adapter_input_text(type_adapter): assert result == message +def test_vllm_offline_type_adapter_input_text_with_template(): + adapter = VLLMOfflineTypeAdapter(has_chat_template=True) + message = "prompt" + result = adapter.format_input(message) + + assert result == [{"role": "user", "content": "prompt"}] + + +def test_vllm_offline_type_adapter_input_text_without_template(): + adapter = VLLMOfflineTypeAdapter(has_chat_template=False) + message = "prompt" + result = adapter.format_input(message) + + assert result == "prompt" + + def test_vllm_offline_type_adapter_input_chat(type_adapter): model_input = Chat(messages=[ {"role": "system", "content": "prompt"},