3232import torch
3333from pydantic import BaseModel
3434
35- from outlines .fsm .guide import RegexGuide
35+ from outlines .fsm .guide import Guide , RegexGuide , StopAtEOSGuide
3636from outlines .fsm .json_schema import build_regex_from_schema
3737from outlines .integrations .utils import adapt_tokenizer , convert_json_schema_to_str
3838
3939if TYPE_CHECKING :
4040 from vllm import LLM
4141
42+ from outlines .models .tokenizer import Tokenizer
4243
43- class RegexLogitsProcessor :
44- """Bias vLLM generation based on a regular expression.
44+
45+ class FSMLogitsProcessor :
46+ """Bias vLLM generation based on a FSM.
4547
4648 Attributes
4749 ----------
4850 fsm
4951 The finite state machine which is used to bias the logits.
5052 """
5153
52- def __init__ (self , regex_string : str , llm : "LLM" ):
54+ def __init__ (self , fsm : Guide ):
5355 """Compile the FSM that drives the regex-structured generation.
5456
5557 Parameters
5658 ----------
57- regex_string
58- A string that represents a regular expression.
59- llm
60- The vLLM model.
59+ fsm
60+ Guide.
6161
62- Raises
63- ------
64- ValueError
65- If the provided LLM instance in `RegexLogitsProcessor` neither has a
66- `tokenizer` attribute or a `get_tokenizer` method.
6762 """
63+ self .fsm = fsm
64+ self .mask_cache : Dict [int , torch .Tensor ] = {}
65+ self ._fsm_state : DefaultDict [int , int ] = defaultdict (int )
66+
67+ @staticmethod
68+ def get_llm_tokenizer (llm : "LLM" ) -> "Tokenizer" :
69+ """Give the tokenizer attached to the LLM provided"""
6870 if hasattr (llm , "get_tokenizer" ):
6971 tokenizer = llm .get_tokenizer ()
7072 elif hasattr (llm , "tokenizer" ):
@@ -74,13 +76,10 @@ def __init__(self, regex_string: str, llm: "LLM"):
7476 tokenizer = llm .tokenizer
7577 else :
7678 raise ValueError (
77- "The provided LLM instance in `RegexLogitsProcessor ` neither has a "
79+ "The provided LLM instance in `FSMLogitsProcessor ` neither has a "
7880 "`tokenizer` attribute or a `get_tokenizer` method."
7981 )
80- tokenizer = adapt_tokenizer (tokenizer = tokenizer )
81- self .mask_cache : Dict [int , torch .Tensor ] = {}
82- self .fsm = RegexGuide (regex_string , tokenizer )
83- self ._fsm_state : DefaultDict [int , int ] = defaultdict (int )
82+ return adapt_tokenizer (tokenizer = tokenizer )
8483
8584 def __call__ (self , input_ids : List [int ], scores : torch .Tensor ) -> torch .Tensor :
8685 """Use the FSM to bias the logits before sampling the next token.
@@ -125,6 +124,64 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
125124 return biased_scores
126125
127126
127+ class TextLogitsProcessor (FSMLogitsProcessor ):
128+ """Bias vLLM generation for free text (required because of prompt alignment).
129+
130+ Attributes
131+ ----------
132+ fsm
133+ The finite state machine which is used to bias the logits.
134+ """
135+
136+ def __init__ (self , llm : "LLM" ):
137+ """Compile the FSM that drives the regex-structured generation.
138+
139+ Parameters
140+ ----------
141+ llm
142+ The vLLM model.
143+
144+ Raises
145+ ------
146+ ValueError
147+ If the provided LLM instance in `TextLogitsProcessor` neither has a
148+ `tokenizer` attribute or a `get_tokenizer` method.
149+ """
150+ tokenizer = self .get_llm_tokenizer (llm )
151+ fsm = StopAtEOSGuide (tokenizer )
152+ super ().__init__ (fsm = fsm )
153+
154+
155+ class RegexLogitsProcessor (FSMLogitsProcessor ):
156+ """Bias vLLM generation based on a regular expression.
157+
158+ Attributes
159+ ----------
160+ fsm
161+ The finite state machine which is used to bias the logits.
162+ """
163+
164+ def __init__ (self , regex_string : str , llm : "LLM" ):
165+ """Compile the FSM that drives the regex-structured generation.
166+
167+ Parameters
168+ ----------
169+ regex_string
170+ A string that represents a regular expression.
171+ llm
172+ The vLLM model.
173+
174+ Raises
175+ ------
176+ ValueError
177+ If the provided LLM instance in `RegexLogitsProcessor` neither has a
178+ `tokenizer` attribute or a `get_tokenizer` method.
179+ """
180+ tokenizer = self .get_llm_tokenizer (llm )
181+ fsm = RegexGuide (regex_string , tokenizer )
182+ super ().__init__ (fsm = fsm )
183+
184+
128185class JSONLogitsProcessor (RegexLogitsProcessor ):
129186 """Bias vLLM generation based on a JSON schema.
130187
0 commit comments