Skip to content

Commit f43b36b

Browse files
[Feature] Add a function to convert bitmasks to bool masks for development usage. (#426)
This PR provides a function to convert a bitmask to bool mask, the API is: ``` def _bitmask_to_bool_mask(bit_mask: torch.Tensor, vocab_size: Optional[int] = None) -> torch.Tensor: ``` --------- Signed-off-by: Yuchuan <[email protected]>
1 parent 7037810 commit f43b36b

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

python/xgrammar/testing.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,42 @@ def _bool_mask_to_bitmask(bool_mask: torch.Tensor) -> torch.Tensor:
262262
return bitmask.to(torch.int32)
263263

264264

265+
def _bitmask_to_bool_mask(bit_mask: torch.Tensor, vocab_size: Optional[int] = None) -> torch.Tensor:
266+
"""
267+
Convert a bitmask tensor to a boolean mask tensor.
268+
269+
Parameters
270+
----------
271+
bit_mask : torch.Tensor
272+
The bitmask tensor to convert. Should be on CPU and of type int32.
273+
vocab_size : Optional[int], default: None
274+
The size of the vocabulary. If provided, the output mask will be cut to this size.
275+
276+
Returns
277+
-------
278+
bool_mask : torch.Tensor
279+
The converted boolean mask tensor.
280+
"""
281+
282+
# Validate input.
283+
if bit_mask.device.type != "cpu":
284+
raise ValueError("bit_mask should be on CPU.")
285+
if bit_mask.dtype != bitmask_dtype:
286+
raise ValueError("bit_mask should be of type torch.int32.")
287+
288+
if vocab_size is None:
289+
vocab_size = bit_mask.shape[1] * 32
290+
if vocab_size > bit_mask.shape[1] * 32:
291+
raise ValueError(
292+
"vocab_size should be less than or equal to the size represented by bit_mask."
293+
)
294+
295+
bool_mask = torch.zeros((bit_mask.shape[0], vocab_size), dtype=torch.bool)
296+
for i in range(vocab_size):
297+
bool_mask[:, i] = (bit_mask[:, i // 32] & (1 << (i % 32))) != 0
298+
return bool_mask
299+
300+
265301
def _get_matcher_from_grammar_and_tokenizer_info(
266302
grammar: Union[Grammar, str], tokenizer_info: Optional[TokenizerInfo] = None, **kwargs
267303
) -> GrammarMatcher:

tests/python/test_token_bitmask_operations.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import xgrammar as xgr
1111
from xgrammar.testing import (
12+
_bitmask_to_bool_mask,
1213
_bool_mask_to_bitmask,
1314
_get_masked_tokens_from_bitmask,
1415
_is_single_token_bitmask,
@@ -357,5 +358,36 @@ def test_apply_token_bitmask_inplace_indices(
357358
torch.testing.assert_close(logits, logits_expected)
358359

359360

361+
def test_bitmask_to_boolmask():
362+
# 0xFFFF0000, 0x0000FFFF
363+
bitmask = torch.tensor([[-65536, 65535]], dtype=torch.int32)
364+
expected = torch.tensor(
365+
[[False] * 16, [True] * 16, [True] * 16, [False] * 16], dtype=torch.bool
366+
).reshape(1, -1)
367+
bool_mask = _bitmask_to_bool_mask(bitmask)
368+
assert torch.equal(bool_mask, expected)
369+
370+
bool_mask_50 = _bitmask_to_bool_mask(bitmask, vocab_size=50)
371+
expected_50 = expected[:, :50]
372+
assert torch.equal(bool_mask_50, expected_50)
373+
374+
375+
batch__size__vocab__size = [
376+
(4, 1000),
377+
(1, 1024),
378+
(16, 1024),
379+
# not a multiple of 16.
380+
(3, 817),
381+
]
382+
383+
384+
@pytest.mark.parametrize("batch_size, vocab_size", batch__size__vocab__size)
385+
def test_bool_mask_bitmask_roundtrip(batch_size: int, vocab_size: int):
386+
bool_mask = torch.randint(0, 2, (batch_size, vocab_size), dtype=torch.bool)
387+
bitmask = _bool_mask_to_bitmask(bool_mask)
388+
bool_mask_converted = _bitmask_to_bool_mask(bitmask, vocab_size=vocab_size)
389+
assert torch.equal(bool_mask, bool_mask_converted)
390+
391+
360392
if __name__ == "__main__":
361393
pytest.main(sys.argv)

0 commit comments

Comments
 (0)