Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions example_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from outlines.types.dsl import Regex
from outlines.models.transformers import from_transformers
import outlines


model = from_transformers(
AutoModelForCausalLM.from_pretrained("erwanf/gpt2-mini"),
AutoTokenizer.from_pretrained("erwanf/gpt2-mini"),
)

generator = outlines.Generator(model, Regex(r"\d{3}"), backend="xgrammar")
# pickable both before and after a generation
generator.to_disk("logits_processor.pkl")
result = generator("Generate a 3-digit number:")
generator.to_disk("logits_processor.pkl")

new_generator = outlines.Generator(model, processor="logits_processor.pkl")
result = new_generator("Generate a 3-digit number:")
print(result)
6 changes: 6 additions & 0 deletions outlines/backends/llguidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def process_logits(

return self._bias_logits(input_ids, logits)

def __getstate__(self):
"""Create a picklable representation of the processor."""
raise NotImplementedError(
"Serializing the logits processor is not supported for LLGuidance"
)


class LLGuidanceBackend(BaseBackend):
"""Backend for LLGuidance."""
Expand Down
20 changes: 18 additions & 2 deletions outlines/backends/outlines_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@ def __init__(
"""
self.index = index
self.tensor_library_name = tensor_library_name
self.is_first_token = True
self.reset()
super().__init__(tensor_library_name)

def reset(self) -> None:
"""Reset the logits processor."""
"""Reset the logits processor to prepare for a new generation."""
self.is_first_token = True
self._guides = None
self._bitmasks = None
self.bias_logits = None
self.allocate_token_bitmask = None

def _setup(self, batch_size: int, vocab_size: int) -> None:
"""Set the guides, bitmasks and some functions used in the
Expand Down Expand Up @@ -193,6 +197,18 @@ def process_logits(

return self.bias_logits(batch_size, logits)

def __getstate__(self):
"""Create a picklable representation of the processor."""
self.reset()
state = self.__dict__.copy()
del state["tensor_adapter"]
return state

def __setstate__(self, state):
"""Restore the processor from a pickled state."""
self.__dict__.update(state)
super().__init__(self.tensor_library_name)


class OutlinesCoreBackend(BaseBackend):
"""Backend for Outlines Core."""
Expand Down
39 changes: 32 additions & 7 deletions outlines/backends/xgrammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,22 @@ def __init__(self, compiled_grammar: str, tensor_library_name: str,):
The name of the tensor library used by the model

"""
import xgrammar as xgr

self.xgr = xgr
self.is_first_token = True
self.compiled_grammar = compiled_grammar
self.tensor_library_name = tensor_library_name
self.reset()
super().__init__(tensor_library_name)

def reset(self):
"""Ensure self._setup is called again for the next generation."""
self.is_first_token = True
self._matchers = None
self._bitmask = None
self._bias_logits = None

def _setup(self, batch_size: int, vocab_size: int) -> None:
"""Setup the logits processor for a new generation."""
import xgrammar as xgr

if self.tensor_library_name == "torch":
self._bias_logits = self._bias_logits_torch
elif self.tensor_library_name == "mlx": # pragma: no cover
Expand All @@ -47,15 +49,17 @@ def _setup(self, batch_size: int, vocab_size: int) -> None:
)

self._matchers = [
self.xgr.GrammarMatcher(self.compiled_grammar)
xgr.GrammarMatcher(self.compiled_grammar)
for _ in range(batch_size)
]
self._bitmask = self.xgr.allocate_token_bitmask(batch_size, vocab_size)
self._bitmask = xgr.allocate_token_bitmask(batch_size, vocab_size)

def _bias_logits_torch(
self, input_ids: TensorType, logits: TensorType
) -> TensorType:
"""Bias the logits for Torch tensors."""
import xgrammar as xgr

for i in range(self.tensor_adapter.shape(input_ids)[0]):
if not self._matchers[i].is_terminated():
self._matchers[i].fill_next_token_bitmask(self._bitmask, i)
Expand All @@ -64,7 +68,7 @@ def _bias_logits_torch(
self._bitmask,
self.tensor_adapter.get_device(logits)
)
self.xgr.apply_token_bitmask_inplace(logits, self._bitmask)
xgr.apply_token_bitmask_inplace(logits, self._bitmask)
self._bitmask = self.tensor_adapter.to_device(
self._bitmask,
"cpu"
Expand Down Expand Up @@ -109,6 +113,27 @@ def process_logits(

return self._bias_logits(input_ids, logits)

def __getstate__(self):
"""Create a picklable representation of the processor."""
self.reset()
state = self.__dict__.copy()
del state["tensor_adapter"]
compiled_grammar = state.pop("compiled_grammar")
state["serialized_compiled_grammar"] = compiled_grammar.serialize_json()
state["serialized_tokenizer_info"] = compiled_grammar.tokenizer_info.serialize_json()
return state

def __setstate__(self, state):
"""Restore the processor from a pickled state."""
import xgrammar as xgr

serialized_tokenizer_info = state.pop("serialized_tokenizer_info", None)
tokenizer_info = xgr.TokenizerInfo.deserialize_json(serialized_tokenizer_info)
serialized_compiled_grammar = state.pop("serialized_compiled_grammar")
compiled_grammar = xgr.CompiledGrammar.deserialize_json(serialized_compiled_grammar, tokenizer_info)
self.__dict__.update({**state, "compiled_grammar": compiled_grammar})
super().__init__(self.tensor_library_name)


class XGrammarBackend(BaseBackend):
"""Backend for XGrammar."""
Expand Down
21 changes: 16 additions & 5 deletions outlines/generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Encapsulate a model and an output type into a reusable object."""

import pickle
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -276,6 +277,11 @@ def from_processor(

return instance

def to_disk(self, path: str):
"""Save the logits processor of the generator in a local file."""
with open(path, "wb") as f:
pickle.dump(self.logits_processor, f)

def __call__(self, prompt: Any, **inference_kwargs) -> Any:
"""Generate a response from the model.

Expand Down Expand Up @@ -348,13 +354,14 @@ def Generator(
output_type: Optional[Any] = None,
backend: Optional[str] = None,
*,
processor: Optional[LogitsProcessorType] = None,
processor: Optional[LogitsProcessorType | str] = None,
) -> Union[SteerableGenerator, BlackBoxGenerator, AsyncBlackBoxGenerator]:
"""Create a generator for the given model and output parameters.

The 2 parameters output_type and processor are mutually exclusive. The
parameters processor is only supported for SteerableModel instances
(typically local models) and is intended to be only used by advanced users.
The parameters output_type and backend on one side and processor on the
other side are mutually exclusive. The parameters processor is only
supported for SteerableModel instances (typically local models) and
is intended to be only used by advanced users.

Parameters
----------
Expand All @@ -368,7 +375,8 @@ def Generator(
used for steerable models if there is an output type and `processor` is
not provided.
processor
An instance of a logits processor.
An instance of a logits processor or the path to a file containing the
logits processor (output of `SteerableGenerator.to_disk`).

Returns
-------
Expand All @@ -387,6 +395,9 @@ def Generator(

if isinstance(model, SteerableModel): # type: ignore
if processor is not None:
if isinstance(processor, str):
with open(processor, "rb") as f:
processor = pickle.load(f)
return SteerableGenerator.from_processor(model, processor) # type: ignore
else:
return SteerableGenerator(model, output_type, backend) # type: ignore
Expand Down
Loading