From 83129cca390befd7babc3c951b8a203e51588754 Mon Sep 17 00:00:00 2001 From: RobinPicard Date: Wed, 6 Aug 2025 13:55:42 +0200 Subject: [PATCH] Add support for reasoning models --- outlines/backends/__init__.py | 70 +++++++- outlines/backends/llguidance.py | 165 ++++++++++++++---- outlines/backends/outlines_core.py | 158 ++++++++++++++--- outlines/backends/xgrammar.py | 151 +++++++++++++--- outlines/generator.py | 33 +++- outlines/processors/base_logits_processor.py | 8 +- outlines/processors/tensor_adapters/base.py | 26 +++ outlines/processors/tensor_adapters/jax.py | 10 ++ outlines/processors/tensor_adapters/mlx.py | 11 ++ outlines/processors/tensor_adapters/numpy.py | 10 ++ .../processors/tensor_adapters/tensorflow.py | 14 ++ outlines/processors/tensor_adapters/torch.py | 11 ++ 12 files changed, 572 insertions(+), 95 deletions(-) diff --git a/outlines/backends/__init__.py b/outlines/backends/__init__.py index 680e54959..7b0ea879d 100644 --- a/outlines/backends/__init__.py +++ b/outlines/backends/__init__.py @@ -15,7 +15,13 @@ REGEX_DEFAULT_BACKEND = "outlines_core" -def _get_backend(backend_name: str, model: SteerableModel) -> BaseBackend: +def _get_backend( + backend_name: str, + model: SteerableModel, + *, + end_thinking_tag: str | None, + thinking_max_tokens: int | None, +) -> BaseBackend: """Create a Backend instance. Parameters @@ -24,6 +30,13 @@ def _get_backend(backend_name: str, model: SteerableModel) -> BaseBackend: The name of the backend to get. model: Model The Outlines model of the user. + end_thinking_tag: str | None + The tag the model uses to indicate the end of the thinking process. + Only used when running a thinking model. + thinking_max_tokens: int | None + The maximum number of tokens the model can think about. Only used when + running a thinking model. The end_thinking_tag argument must also be + provided to use this parameter. Returns ------- @@ -32,11 +45,23 @@ def _get_backend(backend_name: str, model: SteerableModel) -> BaseBackend: """ if backend_name == "outlines_core": - return OutlinesCoreBackend(model) + return OutlinesCoreBackend( + model, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens, + ) elif backend_name == "xgrammar": - return XGrammarBackend(model) + return XGrammarBackend( + model, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens, + ) elif backend_name == "llguidance": - return LLGuidanceBackend(model) + return LLGuidanceBackend( + model, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens, + ) else: raise ValueError(f"Backend {backend_name} not supported") @@ -45,6 +70,9 @@ def get_json_schema_logits_processor( backend_name: str | None, model: SteerableModel, json_schema: str, + *, + end_thinking_tag: str | None, + thinking_max_tokens: int | None, ) -> LogitsProcessorType: """Create a logits processor from a JSON schema. @@ -56,6 +84,13 @@ def get_json_schema_logits_processor( The Outlines model of the user. json_schema: str The JSON schema to create a logits processor from. + end_thinking_tag: str | None + The tag the model uses to indicate the end of the thinking process. + Only used when running a thinking model. + thinking_max_tokens: int | None + The maximum number of tokens the model can think about. Only used when + running a thinking model. The end_thinking_tag argument must also be + provided to use this parameter. Returns ------- @@ -66,6 +101,8 @@ def get_json_schema_logits_processor( backend = _get_backend( backend_name or JSON_SCHEMA_DEFAULT_BACKEND, model, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens, ) return backend.get_json_schema_logits_processor(json_schema) @@ -74,6 +111,9 @@ def get_regex_logits_processor( backend_name: str | None, model: SteerableModel, regex: str, + *, + end_thinking_tag: str | None, + thinking_max_tokens: int | None, ) -> LogitsProcessorType: """Create a logits processor from a regex. @@ -85,6 +125,13 @@ def get_regex_logits_processor( The Outlines model of the user. regex: str The regex to create a logits processor from. + end_thinking_tag: str | None + The tag the model uses to indicate the end of the thinking process. + Only used when running a thinking model. + thinking_max_tokens: int | None + The maximum number of tokens the model can think about. Only used when + running a thinking model. The end_thinking_tag argument must also be + provided to use this parameter. Returns ------- @@ -95,6 +142,8 @@ def get_regex_logits_processor( backend = _get_backend( backend_name or REGEX_DEFAULT_BACKEND, model, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens, ) return backend.get_regex_logits_processor(regex) @@ -103,6 +152,9 @@ def get_cfg_logits_processor( backend_name: str | None, model: SteerableModel, grammar: str, + *, + end_thinking_tag: str | None, + thinking_max_tokens: int | None, ) -> LogitsProcessorType: """Create a logits processor from a context-free grammar. @@ -114,7 +166,13 @@ def get_cfg_logits_processor( The Outlines model of the user. grammar: str The context-free grammar to create a logits processor from. - + end_thinking_tag: str | None + The tag the model uses to indicate the end of the thinking process. + Only used when running a thinking model. + thinking_max_tokens: int | None + The maximum number of tokens the model can think about. Only used when + running a thinking model. The end_thinking_tag argument must also be + provided to use this parameter. Returns ------- LogitsProcessorType @@ -124,5 +182,7 @@ def get_cfg_logits_processor( backend = _get_backend( backend_name or CFG_DEFAULT_BACKEND, model, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens, ) return backend.get_cfg_logits_processor(grammar) diff --git a/outlines/backends/llguidance.py b/outlines/backends/llguidance.py index 2f9ea59f5..1b37ecddd 100644 --- a/outlines/backends/llguidance.py +++ b/outlines/backends/llguidance.py @@ -25,6 +25,9 @@ def __init__( grammar: str, llg_tokenizer, tensor_library_name: str, + *, + end_thinking_token_id: int | None, + thinking_max_tokens: int | None, ) -> None: """ Parameters @@ -35,6 +38,10 @@ def __init__( The LLGuidance tokenizer tensor_library_name: str The name of the tensor library used by the model + end_thinking_token_id: int | None + The id of the end thinking token + thinking_max_tokens: int | None + The maximum number of tokens the model can think about """ if tensor_library_name not in SUPPORTED_TENSOR_LIBRARIES: @@ -44,6 +51,8 @@ def __init__( self.grammar = grammar self.llg_tokenizer = llg_tokenizer self.tensor_library_name = tensor_library_name + self.end_thinking_token_id = end_thinking_token_id + self.thinking_max_tokens = thinking_max_tokens or float("inf") super().__init__(tensor_library_name) def reset(self): @@ -65,31 +74,39 @@ def _setup(self, batch_size: int) -> None: """ from llguidance import LLMatcher - self.ll_matchers = [ + self._ll_matchers = [ LLMatcher(self.llg_tokenizer, self.grammar) for _ in range(batch_size) ] + self._is_thinking = [self.end_thinking_token_id is not None] * batch_size + self._generate_end_thinking_token = [False] * batch_size + self._num_tokens_generated = 0 # we must adapt the bitmask creation and the bias function to the # tensor library used by the model if self.tensor_library_name == "torch": import llguidance.torch - self.bitmask = llguidance.torch.allocate_token_bitmask(batch_size, self.llg_tokenizer.vocab_size) + self.allocate_token_bitmask = llguidance.torch.allocate_token_bitmask self._bias_logits = self._bias_logits_torch elif self.tensor_library_name == "numpy": import llguidance.numpy - self.bitmask = llguidance.numpy.allocate_token_bitmask(batch_size, self.llg_tokenizer.vocab_size) + self.allocate_token_bitmask = llguidance.numpy.allocate_token_bitmask self._bias_logits = self._bias_logits_numpy elif self.tensor_library_name == "mlx": # pragma: no cover import llguidance.numpy - self.bitmask = llguidance.numpy.allocate_token_bitmask(batch_size, self.llg_tokenizer.vocab_size) + self.allocate_token_bitmask = llguidance.numpy.allocate_token_bitmask self._bias_logits = self._bias_logits_mlx else: # pragma: no cover raise ValueError(f"Unsupported tensor library: {self.tensor_library_name}") + self._bitmasks = [ + self.allocate_token_bitmask(1, self.llg_tokenizer.vocab_size) + for _ in range(batch_size) + ] + def _bias_logits_mlx( # pragma: no cover self, input_ids: TensorType, logits: TensorType ) -> TensorType: @@ -99,10 +116,19 @@ def _bias_logits_mlx( # pragma: no cover biased_logits_array = [] for i in range(self.tensor_adapter.shape(input_ids)[0]): - llguidance.numpy.fill_next_token_bitmask(self.ll_matchers[i], self.bitmask, i) - biased_logits = llguidance.mlx.apply_token_bitmask( - logits[i], self.bitmask[i] # type: ignore - ) + if not self._is_thinking[i] and not self._generate_end_thinking_token[i]: + llguidance.numpy.fill_next_token_bitmask(self._ll_matchers[i], self._bitmasks[i], 0) + biased_logits = llguidance.mlx.apply_token_bitmask( + logits[i], self._bitmasks[i] # type: ignore + ) + elif self._generate_end_thinking_token[i]: + self._generate_end_thinking_token[i] = False + biased_logits = llguidance.mlx.apply_token_bitmask( + logits[i], self._bitmasks[i] # type: ignore + ) + else: + biased_logits = logits[i] + biased_logits_array.append(biased_logits) return self.tensor_adapter.concatenate(biased_logits_array) @@ -114,10 +140,16 @@ def _bias_logits_torch( import llguidance.torch for i in range(self.tensor_adapter.shape(input_ids)[0]): - llguidance.torch.fill_next_token_bitmask(self.ll_matchers[i], self.bitmask, i) - llguidance.torch.apply_token_bitmask_inplace( - logits[i], self.bitmask[i] # type: ignore - ) + if not self._is_thinking[i] and not self._generate_end_thinking_token[i]: + llguidance.torch.fill_next_token_bitmask(self._ll_matchers[i], self._bitmasks[i], 0) + llguidance.torch.apply_token_bitmask_inplace( + logits[i], self._bitmasks[i] # type: ignore + ) + elif self._generate_end_thinking_token[i]: + self._generate_end_thinking_token[i] = False + llguidance.torch.apply_token_bitmask_inplace( + logits[i], self._bitmasks[i] # type: ignore + ) return logits @@ -128,10 +160,16 @@ def _bias_logits_numpy( import llguidance.numpy for i in range(self.tensor_adapter.shape(input_ids)[0]): - llguidance.numpy.fill_next_token_bitmask(self.ll_matchers[i], self.bitmask, i) - llguidance.numpy.apply_token_bitmask_inplace( - logits[i], self.bitmask[i] # type: ignore - ) + if not self._is_thinking[i] and not self._generate_end_thinking_token[i]: + llguidance.numpy.fill_next_token_bitmask(self._ll_matchers[i], self._bitmasks[i], 0) + llguidance.numpy.apply_token_bitmask_inplace( + logits[i], self._bitmasks[i] # type: ignore + ) + elif self._generate_end_thinking_token[i]: + self._generate_end_thinking_token[i] = False + llguidance.numpy.apply_token_bitmask_inplace( + logits[i], self._bitmasks[i] # type: ignore + ) return logits @@ -153,20 +191,43 @@ def process_logits( The biased logits. """ + batch_size = self.tensor_adapter.shape(input_ids)[0] + vocab_size = self.tensor_adapter.shape(logits)[1] + if self.is_first_token: - self._setup(self.tensor_adapter.shape(input_ids)[0]) + self._setup(batch_size) self.is_first_token = False - - # we do not make the matchers consume the last token during the first - # generation step because no tokens have been generated yet else: - for i in range(self.tensor_adapter.shape(input_ids)[0]): - sequence = input_ids[i] # type: ignore - last_token = sequence[-1].item() - self.ll_matchers[i].consume_token(last_token) - error = self.ll_matchers[i].get_error() - if error: - warnings.warn(f"Error in LLMatcher: {error}") + self._num_tokens_generated += 1 + for i in range(batch_size): + latest_token_id = self.tensor_adapter.to_scalar(input_ids[i][-1]) # type: ignore + if not self._is_thinking[i]: + self._ll_matchers[i].consume_token(latest_token_id) + error = self._ll_matchers[i].get_error() + if error: + warnings.warn(f"Error in LLMatcher: {error}") + else: + # If the end of thinking token was generated at the + # previous step, we set thinking to False to start + # biasing the logits according to the guide + if latest_token_id == self.end_thinking_token_id: + self._is_thinking[i] = False + # If the max number of tokens has been generated, we + # modify the bitmask to only allow the end of thinking + # token to be generated and set generate_end_thinking_token + # to True to skip filling the bitmask (as we did it + # manually ourselves) + elif ( + self._num_tokens_generated >= self.thinking_max_tokens + ): + updated_bitmask = self.tensor_adapter.create_end_thinking_bitmask( + vocab_size, + self.end_thinking_token_id, + ) + self._bitmasks[i] = self.tensor_adapter.unsqueeze( + updated_bitmask # type: ignore + ) + self._generate_end_thinking_token[i] = True return self._bias_logits(input_ids, logits) @@ -174,12 +235,25 @@ def process_logits( class LLGuidanceBackend(BaseBackend): """Backend for LLGuidance.""" - def __init__(self, model: SteerableModel): + def __init__( + self, + model: SteerableModel, + *, + end_thinking_tag: str | None, + thinking_max_tokens: int | None, + ): """ Parameters ---------- model The Outlines model of the user. + end_thinking_tag + The tag the model uses to indicate the end of the thinking process. + Only used when running a thinking model. + thinking_max_tokens + The maximum number of tokens the model can think about. Only used + when running a thinking model. The end_thinking_tag argument must + also be provided to use this parameter. """ import llguidance as llg @@ -187,6 +261,23 @@ def __init__(self, model: SteerableModel): self.llg = llg self.tensor_library_name = model.tensor_library_name self.llg_tokenizer = self._create_llg_tokenizer(model) + encoded_end_thinking_tag = ( + self.llg_tokenizer.tokenize_str(end_thinking_tag) + if end_thinking_tag + else None + ) + if ( + encoded_end_thinking_tag is not None + and len(encoded_end_thinking_tag) != 1 + ): + raise ValueError( + "The end_thinking_tag must correspond to a single token in" + + "the tokenizer vocabulary." + ) + self.end_thinking_token_id = ( + encoded_end_thinking_tag[0] if encoded_end_thinking_tag else None + ) + self.thinking_max_tokens = thinking_max_tokens def _create_llg_tokenizer(self, model: SteerableModel) -> "LLGTokenizer": """Create an llg tokenizer from the Outlines model's tokenizer. @@ -246,7 +337,11 @@ def get_json_schema_logits_processor( """ grammar_spec = self.llg.grammar_from("json_schema", json_schema) return LLGuidanceLogitsProcessor( - grammar_spec, self.llg_tokenizer, self.tensor_library_name + grammar_spec, + self.llg_tokenizer, + self.tensor_library_name, + end_thinking_token_id=self.end_thinking_token_id, + thinking_max_tokens=self.thinking_max_tokens ) def get_regex_logits_processor( @@ -267,7 +362,11 @@ def get_regex_logits_processor( """ grammar_spec = self.llg.grammar_from("regex", regex) return LLGuidanceLogitsProcessor( - grammar_spec, self.llg_tokenizer, self.tensor_library_name + grammar_spec, + self.llg_tokenizer, + self.tensor_library_name, + end_thinking_token_id=self.end_thinking_token_id, + thinking_max_tokens=self.thinking_max_tokens ) def get_cfg_logits_processor( @@ -292,5 +391,9 @@ def get_cfg_logits_processor( except ValueError: grammar_spec = self.llg.grammar_from("lark", grammar) return LLGuidanceLogitsProcessor( - grammar_spec, self.llg_tokenizer, self.tensor_library_name + grammar_spec, + self.llg_tokenizer, + self.tensor_library_name, + end_thinking_token_id=self.end_thinking_token_id, + thinking_max_tokens=self.thinking_max_tokens ) diff --git a/outlines/backends/outlines_core.py b/outlines/backends/outlines_core.py index 4da664897..2997829e9 100644 --- a/outlines/backends/outlines_core.py +++ b/outlines/backends/outlines_core.py @@ -21,7 +21,12 @@ class OutlinesCoreLogitsProcessor(OutlinesLogitsProcessor): """Logits processor for Outlines Core.""" def __init__( - self, index: Index, tensor_library_name: str + self, + index: Index, + tensor_library_name: str, + *, + end_thinking_token_id: int | None, + thinking_max_tokens: int | None, ): """ Parameters @@ -31,11 +36,20 @@ def __init__( Core `Guide` instances that will be used to bias the logits tensor_library_name: str The tensor library name to use for the logits processor. + end_thinking_token_id: int | None + The token ID of the end of the thinking process. Only used when + running a thinking model. + thinking_max_tokens: int | None + The maximum number of tokens the model can think about. Only used + when running a thinking model. The end_thinking_token_id argument + must also be provided to use this parameter. """ self.index = index self.tensor_library_name = tensor_library_name self.is_first_token = True + self.end_thinking_token_id = end_thinking_token_id + self.thinking_max_tokens = thinking_max_tokens or float("inf") super().__init__(tensor_library_name) def reset(self) -> None: @@ -50,17 +64,21 @@ def _setup(self, batch_size: int, vocab_size: int) -> None: at initialization because we need to know the batch size. """ + self._is_thinking = [self.end_thinking_token_id is not None] * batch_size + self._generate_end_thinking_token = [False] * batch_size + self._num_tokens_generated = 0 + if self.tensor_library_name == "torch": from outlines_core.kernels.torch import allocate_token_bitmask self.allocate_token_bitmask = allocate_token_bitmask - self.bias_logits = self._bias_logits_torch + self._bias_logits = self._bias_logits_torch elif self.tensor_library_name == "numpy": from outlines_core.kernels.numpy import allocate_token_bitmask self.allocate_token_bitmask = allocate_token_bitmask - self.bias_logits = self._bias_logits_numpy + self._bias_logits = self._bias_logits_numpy elif self.tensor_library_name == "mlx": from outlines_core.kernels.mlx import ( @@ -68,7 +86,7 @@ def _setup(self, batch_size: int, vocab_size: int) -> None: ) self.allocate_token_bitmask = allocate_token_bitmask - self.bias_logits = self._bias_logits_mlx + self._bias_logits = self._bias_logits_mlx else: raise ValueError( @@ -92,10 +110,19 @@ def _bias_logits_mlx( # pragma: no cover biased_logits_array = [] for i in range(batch_size): - fill_next_token_bitmask(self._guides[i], self._bitmasks[i]) - biased_logits = apply_token_bitmask( - self.tensor_adapter.unsqueeze(logits[i]), self._bitmasks[i] # type: ignore - ) + if not self._is_thinking[i] and not self._generate_end_thinking_token[i]: + fill_next_token_bitmask(self._guides[i], self._bitmasks[i]) + biased_logits = apply_token_bitmask( + self.tensor_adapter.unsqueeze(logits[i]), self._bitmasks[i] # type: ignore + ) + elif self._generate_end_thinking_token[i]: + self._generate_end_thinking_token[i] = False + biased_logits = apply_token_bitmask( + self.tensor_adapter.unsqueeze(logits[i]), self._bitmasks[i] # type: ignore + ) + else: + biased_logits = self.tensor_adapter.unsqueeze(logits[i]) + biased_logits_array.append(biased_logits) return self.tensor_adapter.concatenate(biased_logits_array) @@ -110,11 +137,18 @@ def _bias_logits_torch( ) for i in range(batch_size): - fill_next_token_bitmask(self._guides[i], self._bitmasks[i]) - apply_token_bitmask_inplace( - self.tensor_adapter.unsqueeze(logits[i]), # type: ignore - self._bitmasks[i] - ) + if not self._is_thinking[i] and not self._generate_end_thinking_token[i]: + fill_next_token_bitmask(self._guides[i], self._bitmasks[i]) + apply_token_bitmask_inplace( + self.tensor_adapter.unsqueeze(logits[i]), # type: ignore + self._bitmasks[i] + ) + elif self._generate_end_thinking_token[i]: + self._generate_end_thinking_token[i] = False + apply_token_bitmask_inplace( + self.tensor_adapter.unsqueeze(logits[i]), # type: ignore + self._bitmasks[i] + ) return logits @@ -128,11 +162,18 @@ def _bias_logits_numpy( ) for i in range(batch_size): - fill_next_token_bitmask(self._guides[i], self._bitmasks[i]) - apply_token_bitmask_inplace( - self.tensor_adapter.unsqueeze(logits[i]), # type: ignore - self._bitmasks[i] - ) + if not self._is_thinking[i] and not self._generate_end_thinking_token[i]: + fill_next_token_bitmask(self._guides[i], self._bitmasks[i]) + apply_token_bitmask_inplace( + self.tensor_adapter.unsqueeze(logits[i]), # type: ignore + self._bitmasks[i] + ) + elif self._generate_end_thinking_token[i]: + self._generate_end_thinking_token[i] = False + apply_token_bitmask_inplace( + self.tensor_adapter.unsqueeze(logits[i]), # type: ignore + self._bitmasks[i] + ) return logits @@ -161,26 +202,63 @@ def process_logits( self._setup(batch_size, vocab_size) self.is_first_token = False else: + self._num_tokens_generated += 1 for i in range(batch_size): - last_token_id = self.tensor_adapter.to_scalar(input_ids[i][-1]) # type: ignore - if not self._guides[i].is_finished(): - self._guides[i].advance( - token_id=last_token_id, - return_tokens=False - ) - - return self.bias_logits(batch_size, logits) + latest_token_id = self.tensor_adapter.to_scalar(input_ids[i][-1]) # type: ignore + if not self._is_thinking[i]: + if not self._guides[i].is_finished(): + self._guides[i].advance( + token_id=latest_token_id, + return_tokens=False + ) + else: + # If the end of thinking token was generated at the + # previous step, we set thinking to False to start + # biasing the logits according to the guide + if latest_token_id == self.end_thinking_token_id: + self._is_thinking[i] = False + # If the max number of tokens has been generated, we + # modify the bitmask to only allow the end of thinking + # token to be generated and set generate_end_thinking_token + # to True to skip filling the bitmask (as we did it + # manually ourselves) + elif ( + self._num_tokens_generated >= self.thinking_max_tokens + ): + updated_bitmask = self.tensor_adapter.create_end_thinking_bitmask( + vocab_size, + self.end_thinking_token_id, + ) + self._bitmasks[i] = self.tensor_adapter.unsqueeze( + updated_bitmask # type: ignore + ) + self._generate_end_thinking_token[i] = True + + return self._bias_logits(batch_size, logits) class OutlinesCoreBackend(BaseBackend): """Backend for Outlines Core.""" - def __init__(self, model: SteerableModel): + def __init__( + self, + model: SteerableModel, + *, + end_thinking_tag: str | None, + thinking_max_tokens: int | None, + ): """ Parameters ---------- model The Outlines model of the user. + end_thinking_tag + The tag the model uses to indicate the end of the thinking process. + Only used when running a thinking model. + thinking_max_tokens + The maximum number of tokens the model can think about. Only used + when running a thinking model. The end_thinking_tag argument must + also be provided to use this parameter. """ if isinstance(model, Transformers): @@ -209,6 +287,25 @@ def __init__(self, model: SteerableModel): vocabulary, eos_token_id, eos_token, token_to_str ) self.tensor_library_name = model.tensor_library_name + encoded_end_thinking_tag = ( + tokenizer.encode(end_thinking_tag) + if end_thinking_tag + else None + ) + if ( + encoded_end_thinking_tag is not None + and len(encoded_end_thinking_tag) != 1 + ): + raise ValueError( + "The end_thinking_tag must correspond to a single token in" + + "the tokenizer vocabulary." + ) + self.end_thinking_token_id = ( + encoded_end_thinking_tag[0] + if encoded_end_thinking_tag is not None + else None + ) + self.thinking_max_tokens = thinking_max_tokens def get_json_schema_logits_processor( self, json_schema: str @@ -244,7 +341,12 @@ def get_regex_logits_processor(self, regex: str): """ index = Index(regex, self.vocabulary) - return OutlinesCoreLogitsProcessor(index, self.tensor_library_name) + return OutlinesCoreLogitsProcessor( + index, + self.tensor_library_name, + end_thinking_token_id=self.end_thinking_token_id, + thinking_max_tokens=self.thinking_max_tokens, + ) def get_cfg_logits_processor(self, grammar): raise NotImplementedError( diff --git a/outlines/backends/xgrammar.py b/outlines/backends/xgrammar.py index c9c935fd7..ac437cafb 100644 --- a/outlines/backends/xgrammar.py +++ b/outlines/backends/xgrammar.py @@ -13,7 +13,14 @@ class XGrammarLogitsProcessor(OutlinesLogitsProcessor): """Logits processor for XGrammar.""" - def __init__(self, compiled_grammar: str, tensor_library_name: str,): + def __init__( + self, + compiled_grammar: str, + tensor_library_name: str, + *, + end_thinking_token_id: int | None, + thinking_max_tokens: int | None, + ): """ Parameters ---------- @@ -21,6 +28,10 @@ def __init__(self, compiled_grammar: str, tensor_library_name: str,): The compiled grammar to use to create the logits processor. tensor_library_name: str The name of the tensor library used by the model + end_thinking_token_id: int | None + The id of the end thinking token + thinking_max_tokens: int | None + The maximum number of tokens the model can think about """ import xgrammar as xgr @@ -29,6 +40,8 @@ def __init__(self, compiled_grammar: str, tensor_library_name: str,): self.is_first_token = True self.compiled_grammar = compiled_grammar self.tensor_library_name = tensor_library_name + self.end_thinking_token_id = end_thinking_token_id + self.thinking_max_tokens = thinking_max_tokens or float("inf") super().__init__(tensor_library_name) def reset(self): @@ -37,6 +50,18 @@ def reset(self): def _setup(self, batch_size: int, vocab_size: int) -> None: """Setup the logits processor for a new generation.""" + self._matchers = [ + self.xgr.GrammarMatcher(self.compiled_grammar) + for _ in range(batch_size) + ] + self._bitmasks = [ + self.xgr.allocate_token_bitmask(1, vocab_size) + for _ in range(batch_size) + ] + self._is_thinking = [self.end_thinking_token_id is not None] * batch_size + self._generate_end_thinking_token = [False] * batch_size + self._num_tokens_generated = 0 + if self.tensor_library_name == "torch": self._bias_logits = self._bias_logits_torch elif self.tensor_library_name == "mlx": @@ -46,21 +71,18 @@ def _setup(self, batch_size: int, vocab_size: int) -> None: f"Unsupported tensor library: {self.tensor_library_name}" ) - self._matchers = [ - self.xgr.GrammarMatcher(self.compiled_grammar) - for _ in range(batch_size) - ] - self._bitmask = self.xgr.allocate_token_bitmask(batch_size, vocab_size) - def _bias_logits_torch( self, input_ids: TensorType, logits: TensorType ) -> TensorType: """Bias the logits for Torch tensors.""" for i in range(self.tensor_adapter.shape(input_ids)[0]): - if not self._matchers[i].is_terminated(): - self._matchers[i].fill_next_token_bitmask(self._bitmask, i) - - self.xgr.apply_token_bitmask_inplace(logits, self._bitmask) + if not self._is_thinking[i] and not self._generate_end_thinking_token[i]: + if not self._matchers[i].is_terminated(): + self._matchers[i].fill_next_token_bitmask(self._bitmasks[i], 0) + self.xgr.apply_token_bitmask_inplace(logits[i], self._bitmasks[i]) + elif self._generate_end_thinking_token[i]: + self._generate_end_thinking_token[i] = False + self.xgr.apply_token_bitmask_inplace(logits[i], self._bitmasks[i]) return logits @@ -71,15 +93,26 @@ def _bias_logits_mlx( # pragma: no cover import mlx.core as mx from xgrammar.kernels.apply_token_bitmask_mlx import apply_token_bitmask_mlx - for i in range(self.tensor_adapter.shape(input_ids)[0]): - if not self._matchers[i].is_terminated(): - self._matchers[i].fill_next_token_bitmask(self._bitmask, i) - - biased_logits = apply_token_bitmask_mlx( - mx.array(self._bitmask.numpy()), logits, self.tensor_adapter.shape(logits)[1] - ) + biased_logits_array = [] - return biased_logits + for i in range(self.tensor_adapter.shape(input_ids)[0]): + if not self._is_thinking[i] and not self._generate_end_thinking_token[i]: + if not self._matchers[i].is_terminated(): + self._matchers[i].fill_next_token_bitmask(self._bitmasks[i], 0) + biased_logits = apply_token_bitmask_mlx( + mx.array(self._bitmasks[i].numpy()), logits[i], 1 + ) + elif self._generate_end_thinking_token[i]: + self._generate_end_thinking_token[i] = False + biased_logits = apply_token_bitmask_mlx( + mx.array(self._bitmasks[i].numpy()), logits[i], 1 + ) + else: + biased_logits = logits[i] + + biased_logits_array.append(biased_logits) + + return self.tensor_adapter.concatenate(biased_logits_array) def process_logits( self, input_ids: TensorType, logits: TensorType @@ -92,12 +125,36 @@ def process_logits( self._setup(batch_size, vocab_size) self.is_first_token = False else: + self._num_tokens_generated += 1 for i in range(batch_size): - if not self._matchers[i].is_terminated(): - last_token_id = self.tensor_adapter.to_scalar( - input_ids[i][-1] # type: ignore - ) - assert self._matchers[i].accept_token(last_token_id) + latest_token_id = self.tensor_adapter.to_scalar( + input_ids[i][-1] # type: ignore + ) + if not self._is_thinking[i]: + if not self._matchers[i].is_terminated(): + assert self._matchers[i].accept_token(latest_token_id) + else: + # If the end of thinking token was generated at the + # previous step, we set thinking to False to start + # biasing the logits according to the guide + if latest_token_id == self.end_thinking_token_id: + self._is_thinking[i] = False + # If the max number of tokens has been generated, we + # modify the bitmask to only allow the end of thinking + # token to be generated and set generate_end_thinking_token + # to True to skip filling the bitmask (as we did it + # manually ourselves) + elif ( + self._num_tokens_generated >= self.thinking_max_tokens + ): + updated_bitmask = self.tensor_adapter.create_end_thinking_bitmask( + vocab_size, + self.end_thinking_token_id, + ) + self._bitmasks[i] = self.tensor_adapter.unsqueeze( + updated_bitmask # type: ignore + ) + self._generate_end_thinking_token[i] = True return self._bias_logits(input_ids, logits) @@ -105,12 +162,25 @@ def process_logits( class XGrammarBackend(BaseBackend): """Backend for XGrammar.""" - def __init__(self, model: SteerableModel): + def __init__( + self, + model: SteerableModel, + *, + end_thinking_tag: str | None, + thinking_max_tokens: int | None, + ): """ Parameters ---------- model The Outlines model of the user. + end_thinking_tag + The tag the model uses to indicate the end of the thinking process. + Only used when running a thinking model. + thinking_max_tokens + The maximum number of tokens the model can think about. Only used + when running a thinking model. The end_thinking_tag argument must + also be provided to use this parameter. """ import xgrammar as xgr @@ -131,6 +201,25 @@ def __init__(self, model: SteerableModel): ) self.grammar_compiler = xgr.GrammarCompiler(tokenizer_info) self.tensor_library_name = model.tensor_library_name + encoded_end_thinking_tag = ( + tokenizer.encode(end_thinking_tag) + if end_thinking_tag + else None + ) + if ( + encoded_end_thinking_tag is not None + and len(encoded_end_thinking_tag) != 1 + ): + raise ValueError( + "The end_thinking_tag must correspond to a single token in" + + "the tokenizer vocabulary." + ) + self.end_thinking_token_id = ( + encoded_end_thinking_tag[0] + if encoded_end_thinking_tag is not None + else None + ) + self.thinking_max_tokens = thinking_max_tokens def get_json_schema_logits_processor( self, json_schema: str @@ -153,7 +242,9 @@ def get_json_schema_logits_processor( ) return XGrammarLogitsProcessor( compiled_grammar, - self.tensor_library_name + self.tensor_library_name, + end_thinking_token_id=self.end_thinking_token_id, + thinking_max_tokens=self.thinking_max_tokens ) def get_regex_logits_processor( @@ -175,7 +266,9 @@ def get_regex_logits_processor( compiled_grammar = self.grammar_compiler.compile_regex(regex) return XGrammarLogitsProcessor( compiled_grammar, - self.tensor_library_name + self.tensor_library_name, + end_thinking_token_id=self.end_thinking_token_id, + thinking_max_tokens=self.thinking_max_tokens ) def get_cfg_logits_processor( @@ -197,5 +290,7 @@ def get_cfg_logits_processor( compiled_grammar = self.grammar_compiler.compile_grammar(grammar) return XGrammarLogitsProcessor( compiled_grammar, - self.tensor_library_name + self.tensor_library_name, + end_thinking_token_id=self.end_thinking_token_id, + thinking_max_tokens=self.thinking_max_tokens ) diff --git a/outlines/generator.py b/outlines/generator.py index f2e669d8f..b8604bc26 100644 --- a/outlines/generator.py +++ b/outlines/generator.py @@ -218,6 +218,9 @@ def __init__( model: SteerableModel, output_type: Optional[Any], backend_name: Optional[str] = None, + *, + end_thinking_tag: Optional[str] = None, + thinking_max_tokens: Optional[int] = None, ): """ Parameters @@ -228,6 +231,13 @@ def __init__( The output type expressed as a Python type backend_name The name of the backend to use to create the logits processor. + end_thinking_tag + The tag the model uses to indicate the end of the thinking process. + Only used when running a thinking model. + thinking_max_tokens + The maximum number of tokens the model can think about. Only used + when running a thinking model. The end_thinking_tag argument must + also be provided to use this parameter. """ self.model = model @@ -241,12 +251,16 @@ def __init__( backend_name, model, cfg_string, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens, ) elif isinstance(term, JsonSchema): self.logits_processor = get_json_schema_logits_processor( backend_name, model, term.schema, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens, ) else: regex_string = to_regex(term) @@ -254,6 +268,8 @@ def __init__( backend_name, model, regex_string, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens, ) @classmethod @@ -349,6 +365,8 @@ def Generator( backend: Optional[str] = None, *, processor: Optional[LogitsProcessorType] = None, + end_thinking_tag: Optional[str] = None, + thinking_max_tokens: Optional[int] = None, ) -> Union[SteerableGenerator, BlackBoxGenerator, AsyncBlackBoxGenerator]: """Create a generator for the given model and output parameters. @@ -369,6 +387,13 @@ def Generator( not provided. processor An instance of a logits processor. + end_thinking_tag + The tag the model uses to indicate the end of the thinking process. + Only used for steerable models running a thinking model. + thinking_max_tokens + The maximum number of tokens the model can think about. Only used for + steerable models running a thinking model. The end_thinking_tag + argument must also be provided to use this parameter. Returns ------- @@ -389,7 +414,13 @@ def Generator( if processor is not None: return SteerableGenerator.from_processor(model, processor) # type: ignore else: - return SteerableGenerator(model, output_type, backend) # type: ignore + return SteerableGenerator( + model, # type: ignore + output_type, + backend, + end_thinking_tag=end_thinking_tag, + thinking_max_tokens=thinking_max_tokens, + ) else: if processor is not None: raise NotImplementedError( diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index 9b4e4fe50..0639987e8 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -1,14 +1,18 @@ """Base class for logits processors.""" from abc import abstractmethod -from typing import TypeVar +from typing import TypeVar, Any, Protocol from outlines.processors.tensor_adapters import ( TensorAdapterImplementation, tensor_adapters, ) -TensorType = TypeVar('TensorType') + +class Indexable(Protocol): + def __getitem__(self, key: Any) -> Any: ... + +TensorType = TypeVar('TensorType', bound=Indexable) class OutlinesLogitsProcessor: diff --git a/outlines/processors/tensor_adapters/base.py b/outlines/processors/tensor_adapters/base.py index ffaac58df..a30b43661 100644 --- a/outlines/processors/tensor_adapters/base.py +++ b/outlines/processors/tensor_adapters/base.py @@ -255,3 +255,29 @@ def argsort_descending( """ ... + + @abstractmethod + def create_end_thinking_bitmask( + self, + size: int, + end_thinking_token_id: int + ) -> TensorType: # type: ignore + """Create a 1D bitmask tensor with 0 everywhere except at the position + of the end thinking token. + + Parameters + ---------- + size + The size of the tensor. + end_thinking_token_id + The id of the end thinking token. It corresponds to the position + of the tensor for which the bitmask will be 1. + + Returns + ------- + TensorType + A 1D bitmask tensor with a 1 at the position of the end thinking + token. + + """ + ... diff --git a/outlines/processors/tensor_adapters/jax.py b/outlines/processors/tensor_adapters/jax.py index 30cd2ecc9..7c7fa48cc 100644 --- a/outlines/processors/tensor_adapters/jax.py +++ b/outlines/processors/tensor_adapters/jax.py @@ -48,3 +48,13 @@ def apply_mask(self, tensor, mask, value): def argsort_descending(self, tensor): return self.jax.numpy.argsort(-tensor) + + def create_end_thinking_bitmask(self, size, end_thinking_token_id): + bitmask = self.jax.numpy.zeros( + (size + 31) // 32, + dtype=self.jax.numpy.int32 + ) + byte_index = end_thinking_token_id // 32 + bit_index = end_thinking_token_id % 32 + bitmask = bitmask.at[byte_index].set(1 << bit_index) + return bitmask diff --git a/outlines/processors/tensor_adapters/mlx.py b/outlines/processors/tensor_adapters/mlx.py index 5714808d4..4249aeb3a 100644 --- a/outlines/processors/tensor_adapters/mlx.py +++ b/outlines/processors/tensor_adapters/mlx.py @@ -58,3 +58,14 @@ def apply_mask(self, tensor, mask, value): def argsort_descending(self, tensor): return self.mlx.argsort(-tensor) + + def create_end_thinking_bitmask(self, size, end_thinking_token_id): + bitmask = self.mlx.zeros( + (size + 31) // 32, + dtype=self.mlx.int32 + ) + byte_index = end_thinking_token_id // 32 + bit_index = end_thinking_token_id % 32 + bitmask = self.mlx.array(bitmask) + bitmask[byte_index] = 1 << bit_index + return bitmask diff --git a/outlines/processors/tensor_adapters/numpy.py b/outlines/processors/tensor_adapters/numpy.py index 831220444..257d605a1 100644 --- a/outlines/processors/tensor_adapters/numpy.py +++ b/outlines/processors/tensor_adapters/numpy.py @@ -48,3 +48,13 @@ def apply_mask(self, tensor, mask, value): def argsort_descending(self, tensor): return self.numpy.argsort(-tensor) + + def create_end_thinking_bitmask(self, size, end_thinking_token_id): + bitmask = self.numpy.zeros( + (size + 31) // 32, + dtype=self.numpy.int32 + ) + byte_index = end_thinking_token_id // 32 + bit_index = end_thinking_token_id % 32 + bitmask[byte_index] = 1 << bit_index + return bitmask diff --git a/outlines/processors/tensor_adapters/tensorflow.py b/outlines/processors/tensor_adapters/tensorflow.py index 382f847b6..a27fbd1f6 100644 --- a/outlines/processors/tensor_adapters/tensorflow.py +++ b/outlines/processors/tensor_adapters/tensorflow.py @@ -48,3 +48,17 @@ def apply_mask(self, tensor, mask, value): def argsort_descending(self, tensor): return self.tf.argsort(tensor, direction='DESCENDING') + + def create_end_thinking_bitmask(self, size, end_thinking_token_id): + bitmask = self.tf.zeros( + (size + 31) // 32, + dtype=self.tf.int32 + ) + byte_index = end_thinking_token_id // 32 + bit_index = end_thinking_token_id % 32 + bitmask = self.tf.tensor_scatter_nd_update( + bitmask, + [[byte_index]], + [1 << bit_index] + ) + return bitmask diff --git a/outlines/processors/tensor_adapters/torch.py b/outlines/processors/tensor_adapters/torch.py index 4b624e960..82b0633bd 100644 --- a/outlines/processors/tensor_adapters/torch.py +++ b/outlines/processors/tensor_adapters/torch.py @@ -46,3 +46,14 @@ def apply_mask(self, tensor, mask, value): def argsort_descending(self, tensor): return self.torch.argsort(tensor, descending=True) + + def create_end_thinking_bitmask(self, size, end_thinking_token_id): + bitmask = self.torch.zeros( + (size + 31) // 32, + dtype=self.torch.int32, + pin_memory=self.torch.cuda.is_available() + ) + byte_index = end_thinking_token_id // 32 + bit_index = end_thinking_token_id % 32 + bitmask[byte_index] = 1 << bit_index + return bitmask