1- import pytest
1+ import re
22
33import llama_cpp
4- import llguidance .hf
5- import numpy as np
6- import torch
4+ import llguidance
5+ import pytest
76import transformers
87from llguidance import LLTokenizer
98
1211 LLGuidanceBackend ,
1312 LLGuidanceLogitsProcessor
1413)
14+ from tests .backends .test_backends_utils import simulate_model_calling_processor
1515
1616try :
17- import mlx .core as mx
1817 import mlx_lm
1918 HAS_MLX = True
2019except ImportError :
@@ -40,20 +39,6 @@ def model_mlxlm():
4039 * mlx_lm .load ("mlx-community/SmolLM-135M-Instruct-4bit" )
4140 )
4241
43- @pytest .fixture
44- def llg_tokenizer ():
45- return llguidance .hf .from_tokenizer (
46- transformers .AutoTokenizer .from_pretrained ("erwanf/gpt2-mini" ),
47- )
48-
49- @pytest .fixture
50- def llg_grammar_spec ():
51- return (
52- '{"grammars": [{ "json_schema": {"type": "object", "properties":'
53- + ' {"name": {"type": "string"}, "age": {"type": "integer"}}, "requ'
54- + 'ired": ["name", "age"], "additionalProperties": false} }] }'
55- )
56-
5742@pytest .fixture
5843def json_schema ():
5944 return (
@@ -97,42 +82,61 @@ def cfg_ebnf():
9782"""
9883
9984
100- def test_llguidance_processor_torch (llg_grammar_spec , llg_tokenizer ):
101- processor = LLGuidanceLogitsProcessor (llg_grammar_spec , llg_tokenizer , "torch" )
102- logits = torch .randn (2 , llg_tokenizer .vocab_size )
103- input_ids = torch .randint (0 , llg_tokenizer .vocab_size , (2 , 10 ))
104- output = processor (input_ids , logits )
105- assert output .shape == (2 , llg_tokenizer .vocab_size )
106- processor (input_ids , logits )
107-
85+ def test_llguidance_processor_torch (regex ):
86+ model = model_transformers ()
87+ tokenizer = model .tokenizer
88+ hf_tokenizer = model .hf_tokenizer
89+ llg_tokenizer = LLGuidanceBackend (model ).llg_tokenizer
90+ grammar_spec = llguidance .grammar_from ("regex" , regex )
91+ processor = LLGuidanceLogitsProcessor (grammar_spec , llg_tokenizer , "torch" )
92+ for _ in range (2 ):
93+ input_ids = simulate_model_calling_processor (
94+ processor ,
95+ "torch" ,
96+ len (tokenizer .get_vocab ()),
97+ tokenizer .eos_token_id ,
98+ 2
99+ )
100+ assert re .match (regex , hf_tokenizer .decode (input_ids [0 ]))
101+ assert re .match (regex , hf_tokenizer .decode (input_ids [1 ]))
102+
103+
104+ def test_llguidance_processor_numpy (regex ):
105+ model = model_llamacpp ()
106+ tokenizer = model .tokenizer
107+ llg_tokenizer = LLGuidanceBackend (model ).llg_tokenizer
108+ grammar_spec = llguidance .grammar_from ("regex" , regex )
109+ processor = LLGuidanceLogitsProcessor (grammar_spec , llg_tokenizer , "numpy" )
110+ for _ in range (2 ):
111+ input_ids = simulate_model_calling_processor (
112+ processor ,
113+ "numpy" ,
114+ len (tokenizer .vocabulary ),
115+ tokenizer .eos_token_id ,
116+ 2
117+ )
118+ assert re .match (regex , tokenizer .decode (input_ids [0 ])[0 ])
119+ assert re .match (regex , tokenizer .decode (input_ids [1 ])[0 ])
108120
109- def test_llguidance_processor_numpy (llg_grammar_spec , llg_tokenizer ):
110- processor = LLGuidanceLogitsProcessor (llg_grammar_spec , llg_tokenizer , "numpy" )
111- logits = np .random .randn (2 , llg_tokenizer .vocab_size )
112- input_ids = np .random .randint (0 , llg_tokenizer .vocab_size , (2 , 10 ))
113- output = processor (input_ids , logits )
114- assert output .shape == (2 , llg_tokenizer .vocab_size )
115- processor (input_ids , logits )
116121
117122
118123@pytest .mark .skipif (not HAS_MLX , reason = "MLX tests require Apple Silicon" )
119- def test_llguidance_processor_mlx (llg_grammar_spec , llg_tokenizer ):
120- processor = LLGuidanceLogitsProcessor (llg_grammar_spec , llg_tokenizer , "mlx" )
121- logits = mx .random .normal ((2 , llg_tokenizer .vocab_size ))
122- input_ids = mx .random .randint (0 , llg_tokenizer .vocab_size , (2 , 10 ))
123- output = processor (input_ids , logits )
124- assert output .shape == (2 , llg_tokenizer .vocab_size )
125- processor (input_ids , logits )
126-
127-
128- def test_llguidance_processor_tensorflow (llg_grammar_spec , llg_tokenizer ):
129- with pytest .raises (TypeError ):
130- LLGuidanceLogitsProcessor (llg_grammar_spec , llg_tokenizer , "tensorflow" )
131-
132-
133- def test_llguidance_processor_jax (llg_grammar_spec , llg_tokenizer ):
134- with pytest .raises (TypeError ):
135- LLGuidanceLogitsProcessor (llg_grammar_spec , llg_tokenizer , "jax" )
124+ def test_llguidance_processor_mlx (regex ):
125+ model = model_mlxlm ()
126+ tokenizer = model .mlx_tokenizer
127+ llg_tokenizer = LLGuidanceBackend (model ).llg_tokenizer
128+ grammar_spec = llguidance .grammar_from ("regex" , regex )
129+ processor = LLGuidanceLogitsProcessor (grammar_spec , llg_tokenizer , "mlx" )
130+ for _ in range (2 ):
131+ input_ids = simulate_model_calling_processor (
132+ processor ,
133+ "mlx" ,
134+ len (tokenizer .vocabulary ),
135+ tokenizer .eos_token_id ,
136+ 2
137+ )
138+ assert re .match (regex , tokenizer .decode (input_ids [0 ]))
139+ assert re .match (regex , tokenizer .decode (input_ids [1 ]))
136140
137141
138142models = [
@@ -155,7 +159,6 @@ def test_llguidance_backend(model, tensor_library_name, json_schema, regex, cfg_
155159 generator = outlines .Generator (model , backend = "llguidance" , processor = processor )
156160 response = generator ("Hello, how are you?" )
157161 assert response [0 ] == "{"
158- assert "name" in response
159162
160163 # regex
161164 processor = backend .get_regex_logits_processor (regex )
@@ -184,3 +187,16 @@ def test_llguidance_backend(model, tensor_library_name, json_schema, regex, cfg_
184187 generator = outlines .Generator (model , backend = "llguidance" , processor = processor )
185188 response = generator ("Hello, how are you?" )
186189 assert response == "yes" or response == "no"
190+
191+ # batch + multiple generations
192+ processor = backend .get_json_schema_logits_processor (json_schema )
193+ generator = outlines .Generator (model , backend = "llguidance" , processor = processor )
194+ for _ in range (2 ):
195+ if tensor_library_name == "torch" :
196+ response = generator .batch (["Create a character" , "Hello, how are you?" ], max_new_tokens = 200 )
197+ assert len (response ) == 2
198+ for r in response :
199+ assert r [0 ] == "{"
200+ else :
201+ response = generator ("Create a character" , max_tokens = 20 )
202+ assert response [0 ] == "{"
0 commit comments