44from ._remote import RemoteEngine
55
66class MockEngine (Engine ):
7- def __init__ (self , tokenizer , byte_patterns , compute_log_probs ):
7+ def __init__ (self , tokenizer , byte_patterns , compute_log_probs , force ):
88 super ().__init__ (tokenizer , compute_log_probs = compute_log_probs )
99
1010 self ._valid_mask = np .zeros (len (tokenizer .tokens ))
@@ -14,6 +14,7 @@ def __init__(self, tokenizer, byte_patterns, compute_log_probs):
1414 self ._valid_mask [i ] = 1.0
1515 except :
1616 pass
17+ self .force = force
1718
1819 # allow a single byte pattern to be passed
1920 if isinstance (byte_patterns , (bytes , str )):
@@ -23,8 +24,10 @@ def __init__(self, tokenizer, byte_patterns, compute_log_probs):
2324 for i ,pattern in enumerate (byte_patterns ):
2425 if isinstance (pattern , str ):
2526 byte_patterns [i ] = pattern .encode ("utf8" )
26-
27+
2728 self .byte_patterns = byte_patterns
29+
30+ # seed the random number generator
2831 self ._rand_generator = np .random .default_rng (seed = 42 )
2932
3033 def get_logits (self , token_ids , forced_bytes , current_temp ):
@@ -34,8 +37,13 @@ def get_logits(self, token_ids, forced_bytes, current_temp):
3437 # build the byte strings
3538 byte_string = b"" .join (self .tokenizer .tokens [i ] for i in token_ids )
3639
37- # we randomly generate valid unicode bytes
38- logits = self ._rand_generator .standard_normal (len (self .tokenizer .tokens )) * self ._valid_mask
40+ # if we are forcing the bytes patterns then don't allow other tokens
41+ if self .force :
42+ logits = np .ones (len (self .tokenizer .tokens )) * - np .inf
43+
44+ # otherwise we randomly generate valid unicode bytes
45+ else :
46+ logits = self ._rand_generator .standard_normal (len (self .tokenizer .tokens )) * self ._valid_mask
3947
4048 # if we have a pattern that matches then force the next token
4149 bias = 100.0
@@ -55,7 +63,7 @@ def _get_next_tokens(self, byte_string):
5563 yield i
5664
5765class Mock (Model ):
58- def __init__ (self , byte_patterns = [], echo = True , compute_log_probs = False , ** kwargs ):
66+ def __init__ (self , byte_patterns = [], echo = True , compute_log_probs = False , force = False , ** kwargs ):
5967 '''Build a new Mock model object that represents a model in a given state.'''
6068
6169 if isinstance (byte_patterns , str ) and byte_patterns .startswith ("http" ):
@@ -67,7 +75,7 @@ def __init__(self, byte_patterns=[], echo=True, compute_log_probs=False, **kwarg
6775 0 ,
6876 0
6977 )
70- engine = MockEngine (tokenizer , byte_patterns , compute_log_probs )
78+ engine = MockEngine (tokenizer , byte_patterns , compute_log_probs , force )
7179
7280
7381 super ().__init__ (
0 commit comments