1+ # flake8: noqa
2+ # mypy: ignore-errors
3+ import os
4+ import random
5+ import time
6+
7+ import psutil
8+ from outlines_core import Guide , Index , Vocabulary , create_mask , mask_to_list
9+ from outlines_core .json_schema import build_regex_from_schema
10+
11+ os .environ ["RUST_LOG" ] = "debug"
12+
13+
14+ regexes = [
15+ {
16+ "name" : "email" ,
17+ "regex" : r"(?:[a-z0-9!#$%&'*+/=?^_`{|}~-]{1,63}(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]{1,63}){0,10})@(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.){1,3}[a-z0-9](?:[a-z0-9-]{0,30}[a-z0-9])?" ,
18+ },
19+ {"name" : "simple_phone" , "regex" : r"\+?[1-9][0-9]{7,14}" },
20+ {
21+ "name" : "complex_phone" ,
22+ "regex" : r"\+?\d{1,4}?[-.\s]?\(?\d{1,3}?\)?[-.\s]?\d{1,4}[-.\s]?\d{1,4}[-.\s]?\d{1,9}" ,
23+ },
24+ {"name" : "permissive_any" , "regex" : r".{255}$" },
25+ {"name" : "permissive_words" , "regex" : r"[a-zA-Z]{100}" },
26+ {"name" : "https" , "regex" : r"(https?:\\/\\/)?([\\da-z\\.-]+)\\.([a-z\\.]{2,6})([\\/\\w \\.-]*)*\\/?" }
27+ ]
28+ schemas = [
29+ {
30+ "name" : "schema_simple" ,
31+ "regex" : r'{"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, "required": ["name", "age"]}' ,
32+ },
33+ {
34+ "name" : "schema_simple_phone" ,
35+ "regex" : r'{"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "integer"}, "complexe_phone": {"type": "string", "pattern": "\\+?\\d{1,4}?[-. ]?\\(\\d{1,3}\\)?[-. ]?\\d{1,4}[-. ]?\\d{1,4}[-. ]?\\d{1,9}"}}, "required": ["name", "age", "complexe_phone"]}' ,
36+ },
37+ {
38+ "name" : "schema_complexe" ,
39+ "regex" : """{
40+ "$schema": "http://json-schema.org/draft-04/schema#",
41+ "title": "Schema for a recording",
42+ "type": "object",
43+ "definitions": {
44+ "artist": {
45+ "type": "object",
46+ "properties": {
47+ "id": {"type": "number"},
48+ "name": {"type": "string"},
49+ "functions": {
50+ "type": "array",
51+ "items": {"type": "string"}
52+ }
53+ },
54+ "required": ["id", "name", "functions"]
55+ }
56+ },
57+ "properties": {
58+ "id": {"type": "number"},
59+ "work": {
60+ "type": "object",
61+ "properties": {
62+ "id": {"type": "number"},
63+ "name": {"type": "string"},
64+ "composer": {"$ref": "#/definitions/artist"}
65+ }
66+ },
67+ "recording_artists": {
68+ "type": "array",
69+ "items": {"$ref": "#/definitions/artist"}
70+ }
71+ },
72+ "required": ["id", "work", "recording_artists"]
73+ }"""
74+ },
75+ {
76+ "name" : "schema_curriculum" ,
77+ "regex" : r'''{
78+ "$schema": "http://json-schema.org/draft-04/schema#",
79+ "title": "Schema for a Curriculum Vitae",
80+ "type": "object",
81+ "definitions": {
82+ "experienceEntry": {
83+ "type": "object",
84+ "properties": {
85+ "date": {
86+ "type": "string",
87+ "format": "date"
88+ },
89+ "position": {
90+ "type": "string"
91+ }
92+ },
93+ "required": ["date", "position"]
94+ }
95+ },
96+ "properties": {
97+ "name": {
98+ "type": "string"
99+ },
100+ "surname": {
101+ "type": "string"
102+ },
103+ "email": {
104+ "type": "string",
105+ "pattern": "[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?"
106+ },
107+ "phone": {
108+ "type": "string",
109+ "pattern": "\\+?\\d{1,4}?[-. ]?\\(\\d{1,3}\\)?[-. ]?\\d{1,4}[-. ]?\\d{1,4}[-. ]?\\d{1,9}"
110+ },
111+ "website": {
112+ "type": "string",
113+ "pattern": "(https?:\\/\\/)?([\\da-z\\.-]+)\\.([a-z\\.]{2,6})([\\/\\w \\.-]*)*\\/?"
114+ },
115+ "resume": {
116+ "type": "array",
117+ "items": {
118+ "$ref": "#/definitions/experienceEntry"
119+ }
120+ }
121+ },
122+ "required": ["name", "surname", "email", "phone", "resume"]
123+ }'''
124+ }
125+ ]
126+
127+
128+ class V2IndexBenchmark :
129+ def setup (self , regex ):
130+ self .vocab = Vocabulary .from_pretrained ("unsloth/Llama-3.1-8B-Instruct" )
131+ self .v2_index = Index (regex , self .vocab )
132+
133+ self .v2_guide = Guide (self .v2_index )
134+
135+ self .mask = create_mask (len (self .vocab ) + 1 )
136+
137+ self .process = psutil .Process ()
138+
139+ assert (
140+ not self .v2_guide .is_finished ()
141+ ), f"Compressed Guide should not be finished for { regex } "
142+
143+ def run_benchmark (self ):
144+ iterations = 0
145+ v2_total_time = 0
146+
147+ self .current_token_id = - 1
148+
149+ if not self .v2_guide .is_finished ():
150+ iterations += 1
151+
152+ start_compressed = time .perf_counter ()
153+ self .v2_guide .get_tokens (self .mask )
154+ end_compressed = time .perf_counter ()
155+
156+ v2_time = end_compressed - start_compressed
157+ v2_total_time += v2_time
158+
159+
160+ mask_tokens_list = mask_to_list (self .mask )
161+ random_idx = random .randrange (len (mask_tokens_list ))
162+ self .current_token_id = mask_tokens_list [random_idx ]
163+
164+
165+ while not self .v2_guide .is_finished ():
166+ iterations += 1
167+
168+ start_compressed = time .perf_counter ()
169+ self .v2_guide .advance (self .current_token_id , self .mask )
170+ end_compressed = time .perf_counter ()
171+
172+ v2_time = end_compressed - start_compressed
173+ v2_total_time += v2_time
174+
175+
176+ if not self .v2_guide .is_finished ():
177+ if iterations > 2000 :
178+ break
179+ mask_tokens_list = mask_to_list (self .mask )
180+ random_idx = random .randrange (len (mask_tokens_list ))
181+
182+ self .current_token_id = mask_tokens_list [random_idx ]
183+
184+
185+
186+ v2_total_time_us = v2_total_time * 1e6
187+
188+ print (f" Total iterations (Number of tokens): { iterations } " )
189+ print (
190+ f" Guide with Compressed Index: { v2_total_time_us :.2f} µs ({ v2_total_time_us / iterations :.2f} µs per iteration)"
191+ )
192+
193+
194+
195+ def test_benchmark_v2index ():
196+ for r in regexes :
197+ name = r ["name" ]
198+ regex = r ["regex" ]
199+
200+ print (f"> Regex : '{ name } '" )
201+ bench = V2IndexBenchmark ()
202+ bench .setup (regex )
203+ bench .run_benchmark ()
204+
205+ for s in schemas :
206+ name = s ["name" ]
207+ schema = s ["regex" ]
208+ regex = build_regex_from_schema (schema , None )
209+ print (f"> Schema : '{ name } '" )
210+ bench = V2IndexBenchmark ()
211+ bench .setup (regex )
212+ bench .run_benchmark ()
213+
214+
215+ if __name__ == "__main__" :
216+ print ("Running main..." )
217+ test_benchmark_v2index ()
0 commit comments