Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion docs/features/models/llamacpp.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ Outlines provides an integration with [Llama.cpp](https://github.com/ggerganov/l

## Model Initialization

To load the model, you can use the `from_llamacpp` function. The single argument of the function is a `Llama` model instance from the `llama_cpp` library. Consult the [Llama class API reference](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) for detailed information on how to create a model instance and on the various available parameters.
To load the model, you can use the `from_llamacpp` function. The first argument of the function is a `Llama` model instance from the `llama_cpp` library. Consult the [Llama class API reference](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama) for detailed information on how to create a model instance and on the various available parameters.

You can also pass a `chat_mode` argument to `from_llamacpp`. If `True` (default), the model will regard all `str` inputs as user messages in a chat conversation. If `False`, the model will regard all `str` inputs as plain text prompts.

For instance:

Expand All @@ -31,6 +33,21 @@ model = outlines.from_llamacpp(
)
```

You can also disable chat mode:

```python
import outlines
from llama_cpp import Llama

model = outlines.from_llamacpp(
Llama.from_pretrained(
repo_id="TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
filename="mistral-7b-instruct-v0.2.Q5_K_M.gguf",
),
chat_mode=False,
)
```

## Text Generation

To generate text, you can simply call the model with a prompt.
Expand Down
43 changes: 35 additions & 8 deletions outlines/models/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,15 @@ class LlamaCppTypeAdapter(ModelTypeAdapter):

"""

def __init__(self, has_chat_template: bool = False):
"""
Parameters
----------
has_chat_template
Whether the model has a chat template defined.
"""
self.has_chat_template = has_chat_template

@singledispatchmethod
def format_input(self, model_input):
"""Generate the prompt argument to pass to the model.
Expand All @@ -173,7 +182,9 @@ def format_input(self, model_input):
)

@format_input.register(str)
def format_str_input(self, model_input: str) -> str:
def format_str_input(self, model_input: str) -> str | list:
if self.has_chat_template:
return [{"role": "user", "content": model_input}]
return model_input

@format_input.register(Chat)
Expand Down Expand Up @@ -227,17 +238,26 @@ class LlamaCpp(Model):

tensor_library_name = "numpy"

def __init__(self, model: "Llama"):
def __init__(self, model: "Llama", chat_mode: bool = True):
"""
Parameters
----------
model
A `llama_cpp.Llama` model instance.
chat_mode
Whether to enable chat mode. If `False`, the model will regard
all `str` inputs as plain text prompts. If `True`, the model will
regard all `str` inputs as user messages in a chat conversation.

"""
self.model = model
self.tokenizer = LlamaCppTokenizer(self.model)
self.type_adapter = LlamaCppTypeAdapter()

# Note: llama-cpp-python provides a default chat-template fallback even when
# the user hasn't explicitly configured one:
# https://github.com/abetlen/llama-cpp-python/blob/c37132b/llama_cpp/llama.py#L540-L545
# We keep the default as True because the upstream library generally favors chat-style usage.
self.type_adapter = LlamaCppTypeAdapter(has_chat_template=chat_mode)

def generate(
self,
Expand Down Expand Up @@ -273,13 +293,15 @@ def generate(
**inference_kwargs,
)
result = completion["choices"][0]["text"]
elif isinstance(prompt, list): # pragma: no cover
elif isinstance(prompt, list):
completion = self.model.create_chat_completion(
prompt,
logits_processor=self.type_adapter.format_output_type(output_type),
**inference_kwargs,
)
result = completion["choices"][0]["message"]["content"]
else: # Never reached # pragma: no cover
raise ValueError("Unexpected prompt type.")

self.model.reset()

Expand Down Expand Up @@ -330,7 +352,7 @@ def generate_stream(
for chunk in generator:
yield chunk["choices"][0]["text"]

elif isinstance(prompt, list): # pragma: no cover
elif isinstance(prompt, list):
generator = self.model.create_chat_completion(
prompt,
logits_processor=self.type_adapter.format_output_type(output_type),
Expand All @@ -339,21 +361,26 @@ def generate_stream(
)
for chunk in generator:
yield chunk["choices"][0]["delta"].get("content", "")
else: # Never reached # pragma: no cover
raise ValueError("Unexpected prompt type.")


def from_llamacpp(model: "Llama"):
def from_llamacpp(model: "Llama", chat_mode: bool = True) -> LlamaCpp:
"""Create an Outlines `LlamaCpp` model instance from a
`llama_cpp.Llama` instance.

Parameters
----------
model
A `llama_cpp.Llama` instance.
chat_mode
Whether to enable chat mode. If `False`, the model will regard
all `str` inputs as plain text prompts. If `True`, the model will
regard all `str` inputs as user messages in a chat conversation.

Returns
-------
LlamaCpp
An Outlines `LlamaCpp` model instance.

"""
return LlamaCpp(model)
return LlamaCpp(model, chat_mode=chat_mode)
15 changes: 11 additions & 4 deletions outlines/models/mlxlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from outlines.inputs import Chat
from outlines.models.base import Model, ModelTypeAdapter
from outlines.models.tokenizer import _check_hf_chat_template
from outlines.models.transformers import TransformerTokenizer
from outlines.processors import OutlinesLogitsProcessor

Expand All @@ -18,8 +19,9 @@
class MLXLMTypeAdapter(ModelTypeAdapter):
"""Type adapter for the `MLXLM` model."""

def __init__(self, **kwargs):
self.tokenizer = kwargs.get("tokenizer")
def __init__(self, tokenizer: "PreTrainedTokenizer", has_chat_template: bool = False):
self.tokenizer = tokenizer
self.has_chat_template = has_chat_template

@singledispatchmethod
def format_input(self, model_input):
Expand All @@ -42,7 +44,9 @@ def format_input(self, model_input):
)

@format_input.register(str)
def format_str_input(self, model_input: str):
def format_str_input(self, model_input: str) -> str:
if self.has_chat_template:
return self.format_chat_input(Chat([{"role": "user", "content": model_input}]))
return model_input

@format_input.register(Chat)
Expand Down Expand Up @@ -113,7 +117,10 @@ def __init__(
self.mlx_tokenizer = tokenizer
# self.tokenizer is used by the logits processor
self.tokenizer = TransformerTokenizer(tokenizer._tokenizer)
self.type_adapter = MLXLMTypeAdapter(tokenizer=tokenizer)
self.type_adapter = MLXLMTypeAdapter(
tokenizer=tokenizer,
has_chat_template=_check_hf_chat_template(tokenizer)
)

def generate(
self,
Expand Down
10 changes: 10 additions & 0 deletions outlines/models/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
if TYPE_CHECKING:
import numpy as np
from numpy.typing import NDArray
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast


class Tokenizer(Hashable, Protocol):
Expand Down Expand Up @@ -31,3 +32,12 @@ def convert_token_to_string(self, token: str) -> str:
token that includes `Ġ` with a string.
"""
...


def _check_hf_chat_template(tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast") -> bool:
"""Check if the HuggingFace tokenizer has a chat template."""
try:
tokenizer.get_chat_template()
return True
except ValueError:
return False
14 changes: 10 additions & 4 deletions outlines/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from outlines.inputs import Audio, Chat, Image, Video
from outlines.models.base import Model, ModelTypeAdapter
from outlines.models.tokenizer import Tokenizer
from outlines.models.tokenizer import Tokenizer, _check_hf_chat_template
from outlines.processors import OutlinesLogitsProcessor

if TYPE_CHECKING:
Expand Down Expand Up @@ -136,8 +136,9 @@ def __setstate__(self, state):
class TransformersTypeAdapter(ModelTypeAdapter):
"""Type adapter for the `Transformers` model."""

def __init__(self, **kwargs):
self.tokenizer = kwargs.get("tokenizer")
def __init__(self, tokenizer: "PreTrainedTokenizer", has_chat_template: bool = False):
self.tokenizer = tokenizer
self.has_chat_template = has_chat_template

@singledispatchmethod
def format_input(self, model_input):
Expand All @@ -161,6 +162,8 @@ def format_input(self, model_input):

@format_input.register(str)
def format_str_input(self, model_input: str) -> str:
if self.has_chat_template:
return self.format_chat_input(Chat([{"role": "user", "content": model_input}]))
return model_input

@format_input.register(Chat)
Expand Down Expand Up @@ -243,7 +246,10 @@ def __init__(
self.hf_tokenizer = tokenizer
self.tokenizer = TransformerTokenizer(tokenizer)
self.device_dtype = device_dtype
self.type_adapter = TransformersTypeAdapter(tokenizer=tokenizer)
self.type_adapter = TransformersTypeAdapter(
tokenizer=tokenizer,
has_chat_template=_check_hf_chat_template(tokenizer)
)

if (
FlaxPreTrainedModel is not None
Expand Down
62 changes: 48 additions & 14 deletions outlines/models/vllm_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
class VLLMOfflineTypeAdapter(ModelTypeAdapter):
"""Type adapter for the `VLLMOffline` model."""

def __init__(self, has_chat_template: bool = False):
self.has_chat_template = has_chat_template

@singledispatchmethod
def format_input(self, model_input):
"""Generate the prompt argument to pass to the model.
Expand All @@ -36,10 +39,12 @@ def format_input(self, model_input):
)

@format_input.register(str)
def format_input_str(self, model_input: str) -> str:
def format_input_str(self, model_input: str) -> str | list:
"""Format a `str` input.

"""
if self.has_chat_template:
return self.format_input_chat(Chat([{"role": "user", "content": model_input}]))
return model_input

@format_input.register(Chat)
Expand Down Expand Up @@ -109,7 +114,8 @@ def __init__(self, model: "LLM"):

"""
self.model = model
self.type_adapter = VLLMOfflineTypeAdapter()
self.tokenizer = self.model.get_tokenizer()
self.type_adapter = VLLMOfflineTypeAdapter(has_chat_template=self._check_chat_template())

def _build_generation_args(
self,
Expand Down Expand Up @@ -163,15 +169,17 @@ def generate(
output_type,
)

if isinstance(model_input, Chat):
model_input = self.type_adapter.format_input(model_input)

if isinstance(model_input, list):
results = self.model.chat(
messages=self.type_adapter.format_input(model_input),
messages=model_input,
sampling_params=sampling_params,
**inference_kwargs,
)
else:
results = self.model.generate(
prompts=self.type_adapter.format_input(model_input),
prompts=model_input,
sampling_params=sampling_params,
**inference_kwargs,
)
Expand Down Expand Up @@ -213,16 +221,20 @@ def generate_batch(
output_type,
)

if any(isinstance(item, Chat) for item in model_input):
raise TypeError(
"Batch generation is not available for the `Chat` input type."
)
model_inputs = [self.type_adapter.format_input(item) for item in model_input]

results = self.model.generate(
prompts=[self.type_adapter.format_input(item) for item in model_input],
sampling_params=sampling_params,
**inference_kwargs,
)
if model_inputs and isinstance(model_inputs[0], list):
results = self.model.chat(
messages=model_inputs,
sampling_params=sampling_params,
**inference_kwargs,
)
else:
results = self.model.generate(
prompts=model_inputs,
sampling_params=sampling_params,
**inference_kwargs,
)
return [[sample.text for sample in batch.outputs] for batch in results]

def generate_stream(self, model_input, output_type, **inference_kwargs):
Expand All @@ -235,6 +247,28 @@ def generate_stream(self, model_input, output_type, **inference_kwargs):
"Streaming is not available for the vLLM offline integration."
)

def _check_chat_template(self) -> bool:
"""Check if the tokenizer has a chat template."""
from vllm.transformers_utils.tokenizer import (
PreTrainedTokenizer,
PreTrainedTokenizerFast,
TokenizerBase
)
from outlines.models.tokenizer import _check_hf_chat_template

if isinstance(self.tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
return _check_hf_chat_template(self.tokenizer)
elif isinstance(self.tokenizer, TokenizerBase):
# vLLM defines its own TokenizerBase class, and only provides
# limited compatibility with HuggingFace tokenizers. So we
# need to check for chat template support differently.
try:
self.tokenizer.apply_chat_template([{"role": "user", "content": "test"}])
return True
except Exception:
return False
else: # Never reached # pragma: no cover
return False

def from_vllm_offline(model: "LLM") -> VLLMOffline:
"""Create an Outlines `VLLMOffline` model instance from a `vllm.LLM`
Expand Down
Loading
Loading