Skip to content

Commit 0431333

Browse files
[Format] Rename two testing functions to expose them. (#434)
This PR renames two functions to expose them. Signed-off-by: Yuchuan <[email protected]>
1 parent 148c3e7 commit 0431333

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

python/xgrammar/testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _is_single_token_bitmask(
233233
)
234234

235235

236-
def _bool_mask_to_bitmask(bool_mask: torch.Tensor) -> torch.Tensor:
236+
def bool_mask_to_bitmask(bool_mask: torch.Tensor) -> torch.Tensor:
237237
"""Get the bitmask from bool mask. If the bool mask does not align with the 32-bit block
238238
size, it will add extra 1 paddings.
239239
@@ -262,7 +262,7 @@ def _bool_mask_to_bitmask(bool_mask: torch.Tensor) -> torch.Tensor:
262262
return bitmask.to(torch.int32)
263263

264264

265-
def _bitmask_to_bool_mask(bit_mask: torch.Tensor, vocab_size: Optional[int] = None) -> torch.Tensor:
265+
def bitmask_to_bool_mask(bit_mask: torch.Tensor, vocab_size: Optional[int] = None) -> torch.Tensor:
266266
"""
267267
Convert a bitmask tensor to a boolean mask tensor.
268268

tests/python/test_token_bitmask_operations.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010
import xgrammar as xgr
1111
from xgrammar.testing import (
12-
_bitmask_to_bool_mask,
13-
_bool_mask_to_bitmask,
1412
_get_masked_tokens_from_bitmask,
1513
_is_single_token_bitmask,
14+
bitmask_to_bool_mask,
15+
bool_mask_to_bitmask,
1616
)
1717

1818
_is_cuda_available = torch.cuda.is_available()
@@ -38,7 +38,7 @@ def test_allocate_reset_token_bitmask():
3838
@pytest.mark.parametrize("index", (0, 1))
3939
def test_get_masked_tokens_from_bitmask(token_mask_size: int, index: int):
4040
bool_mask = torch.randint(0, 2, (2, token_mask_size), dtype=torch.bool)
41-
bitmask = _bool_mask_to_bitmask(bool_mask)
41+
bitmask = bool_mask_to_bitmask(bool_mask)
4242
expected = torch.where(~bool_mask[index])[0].tolist()
4343
assert _get_masked_tokens_from_bitmask(bitmask, token_mask_size, index) == expected
4444

@@ -50,13 +50,13 @@ def test_is_single_token_bitmask():
5050
token_id = 100
5151

5252
bool_mask = torch.zeros(batch, vocab_size, dtype=torch.bool)
53-
bitmask = _bool_mask_to_bitmask(bool_mask)
53+
bitmask = bool_mask_to_bitmask(bool_mask)
5454
assert _is_single_token_bitmask(bitmask, vocab_size, batch_index) == (False, -1)
5555
bool_mask[batch_index, token_id] = True
56-
bitmask = _bool_mask_to_bitmask(bool_mask)
56+
bitmask = bool_mask_to_bitmask(bool_mask)
5757
assert _is_single_token_bitmask(bitmask, vocab_size, batch_index) == (True, token_id)
5858
bool_mask[batch_index, token_id + 1] = True
59-
bitmask = _bool_mask_to_bitmask(bool_mask)
59+
bitmask = bool_mask_to_bitmask(bool_mask)
6060
assert _is_single_token_bitmask(bitmask, vocab_size, batch_index) == (False, -1)
6161

6262

@@ -229,7 +229,7 @@ def test_apply_token_bitmask_inplace_kernel_large(
229229
bool_mask.scatter_(1, masked_positions, False)
230230
assert (bool_mask.sum(dim=-1) + masked_cnt == vocab_size).all().item()
231231

232-
bitmask = _bool_mask_to_bitmask(bool_mask)
232+
bitmask = bool_mask_to_bitmask(bool_mask)
233233

234234
batch_indices = torch.arange(0, batch_size, stride, dtype=torch.int32)
235235

@@ -238,7 +238,7 @@ def test_apply_token_bitmask_inplace_kernel_large(
238238
logits_expected[batch_indices], ~bool_mask[batch_indices], float("-inf")
239239
)
240240

241-
bitmask = _bool_mask_to_bitmask(bool_mask)
241+
bitmask = bool_mask_to_bitmask(bool_mask)
242242
if impl in ["cuda", "triton", "torch_compile"]:
243243
logits_gpu = logits.to("cuda")
244244
bitmask_gpu = bitmask.to("cuda")
@@ -340,7 +340,7 @@ def test_apply_token_bitmask_inplace_indices(
340340

341341
logits = torch.ones(logits_batch_size, vocab_size, dtype=torch.float32)
342342
bool_mask = torch.zeros(bitmask_batch_size, vocab_size, dtype=torch.bool)
343-
bitmask = _bool_mask_to_bitmask(bool_mask)
343+
bitmask = bool_mask_to_bitmask(bool_mask)
344344

345345
logits_expected = logits.clone()
346346
logits_expected[indices] = torch.masked_fill(
@@ -364,10 +364,10 @@ def test_bitmask_to_boolmask():
364364
expected = torch.tensor(
365365
[[False] * 16, [True] * 16, [True] * 16, [False] * 16], dtype=torch.bool
366366
).reshape(1, -1)
367-
bool_mask = _bitmask_to_bool_mask(bitmask)
367+
bool_mask = bitmask_to_bool_mask(bitmask)
368368
assert torch.equal(bool_mask, expected)
369369

370-
bool_mask_50 = _bitmask_to_bool_mask(bitmask, vocab_size=50)
370+
bool_mask_50 = bitmask_to_bool_mask(bitmask, vocab_size=50)
371371
expected_50 = expected[:, :50]
372372
assert torch.equal(bool_mask_50, expected_50)
373373

@@ -384,8 +384,8 @@ def test_bitmask_to_boolmask():
384384
@pytest.mark.parametrize("batch_size, vocab_size", batch__size__vocab__size)
385385
def test_bool_mask_bitmask_roundtrip(batch_size: int, vocab_size: int):
386386
bool_mask = torch.randint(0, 2, (batch_size, vocab_size), dtype=torch.bool)
387-
bitmask = _bool_mask_to_bitmask(bool_mask)
388-
bool_mask_converted = _bitmask_to_bool_mask(bitmask, vocab_size=vocab_size)
387+
bitmask = bool_mask_to_bitmask(bool_mask)
388+
bool_mask_converted = bitmask_to_bool_mask(bitmask, vocab_size=vocab_size)
389389
assert torch.equal(bool_mask, bool_mask_converted)
390390

391391

0 commit comments

Comments
 (0)