11import sys
22import time
3- from typing import Any , Dict , List , Tuple
3+ from typing import Any , Dict , List , Optional , Tuple , Union
44
55import pytest
66from transformers import AutoTokenizer
77
88import xgrammar as xgr
9+ from xgrammar .structural_tag import StructuralTag
910from xgrammar .testing import _is_grammar_accept_string
1011
11- PROFILER_ON = True
12- tokenizer_id = "meta-llama/Llama-3.1-8B-Instruct"
13-
1412
1513class Profiler :
1614 def __init__ (self , tokenizer_id : str ):
@@ -22,8 +20,13 @@ def __init__(self, tokenizer_id: str):
2220 self .tokenizer_info , max_threads = 16 , cache_enabled = False
2321 )
2422
25- def profile_stag (self , structural_tag_format : Dict [str , Any ], instance : str ):
26- structural_tag = {"type" : "structural_tag" , "format" : structural_tag_format }
23+ def profile_stag (
24+ self , structural_tag_format : Union [Dict [str , Any ], StructuralTag ], instance : str
25+ ):
26+ if isinstance (structural_tag_format , StructuralTag ):
27+ structural_tag = structural_tag_format
28+ else :
29+ structural_tag = {"type" : "structural_tag" , "format" : structural_tag_format }
2730 time_begin = time .monotonic_ns ()
2831 compiled_grammar = self .compiler .compile_structural_tag (structural_tag )
2932 time_end = time .monotonic_ns ()
@@ -45,8 +48,23 @@ def profile_stag(self, structural_tag_format: Dict[str, Any], instance: str):
4548 print (f"Time to generate mask: { duration / 1000 } us, Character: '{ char } '" )
4649
4750
48- if PROFILER_ON :
49- profiler = Profiler (tokenizer_id )
51+ profiler : Optional [Profiler ] = None
52+ PROFILER_ON = True
53+ tokenizer_id = "meta-llama/Llama-3.1-8B-Instruct"
54+
55+
56+ @pytest .fixture (autouse = True )
57+ def disable_profiler (request ):
58+ global PROFILER_ON
59+ global profiler
60+ markexpr = getattr (request .config .option , "markexpr" , "" ) or request .config .getoption (
61+ "markexpr" , ""
62+ )
63+ hf_token_not_provided = "not hf_token_required" in (markexpr or "" )
64+ if hf_token_not_provided :
65+ PROFILER_ON = False
66+ else :
67+ profiler = Profiler (tokenizer_id )
5068
5169
5270def check_stag_with_grammar (structural_tag_format : Dict [str , Any ], expected_grammar_ebnf : str ):
@@ -56,13 +74,16 @@ def check_stag_with_grammar(structural_tag_format: Dict[str, Any], expected_gram
5674
5775
5876def check_stag_with_instance (
59- structural_tag_format : Dict [str , Any ],
77+ structural_tag_format : Union [ Dict [str , Any ], StructuralTag ],
6078 instance : str ,
6179 is_accepted : bool = True ,
6280 debug_print : bool = False ,
6381):
64- structural_tag = {"type" : "structural_tag" , "format" : structural_tag_format }
65- stag_grammar = xgr .Grammar .from_structural_tag (structural_tag )
82+ if isinstance (structural_tag_format , StructuralTag ):
83+ stag_grammar = xgr .Grammar .from_structural_tag (structural_tag_format )
84+ else :
85+ structural_tag = {"type" : "structural_tag" , "format" : structural_tag_format }
86+ stag_grammar = xgr .Grammar .from_structural_tag (structural_tag )
6687 accepted = _is_grammar_accept_string (stag_grammar , instance , debug_print = debug_print )
6788 assert accepted == is_accepted
6889 if PROFILER_ON :
@@ -1955,8 +1976,7 @@ def test_from_structural_tag_with_structural_tag_instance(
19551976 stag_format : xgr .structural_tag .Format , instance : str , is_accepted : bool
19561977):
19571978 stag = xgr .StructuralTag (format = stag_format )
1958- grammar = xgr .Grammar .from_structural_tag (stag )
1959- assert _is_grammar_accept_string (grammar , instance ) == is_accepted
1979+ check_stag_with_instance (stag , instance , is_accepted )
19601980
19611981
19621982if __name__ == "__main__" :
0 commit comments