Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions outlines/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from outlines.backends.outlines_core import OutlinesCoreBackend
from outlines.backends.xgrammar import XGrammarBackend
from outlines.models import SteerableModel
from outlines.processors.thinking_logits_processor import ThinkingLogitsProcessor
from outlines.models.transformers import Transformers
from outlines.models.llama_cpp import LlamaCpp
from outlines.models.mlxlm import MLXLM


CFG_DEFAULT_BACKEND = "llguidance"
Expand Down Expand Up @@ -39,12 +43,30 @@ def _get_backend(backend_name: str, model: SteerableModel) -> BaseBackend:
return LLGuidanceBackend(model)
else:
raise ValueError(f"Backend {backend_name} not supported")



def _get_end_thinking_token_id(end_thinking_tag: str, model: SteerableModel) -> int:
if isinstance(model, Transformers):
tokenizer = model.hf_tokenizer
elif isinstance(model, LlamaCpp):
tokenizer = model.tokenizer
elif isinstance(model, MLXLM):
tokenizer = model.mlx_tokenizer
encoded_end_thinking_tag = tokenizer.encode(end_thinking_tag)
if len(encoded_end_thinking_tag) != 1:
raise ValueError(
"The end_thinking_tag must correspond to a single token in"
+ "the tokenizer vocabulary."
)
return encoded_end_thinking_tag[0]

def get_json_schema_logits_processor(
backend_name: str | None,
model: SteerableModel,
Comment on lines 63 to 65
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer a separate function that calls get_json_schema_logits_processor instead of the current branching logic

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is there a clean way to get the JsonSchema, Regex and CFG objects up to this point? That would allow us to have a single function get_thinking_logits_processor that dispatches depending on the type.

json_schema: str,
*,
end_thinking_tag: str | None,
thinking_max_tokens: int | None,
) -> LogitsProcessorType:
"""Create a logits processor from a JSON schema.

Expand All @@ -67,13 +89,20 @@ def get_json_schema_logits_processor(
backend_name or JSON_SCHEMA_DEFAULT_BACKEND,
model,
)
return backend.get_json_schema_logits_processor(json_schema)
backend_logits_processor = backend.get_json_schema_logits_processor(json_schema)
if end_thinking_tag is not None:
end_thinking_token_id = _get_end_thinking_token_id(end_thinking_tag, model)
return ThinkingLogitsProcessor(end_thinking_token_id, thinking_max_tokens, backend_logits_processor)
return backend_logits_processor


def get_regex_logits_processor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idem

backend_name: str | None,
model: SteerableModel,
regex: str,
*,
end_thinking_tag: str | None,
thinking_max_tokens: int | None,
) -> LogitsProcessorType:
"""Create a logits processor from a regex.

Expand All @@ -96,13 +125,20 @@ def get_regex_logits_processor(
backend_name or REGEX_DEFAULT_BACKEND,
model,
)
return backend.get_regex_logits_processor(regex)
backend_logits_processor = backend.get_regex_logits_processor(regex)
if end_thinking_tag is not None:
end_thinking_token_id = _get_end_thinking_token_id(end_thinking_tag, model)
return ThinkingLogitsProcessor(end_thinking_token_id, thinking_max_tokens, backend_logits_processor)
return backend_logits_processor


def get_cfg_logits_processor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idem

backend_name: str | None,
model: SteerableModel,
grammar: str,
*,
end_thinking_tag: str | None,
thinking_max_tokens: int | None,
) -> LogitsProcessorType:
"""Create a logits processor from a context-free grammar.

Expand All @@ -125,4 +161,8 @@ def get_cfg_logits_processor(
backend_name or CFG_DEFAULT_BACKEND,
model,
)
return backend.get_cfg_logits_processor(grammar)
backend_logits_processor = backend.get_cfg_logits_processor(grammar)
if end_thinking_tag is not None:
end_thinking_token_id = _get_end_thinking_token_id(end_thinking_tag, model)
return ThinkingLogitsProcessor(end_thinking_token_id, thinking_max_tokens, backend_logits_processor)
return backend_logits_processor
24 changes: 18 additions & 6 deletions outlines/backends/outlines_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Backend class for Outlines Core."""

from typing import Callable, Dict
from typing import Callable, Dict, Union

from outlines_core import Guide, Index, Vocabulary
# TODO: change this once the import issue is fixed in outlines_core
Expand Down Expand Up @@ -90,7 +90,7 @@ def _setup(self, batch_size: int, vocab_size: int) -> None:
]

def _bias_logits_mlx( # pragma: no cover
self, batch_size: int, logits: TensorType
self, batch_size: int, logits: TensorType, skip: list[bool]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go with this design, I would consider a different name like passthrough

) -> TensorType:
"""Bias the logits for MLX tensors."""
from outlines_core.kernels.mlx import (
Expand All @@ -100,6 +100,9 @@ def _bias_logits_mlx( # pragma: no cover

biased_logits_array = []
for i in range(batch_size):
if skip[i]:
biased_logits_array.append(logits[i])
continue
fill_next_token_bitmask(self._guides[i], self._bitmasks[i])
biased_logits = apply_token_bitmask(
self.tensor_adapter.unsqueeze(logits[i]), self._bitmasks[i] # type: ignore
Expand All @@ -109,7 +112,7 @@ def _bias_logits_mlx( # pragma: no cover
return self.tensor_adapter.concatenate(biased_logits_array)

def _bias_logits_torch(
self, batch_size: int, logits: TensorType
self, batch_size: int, logits: TensorType, skip: list[bool]
) -> TensorType:
"""Bias the logits for Torch tensors."""
from outlines_core.kernels.torch import (
Expand All @@ -118,6 +121,8 @@ def _bias_logits_torch(
)

for i in range(batch_size):
if skip[i]:
continue
fill_next_token_bitmask(self._guides[i], self._bitmasks[i])
self._bitmasks[i] = self.tensor_adapter.to_device(
self._bitmasks[i],
Expand All @@ -135,7 +140,7 @@ def _bias_logits_torch(
return logits

def _bias_logits_numpy(
self, batch_size: int, logits: TensorType
self, batch_size: int, logits: TensorType, skip: list[bool]
) -> TensorType:
"""Bias the logits for Numpy tensors."""
from outlines_core.kernels.numpy import (
Expand All @@ -144,6 +149,8 @@ def _bias_logits_numpy(
)

for i in range(batch_size):
if skip[i]:
continue
fill_next_token_bitmask(self._guides[i], self._bitmasks[i])
apply_token_bitmask_inplace(
self.tensor_adapter.unsqueeze(logits[i]), # type: ignore
Expand All @@ -153,7 +160,7 @@ def _bias_logits_numpy(
return logits

def process_logits(
self, input_ids: TensorType, logits: TensorType
self, input_ids: TensorType, logits: TensorType, skip: Union[list[bool], None] = None
) -> TensorType:
"""Use the guides to bias the logits.

Expand All @@ -173,19 +180,24 @@ def process_logits(
batch_size = self.tensor_adapter.shape(input_ids)[0]
vocab_size = self.tensor_adapter.shape(logits)[1]

if skip is None:
skip = [False] * batch_size

if self.is_first_token:
self._setup(batch_size, vocab_size)
self.is_first_token = False
else:
for i in range(batch_size):
if skip[i]:
continue
last_token_id = self.tensor_adapter.to_scalar(input_ids[i][-1]) # type: ignore
if not self._guides[i].is_finished():
self._guides[i].advance(
token_id=last_token_id,
return_tokens=False
)

return self.bias_logits(batch_size, logits)
return self.bias_logits(batch_size, logits, skip)


class OutlinesCoreBackend(BaseBackend):
Expand Down
19 changes: 18 additions & 1 deletion outlines/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ def __init__(
model: SteerableModel,
output_type: Optional[Any],
backend_name: Optional[str] = None,
*,
end_thinking_tag: Optional[str] = None,
thinking_max_tokens: Optional[int] = None,
):
"""
Parameters
Expand All @@ -241,19 +244,25 @@ def __init__(
backend_name,
model,
cfg_string,
end_thinking_tag=end_thinking_tag,
thinking_max_tokens=thinking_max_tokens,
)
elif isinstance(term, JsonSchema):
self.logits_processor = get_json_schema_logits_processor(
backend_name,
model,
term.schema,
end_thinking_tag=end_thinking_tag,
thinking_max_tokens=thinking_max_tokens,
)
else:
regex_string = to_regex(term)
self.logits_processor = get_regex_logits_processor(
backend_name,
model,
regex_string,
end_thinking_tag=end_thinking_tag,
thinking_max_tokens=thinking_max_tokens,
)

@classmethod
Expand Down Expand Up @@ -349,6 +358,8 @@ def Generator(
backend: Optional[str] = None,
*,
processor: Optional[LogitsProcessorType] = None,
end_thinking_tag: Optional[str] = None,
thinking_max_tokens: Optional[int] = None,
) -> Union[SteerableGenerator, BlackBoxGenerator, AsyncBlackBoxGenerator]:
"""Create a generator for the given model and output parameters.

Expand Down Expand Up @@ -389,7 +400,13 @@ def Generator(
if processor is not None:
return SteerableGenerator.from_processor(model, processor) # type: ignore
else:
return SteerableGenerator(model, output_type, backend) # type: ignore
return SteerableGenerator(
model,
output_type,
backend,
end_thinking_tag=end_thinking_tag,
thinking_max_tokens=thinking_max_tokens
)
else:
if processor is not None:
raise NotImplementedError(
Expand Down
33 changes: 30 additions & 3 deletions outlines/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def __call__(
model_input: Any,
output_type: Optional[Any] = None,
backend: Optional[str] = None,
*,
end_thinking_tag: Optional[str] = None,
thinking_max_tokens: Optional[int] = None,
**inference_kwargs: Any
) -> Any:
"""Call the model.
Expand Down Expand Up @@ -119,13 +122,22 @@ def __call__(
"""
from outlines.generator import Generator

return Generator(self, output_type, backend)(model_input, **inference_kwargs)
return Generator(
self,
output_type,
backend,
end_thinking_tag=end_thinking_tag,
thinking_max_tokens=thinking_max_tokens
)(model_input, **inference_kwargs)

def batch(
self,
model_input: List[Any],
output_type: Optional[Any] = None,
backend: Optional[str] = None,
*,
end_thinking_tag: Optional[str] = None,
thinking_max_tokens: Optional[int] = None,
**inference_kwargs: Any
) -> List[Any]:
"""Make a batch call to the model (several inputs at once).
Expand Down Expand Up @@ -164,14 +176,23 @@ def batch(
"""
from outlines import Generator

generator = Generator(self, output_type, backend)
generator = Generator(
self,
output_type,
backend,
end_thinking_tag=end_thinking_tag,
thinking_max_tokens=thinking_max_tokens
)
return generator.batch(model_input, **inference_kwargs) # type: ignore

def stream(
self,
model_input: Any,
output_type: Optional[Any] = None,
backend: Optional[str] = None,
*,
end_thinking_tag: Optional[str] = None,
thinking_max_tokens: Optional[int] = None,
**inference_kwargs: Any
) -> Iterator[Any]:
"""Stream a response from the model.
Expand Down Expand Up @@ -212,7 +233,13 @@ def stream(
"""
from outlines import Generator

generator = Generator(self, output_type, backend)
generator = Generator(
self,
output_type,
backend,
end_thinking_tag=end_thinking_tag,
thinking_max_tokens=thinking_max_tokens
)
return generator.stream(model_input, **inference_kwargs) # type: ignore

@abstractmethod
Expand Down
42 changes: 42 additions & 0 deletions outlines/processors/thinking_logits_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from outlines.processors.base_logits_processor import OutlinesLogitsProcessor, TensorType


class ThinkingLogitsProcessor(OutlinesLogitsProcessor):

def __init__(self, end_thinking_token_id: int, thinking_max_tokens: int, logits_processor: OutlinesLogitsProcessor):
super().__init__(logits_processor.tensor_library_name)
self.logits_processor = logits_processor
self.end_thinking_token_id = end_thinking_token_id
self.thinking_max_tokens = thinking_max_tokens
self.is_first_token = True

def reset(self) -> None:
self.is_first_token = True
self.logits_processor.reset()

def setup(self, batch_size: int) -> None:
self._is_thinking = [self.end_thinking_token_id is not None] * batch_size
self._num_tokens_generated = 0

def process_logits(self, input_ids: TensorType, logits: TensorType) -> TensorType:

batch_size = self.tensor_adapter.shape(input_ids)[0]

if self.is_first_token:
self.setup(batch_size)
self.is_first_token = False
else:
self._num_tokens_generated += 1
for i in range(batch_size):
if not self._is_thinking[i]:
continue
latest_token_id = self.tensor_adapter.to_scalar(input_ids[i][-1])
if latest_token_id == self.end_thinking_token_id:
self._is_thinking[i] = False
elif self._num_tokens_generated >= self.thinking_max_tokens:
logits[i][self.end_thinking_token_id] = float("inf")

if all(self._is_thinking):
return logits

return self.logits_processor.process_logits(input_ids, logits)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we could transform all this into operations on arrays so we don't have to call process_logits for the sequences where the end-of-think token has not been generated. It would go as:

  1. Extract sequences where end-of-think is present
  2. Run process-logits on them
  3. Re-build the logits array with all sequences.

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be the best although it means the downstream logits processor needs to be able to handle tensors of different batch sizes and not always in the same order. I'm going to look into how constraining it is.

Loading