Skip to content

Conversation

RobinPicard
Copy link
Contributor

No description provided.

@RobinPicard RobinPicard requested a review from rlouf August 11, 2025 16:39
Comment on lines 63 to 65
def get_json_schema_logits_processor(
backend_name: str | None,
model: SteerableModel,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer a separate function that calls get_json_schema_logits_processor instead of the current branching logic

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is there a clean way to get the JsonSchema, Regex and CFG objects up to this point? That would allow us to have a single function get_thinking_logits_processor that dispatches depending on the type.

if end_thinking_tag is not None:
end_thinking_token_id = _get_end_thinking_token_id(end_thinking_tag, model)
return ThinkingLogitsProcessor(end_thinking_token_id, thinking_max_tokens, backend_logits_processor)
return backend_logits_processor


def get_regex_logits_processor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idem

if end_thinking_tag is not None:
end_thinking_token_id = _get_end_thinking_token_id(end_thinking_tag, model)
return ThinkingLogitsProcessor(end_thinking_token_id, thinking_max_tokens, backend_logits_processor)
return backend_logits_processor


def get_cfg_logits_processor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idem

@@ -90,7 +90,7 @@ def _setup(self, batch_size: int, vocab_size: int) -> None:
]

def _bias_logits_mlx( # pragma: no cover
self, batch_size: int, logits: TensorType
self, batch_size: int, logits: TensorType, skip: list[bool]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go with this design, I would consider a different name like passthrough

if all(self._is_thinking):
return logits

return self.logits_processor.process_logits(input_ids, logits)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we could transform all this into operations on arrays so we don't have to call process_logits for the sequences where the end-of-think token has not been generated. It would go as:

  1. Extract sequences where end-of-think is present
  2. Run process-logits on them
  3. Re-build the logits array with all sequences.

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be the best although it means the downstream logits processor needs to be able to handle tensors of different batch sizes and not always in the same order. I'm going to look into how constraining it is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants