99
1010import xgrammar as xgr
1111from xgrammar .testing import (
12- _bitmask_to_bool_mask ,
13- _bool_mask_to_bitmask ,
1412 _get_masked_tokens_from_bitmask ,
1513 _is_single_token_bitmask ,
14+ bitmask_to_bool_mask ,
15+ bool_mask_to_bitmask ,
1616)
1717
1818_is_cuda_available = torch .cuda .is_available ()
@@ -38,7 +38,7 @@ def test_allocate_reset_token_bitmask():
3838@pytest .mark .parametrize ("index" , (0 , 1 ))
3939def test_get_masked_tokens_from_bitmask (token_mask_size : int , index : int ):
4040 bool_mask = torch .randint (0 , 2 , (2 , token_mask_size ), dtype = torch .bool )
41- bitmask = _bool_mask_to_bitmask (bool_mask )
41+ bitmask = bool_mask_to_bitmask (bool_mask )
4242 expected = torch .where (~ bool_mask [index ])[0 ].tolist ()
4343 assert _get_masked_tokens_from_bitmask (bitmask , token_mask_size , index ) == expected
4444
@@ -50,13 +50,13 @@ def test_is_single_token_bitmask():
5050 token_id = 100
5151
5252 bool_mask = torch .zeros (batch , vocab_size , dtype = torch .bool )
53- bitmask = _bool_mask_to_bitmask (bool_mask )
53+ bitmask = bool_mask_to_bitmask (bool_mask )
5454 assert _is_single_token_bitmask (bitmask , vocab_size , batch_index ) == (False , - 1 )
5555 bool_mask [batch_index , token_id ] = True
56- bitmask = _bool_mask_to_bitmask (bool_mask )
56+ bitmask = bool_mask_to_bitmask (bool_mask )
5757 assert _is_single_token_bitmask (bitmask , vocab_size , batch_index ) == (True , token_id )
5858 bool_mask [batch_index , token_id + 1 ] = True
59- bitmask = _bool_mask_to_bitmask (bool_mask )
59+ bitmask = bool_mask_to_bitmask (bool_mask )
6060 assert _is_single_token_bitmask (bitmask , vocab_size , batch_index ) == (False , - 1 )
6161
6262
@@ -229,7 +229,7 @@ def test_apply_token_bitmask_inplace_kernel_large(
229229 bool_mask .scatter_ (1 , masked_positions , False )
230230 assert (bool_mask .sum (dim = - 1 ) + masked_cnt == vocab_size ).all ().item ()
231231
232- bitmask = _bool_mask_to_bitmask (bool_mask )
232+ bitmask = bool_mask_to_bitmask (bool_mask )
233233
234234 batch_indices = torch .arange (0 , batch_size , stride , dtype = torch .int32 )
235235
@@ -238,7 +238,7 @@ def test_apply_token_bitmask_inplace_kernel_large(
238238 logits_expected [batch_indices ], ~ bool_mask [batch_indices ], float ("-inf" )
239239 )
240240
241- bitmask = _bool_mask_to_bitmask (bool_mask )
241+ bitmask = bool_mask_to_bitmask (bool_mask )
242242 if impl in ["cuda" , "triton" , "torch_compile" ]:
243243 logits_gpu = logits .to ("cuda" )
244244 bitmask_gpu = bitmask .to ("cuda" )
@@ -340,7 +340,7 @@ def test_apply_token_bitmask_inplace_indices(
340340
341341 logits = torch .ones (logits_batch_size , vocab_size , dtype = torch .float32 )
342342 bool_mask = torch .zeros (bitmask_batch_size , vocab_size , dtype = torch .bool )
343- bitmask = _bool_mask_to_bitmask (bool_mask )
343+ bitmask = bool_mask_to_bitmask (bool_mask )
344344
345345 logits_expected = logits .clone ()
346346 logits_expected [indices ] = torch .masked_fill (
@@ -364,10 +364,10 @@ def test_bitmask_to_boolmask():
364364 expected = torch .tensor (
365365 [[False ] * 16 , [True ] * 16 , [True ] * 16 , [False ] * 16 ], dtype = torch .bool
366366 ).reshape (1 , - 1 )
367- bool_mask = _bitmask_to_bool_mask (bitmask )
367+ bool_mask = bitmask_to_bool_mask (bitmask )
368368 assert torch .equal (bool_mask , expected )
369369
370- bool_mask_50 = _bitmask_to_bool_mask (bitmask , vocab_size = 50 )
370+ bool_mask_50 = bitmask_to_bool_mask (bitmask , vocab_size = 50 )
371371 expected_50 = expected [:, :50 ]
372372 assert torch .equal (bool_mask_50 , expected_50 )
373373
@@ -384,8 +384,8 @@ def test_bitmask_to_boolmask():
384384@pytest .mark .parametrize ("batch_size, vocab_size" , batch__size__vocab__size )
385385def test_bool_mask_bitmask_roundtrip (batch_size : int , vocab_size : int ):
386386 bool_mask = torch .randint (0 , 2 , (batch_size , vocab_size ), dtype = torch .bool )
387- bitmask = _bool_mask_to_bitmask (bool_mask )
388- bool_mask_converted = _bitmask_to_bool_mask (bitmask , vocab_size = vocab_size )
387+ bitmask = bool_mask_to_bitmask (bool_mask )
388+ bool_mask_converted = bitmask_to_bool_mask (bitmask , vocab_size = vocab_size )
389389 assert torch .equal (bool_mask , bool_mask_converted )
390390
391391
0 commit comments