-
Notifications
You must be signed in to change notification settings - Fork 630
Alternative implementation of thinking mode #1723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,10 @@ | |
from outlines.backends.outlines_core import OutlinesCoreBackend | ||
from outlines.backends.xgrammar import XGrammarBackend | ||
from outlines.models import SteerableModel | ||
from outlines.processors.thinking_logits_processor import ThinkingLogitsProcessor | ||
from outlines.models.transformers import Transformers | ||
from outlines.models.llama_cpp import LlamaCpp | ||
from outlines.models.mlxlm import MLXLM | ||
|
||
|
||
CFG_DEFAULT_BACKEND = "llguidance" | ||
|
@@ -39,12 +43,30 @@ def _get_backend(backend_name: str, model: SteerableModel) -> BaseBackend: | |
return LLGuidanceBackend(model) | ||
else: | ||
raise ValueError(f"Backend {backend_name} not supported") | ||
|
||
|
||
|
||
def _get_end_thinking_token_id(end_thinking_tag: str, model: SteerableModel) -> int: | ||
if isinstance(model, Transformers): | ||
tokenizer = model.hf_tokenizer | ||
elif isinstance(model, LlamaCpp): | ||
tokenizer = model.tokenizer | ||
elif isinstance(model, MLXLM): | ||
tokenizer = model.mlx_tokenizer | ||
encoded_end_thinking_tag = tokenizer.encode(end_thinking_tag) | ||
if len(encoded_end_thinking_tag) != 1: | ||
raise ValueError( | ||
"The end_thinking_tag must correspond to a single token in" | ||
+ "the tokenizer vocabulary." | ||
) | ||
return encoded_end_thinking_tag[0] | ||
|
||
def get_json_schema_logits_processor( | ||
backend_name: str | None, | ||
model: SteerableModel, | ||
json_schema: str, | ||
*, | ||
end_thinking_tag: str | None, | ||
thinking_max_tokens: int | None, | ||
) -> LogitsProcessorType: | ||
"""Create a logits processor from a JSON schema. | ||
|
||
|
@@ -67,13 +89,20 @@ def get_json_schema_logits_processor( | |
backend_name or JSON_SCHEMA_DEFAULT_BACKEND, | ||
model, | ||
) | ||
return backend.get_json_schema_logits_processor(json_schema) | ||
backend_logits_processor = backend.get_json_schema_logits_processor(json_schema) | ||
if end_thinking_tag is not None: | ||
end_thinking_token_id = _get_end_thinking_token_id(end_thinking_tag, model) | ||
return ThinkingLogitsProcessor(end_thinking_token_id, thinking_max_tokens, backend_logits_processor) | ||
return backend_logits_processor | ||
|
||
|
||
def get_regex_logits_processor( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Idem |
||
backend_name: str | None, | ||
model: SteerableModel, | ||
regex: str, | ||
*, | ||
end_thinking_tag: str | None, | ||
thinking_max_tokens: int | None, | ||
) -> LogitsProcessorType: | ||
"""Create a logits processor from a regex. | ||
|
||
|
@@ -96,13 +125,20 @@ def get_regex_logits_processor( | |
backend_name or REGEX_DEFAULT_BACKEND, | ||
model, | ||
) | ||
return backend.get_regex_logits_processor(regex) | ||
backend_logits_processor = backend.get_regex_logits_processor(regex) | ||
if end_thinking_tag is not None: | ||
end_thinking_token_id = _get_end_thinking_token_id(end_thinking_tag, model) | ||
return ThinkingLogitsProcessor(end_thinking_token_id, thinking_max_tokens, backend_logits_processor) | ||
return backend_logits_processor | ||
|
||
|
||
def get_cfg_logits_processor( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Idem |
||
backend_name: str | None, | ||
model: SteerableModel, | ||
grammar: str, | ||
*, | ||
end_thinking_tag: str | None, | ||
thinking_max_tokens: int | None, | ||
) -> LogitsProcessorType: | ||
"""Create a logits processor from a context-free grammar. | ||
|
||
|
@@ -125,4 +161,8 @@ def get_cfg_logits_processor( | |
backend_name or CFG_DEFAULT_BACKEND, | ||
model, | ||
) | ||
return backend.get_cfg_logits_processor(grammar) | ||
backend_logits_processor = backend.get_cfg_logits_processor(grammar) | ||
if end_thinking_tag is not None: | ||
end_thinking_token_id = _get_end_thinking_token_id(end_thinking_tag, model) | ||
return ThinkingLogitsProcessor(end_thinking_token_id, thinking_max_tokens, backend_logits_processor) | ||
return backend_logits_processor |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
"""Backend class for Outlines Core.""" | ||
|
||
from typing import Callable, Dict | ||
from typing import Callable, Dict, Union | ||
|
||
from outlines_core import Guide, Index, Vocabulary | ||
# TODO: change this once the import issue is fixed in outlines_core | ||
|
@@ -90,7 +90,7 @@ def _setup(self, batch_size: int, vocab_size: int) -> None: | |
] | ||
|
||
def _bias_logits_mlx( # pragma: no cover | ||
self, batch_size: int, logits: TensorType | ||
self, batch_size: int, logits: TensorType, skip: list[bool] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we go with this design, I would consider a different name like |
||
) -> TensorType: | ||
"""Bias the logits for MLX tensors.""" | ||
from outlines_core.kernels.mlx import ( | ||
|
@@ -100,6 +100,9 @@ def _bias_logits_mlx( # pragma: no cover | |
|
||
biased_logits_array = [] | ||
for i in range(batch_size): | ||
if skip[i]: | ||
biased_logits_array.append(logits[i]) | ||
continue | ||
fill_next_token_bitmask(self._guides[i], self._bitmasks[i]) | ||
biased_logits = apply_token_bitmask( | ||
self.tensor_adapter.unsqueeze(logits[i]), self._bitmasks[i] # type: ignore | ||
|
@@ -109,7 +112,7 @@ def _bias_logits_mlx( # pragma: no cover | |
return self.tensor_adapter.concatenate(biased_logits_array) | ||
|
||
def _bias_logits_torch( | ||
self, batch_size: int, logits: TensorType | ||
self, batch_size: int, logits: TensorType, skip: list[bool] | ||
) -> TensorType: | ||
"""Bias the logits for Torch tensors.""" | ||
from outlines_core.kernels.torch import ( | ||
|
@@ -118,6 +121,8 @@ def _bias_logits_torch( | |
) | ||
|
||
for i in range(batch_size): | ||
if skip[i]: | ||
continue | ||
fill_next_token_bitmask(self._guides[i], self._bitmasks[i]) | ||
self._bitmasks[i] = self.tensor_adapter.to_device( | ||
self._bitmasks[i], | ||
|
@@ -135,7 +140,7 @@ def _bias_logits_torch( | |
return logits | ||
|
||
def _bias_logits_numpy( | ||
self, batch_size: int, logits: TensorType | ||
self, batch_size: int, logits: TensorType, skip: list[bool] | ||
) -> TensorType: | ||
"""Bias the logits for Numpy tensors.""" | ||
from outlines_core.kernels.numpy import ( | ||
|
@@ -144,6 +149,8 @@ def _bias_logits_numpy( | |
) | ||
|
||
for i in range(batch_size): | ||
if skip[i]: | ||
continue | ||
fill_next_token_bitmask(self._guides[i], self._bitmasks[i]) | ||
apply_token_bitmask_inplace( | ||
self.tensor_adapter.unsqueeze(logits[i]), # type: ignore | ||
|
@@ -153,7 +160,7 @@ def _bias_logits_numpy( | |
return logits | ||
|
||
def process_logits( | ||
self, input_ids: TensorType, logits: TensorType | ||
self, input_ids: TensorType, logits: TensorType, skip: Union[list[bool], None] = None | ||
) -> TensorType: | ||
"""Use the guides to bias the logits. | ||
|
||
|
@@ -173,19 +180,24 @@ def process_logits( | |
batch_size = self.tensor_adapter.shape(input_ids)[0] | ||
vocab_size = self.tensor_adapter.shape(logits)[1] | ||
|
||
if skip is None: | ||
skip = [False] * batch_size | ||
|
||
if self.is_first_token: | ||
self._setup(batch_size, vocab_size) | ||
self.is_first_token = False | ||
else: | ||
for i in range(batch_size): | ||
if skip[i]: | ||
continue | ||
last_token_id = self.tensor_adapter.to_scalar(input_ids[i][-1]) # type: ignore | ||
if not self._guides[i].is_finished(): | ||
self._guides[i].advance( | ||
token_id=last_token_id, | ||
return_tokens=False | ||
) | ||
|
||
return self.bias_logits(batch_size, logits) | ||
return self.bias_logits(batch_size, logits, skip) | ||
|
||
|
||
class OutlinesCoreBackend(BaseBackend): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from outlines.processors.base_logits_processor import OutlinesLogitsProcessor, TensorType | ||
|
||
|
||
class ThinkingLogitsProcessor(OutlinesLogitsProcessor): | ||
|
||
def __init__(self, end_thinking_token_id: int, thinking_max_tokens: int, logits_processor: OutlinesLogitsProcessor): | ||
super().__init__(logits_processor.tensor_library_name) | ||
self.logits_processor = logits_processor | ||
self.end_thinking_token_id = end_thinking_token_id | ||
self.thinking_max_tokens = thinking_max_tokens | ||
self.is_first_token = True | ||
|
||
def reset(self) -> None: | ||
self.is_first_token = True | ||
self.logits_processor.reset() | ||
|
||
def setup(self, batch_size: int) -> None: | ||
self._is_thinking = [self.end_thinking_token_id is not None] * batch_size | ||
self._num_tokens_generated = 0 | ||
|
||
def process_logits(self, input_ids: TensorType, logits: TensorType) -> TensorType: | ||
|
||
batch_size = self.tensor_adapter.shape(input_ids)[0] | ||
|
||
if self.is_first_token: | ||
self.setup(batch_size) | ||
self.is_first_token = False | ||
else: | ||
self._num_tokens_generated += 1 | ||
for i in range(batch_size): | ||
if not self._is_thinking[i]: | ||
continue | ||
latest_token_id = self.tensor_adapter.to_scalar(input_ids[i][-1]) | ||
if latest_token_id == self.end_thinking_token_id: | ||
self._is_thinking[i] = False | ||
elif self._num_tokens_generated >= self.thinking_max_tokens: | ||
logits[i][self.end_thinking_token_id] = float("inf") | ||
|
||
if all(self._is_thinking): | ||
return logits | ||
|
||
return self.logits_processor.process_logits(input_ids, logits) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering if we could transform all this into operations on arrays so we don't have to call
What do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would be the best although it means the downstream logits processor needs to be able to handle tensors of different batch sizes and not always in the same order. I'm going to look into how constraining it is. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer a separate function that calls
get_json_schema_logits_processor
instead of the current branching logicThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, is there a clean way to get the
JsonSchema
,Regex
andCFG
objects up to this point? That would allow us to have a single functionget_thinking_logits_processor
that dispatches depending on the type.