-
Notifications
You must be signed in to change notification settings - Fork 624
Alternative implementation of thinking mode #1723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
def get_json_schema_logits_processor( | ||
backend_name: str | None, | ||
model: SteerableModel, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer a separate function that calls get_json_schema_logits_processor
instead of the current branching logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, is there a clean way to get the JsonSchema
, Regex
and CFG
objects up to this point? That would allow us to have a single function get_thinking_logits_processor
that dispatches depending on the type.
if end_thinking_tag is not None: | ||
end_thinking_token_id = _get_end_thinking_token_id(end_thinking_tag, model) | ||
return ThinkingLogitsProcessor(end_thinking_token_id, thinking_max_tokens, backend_logits_processor) | ||
return backend_logits_processor | ||
|
||
|
||
def get_regex_logits_processor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idem
if end_thinking_tag is not None: | ||
end_thinking_token_id = _get_end_thinking_token_id(end_thinking_tag, model) | ||
return ThinkingLogitsProcessor(end_thinking_token_id, thinking_max_tokens, backend_logits_processor) | ||
return backend_logits_processor | ||
|
||
|
||
def get_cfg_logits_processor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idem
@@ -90,7 +90,7 @@ def _setup(self, batch_size: int, vocab_size: int) -> None: | |||
] | |||
|
|||
def _bias_logits_mlx( # pragma: no cover | |||
self, batch_size: int, logits: TensorType | |||
self, batch_size: int, logits: TensorType, skip: list[bool] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we go with this design, I would consider a different name like passthrough
if all(self._is_thinking): | ||
return logits | ||
|
||
return self.logits_processor.process_logits(input_ids, logits) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we could transform all this into operations on arrays so we don't have to call process_logits
for the sequences where the end-of-think token has not been generated. It would go as:
- Extract sequences where end-of-think is present
- Run process-logits on them
- Re-build the logits array with all sequences.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be the best although it means the downstream logits processor needs to be able to handle tensors of different batch sizes and not always in the same order. I'm going to look into how constraining it is.
No description provided.