diff --git a/example_serialization.py b/example_serialization.py new file mode 100644 index 000000000..7c20fd327 --- /dev/null +++ b/example_serialization.py @@ -0,0 +1,21 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer + +from outlines.types.dsl import Regex +from outlines.models.transformers import from_transformers +import outlines + + +model = from_transformers( + AutoModelForCausalLM.from_pretrained("erwanf/gpt2-mini"), + AutoTokenizer.from_pretrained("erwanf/gpt2-mini"), +) + +generator = outlines.Generator(model, Regex(r"\d{3}"), backend="xgrammar") +# pickable both before and after a generation +generator.to_disk("logits_processor.pkl") +result = generator("Generate a 3-digit number:") +generator.to_disk("logits_processor.pkl") + +new_generator = outlines.Generator(model, processor="logits_processor.pkl") +result = new_generator("Generate a 3-digit number:") +print(result) diff --git a/outlines/backends/llguidance.py b/outlines/backends/llguidance.py index 5b6d168b1..371e23e22 100644 --- a/outlines/backends/llguidance.py +++ b/outlines/backends/llguidance.py @@ -176,6 +176,12 @@ def process_logits( return self._bias_logits(input_ids, logits) + def __getstate__(self): + """Create a picklable representation of the processor.""" + raise NotImplementedError( + "Serializing the logits processor is not supported for LLGuidance" + ) + class LLGuidanceBackend(BaseBackend): """Backend for LLGuidance.""" diff --git a/outlines/backends/outlines_core.py b/outlines/backends/outlines_core.py index 23457fd5a..fb4ef1570 100644 --- a/outlines/backends/outlines_core.py +++ b/outlines/backends/outlines_core.py @@ -35,12 +35,16 @@ def __init__( """ self.index = index self.tensor_library_name = tensor_library_name - self.is_first_token = True + self.reset() super().__init__(tensor_library_name) def reset(self) -> None: - """Reset the logits processor.""" + """Reset the logits processor to prepare for a new generation.""" self.is_first_token = True + self._guides = None + self._bitmasks = None + self.bias_logits = None + self.allocate_token_bitmask = None def _setup(self, batch_size: int, vocab_size: int) -> None: """Set the guides, bitmasks and some functions used in the @@ -193,6 +197,18 @@ def process_logits( return self.bias_logits(batch_size, logits) + def __getstate__(self): + """Create a picklable representation of the processor.""" + self.reset() + state = self.__dict__.copy() + del state["tensor_adapter"] + return state + + def __setstate__(self, state): + """Restore the processor from a pickled state.""" + self.__dict__.update(state) + super().__init__(self.tensor_library_name) + class OutlinesCoreBackend(BaseBackend): """Backend for Outlines Core.""" diff --git a/outlines/backends/xgrammar.py b/outlines/backends/xgrammar.py index 4b6a59742..b4a02f203 100644 --- a/outlines/backends/xgrammar.py +++ b/outlines/backends/xgrammar.py @@ -23,20 +23,22 @@ def __init__(self, compiled_grammar: str, tensor_library_name: str,): The name of the tensor library used by the model """ - import xgrammar as xgr - - self.xgr = xgr - self.is_first_token = True self.compiled_grammar = compiled_grammar self.tensor_library_name = tensor_library_name + self.reset() super().__init__(tensor_library_name) def reset(self): """Ensure self._setup is called again for the next generation.""" self.is_first_token = True + self._matchers = None + self._bitmask = None + self._bias_logits = None def _setup(self, batch_size: int, vocab_size: int) -> None: """Setup the logits processor for a new generation.""" + import xgrammar as xgr + if self.tensor_library_name == "torch": self._bias_logits = self._bias_logits_torch elif self.tensor_library_name == "mlx": # pragma: no cover @@ -47,15 +49,17 @@ def _setup(self, batch_size: int, vocab_size: int) -> None: ) self._matchers = [ - self.xgr.GrammarMatcher(self.compiled_grammar) + xgr.GrammarMatcher(self.compiled_grammar) for _ in range(batch_size) ] - self._bitmask = self.xgr.allocate_token_bitmask(batch_size, vocab_size) + self._bitmask = xgr.allocate_token_bitmask(batch_size, vocab_size) def _bias_logits_torch( self, input_ids: TensorType, logits: TensorType ) -> TensorType: """Bias the logits for Torch tensors.""" + import xgrammar as xgr + for i in range(self.tensor_adapter.shape(input_ids)[0]): if not self._matchers[i].is_terminated(): self._matchers[i].fill_next_token_bitmask(self._bitmask, i) @@ -64,7 +68,7 @@ def _bias_logits_torch( self._bitmask, self.tensor_adapter.get_device(logits) ) - self.xgr.apply_token_bitmask_inplace(logits, self._bitmask) + xgr.apply_token_bitmask_inplace(logits, self._bitmask) self._bitmask = self.tensor_adapter.to_device( self._bitmask, "cpu" @@ -109,6 +113,27 @@ def process_logits( return self._bias_logits(input_ids, logits) + def __getstate__(self): + """Create a picklable representation of the processor.""" + self.reset() + state = self.__dict__.copy() + del state["tensor_adapter"] + compiled_grammar = state.pop("compiled_grammar") + state["serialized_compiled_grammar"] = compiled_grammar.serialize_json() + state["serialized_tokenizer_info"] = compiled_grammar.tokenizer_info.serialize_json() + return state + + def __setstate__(self, state): + """Restore the processor from a pickled state.""" + import xgrammar as xgr + + serialized_tokenizer_info = state.pop("serialized_tokenizer_info", None) + tokenizer_info = xgr.TokenizerInfo.deserialize_json(serialized_tokenizer_info) + serialized_compiled_grammar = state.pop("serialized_compiled_grammar") + compiled_grammar = xgr.CompiledGrammar.deserialize_json(serialized_compiled_grammar, tokenizer_info) + self.__dict__.update({**state, "compiled_grammar": compiled_grammar}) + super().__init__(self.tensor_library_name) + class XGrammarBackend(BaseBackend): """Backend for XGrammar.""" diff --git a/outlines/generator.py b/outlines/generator.py index f2e669d8f..6f04a03a3 100644 --- a/outlines/generator.py +++ b/outlines/generator.py @@ -1,5 +1,6 @@ """Encapsulate a model and an output type into a reusable object.""" +import pickle from typing import ( Any, AsyncIterator, @@ -276,6 +277,11 @@ def from_processor( return instance + def to_disk(self, path: str): + """Save the logits processor of the generator in a local file.""" + with open(path, "wb") as f: + pickle.dump(self.logits_processor, f) + def __call__(self, prompt: Any, **inference_kwargs) -> Any: """Generate a response from the model. @@ -348,13 +354,14 @@ def Generator( output_type: Optional[Any] = None, backend: Optional[str] = None, *, - processor: Optional[LogitsProcessorType] = None, + processor: Optional[LogitsProcessorType | str] = None, ) -> Union[SteerableGenerator, BlackBoxGenerator, AsyncBlackBoxGenerator]: """Create a generator for the given model and output parameters. - The 2 parameters output_type and processor are mutually exclusive. The - parameters processor is only supported for SteerableModel instances - (typically local models) and is intended to be only used by advanced users. + The parameters output_type and backend on one side and processor on the + other side are mutually exclusive. The parameters processor is only + supported for SteerableModel instances (typically local models) and + is intended to be only used by advanced users. Parameters ---------- @@ -368,7 +375,8 @@ def Generator( used for steerable models if there is an output type and `processor` is not provided. processor - An instance of a logits processor. + An instance of a logits processor or the path to a file containing the + logits processor (output of `SteerableGenerator.to_disk`). Returns ------- @@ -387,6 +395,9 @@ def Generator( if isinstance(model, SteerableModel): # type: ignore if processor is not None: + if isinstance(processor, str): + with open(processor, "rb") as f: + processor = pickle.load(f) return SteerableGenerator.from_processor(model, processor) # type: ignore else: return SteerableGenerator(model, output_type, backend) # type: ignore