Skip to content

Commit 08dc786

Browse files
authored
Fix 50% comparison mismatch in sort_chunks_by_index (Cont.) (#2575)
* force initialization to int32 Signed-off-by: tdophung <[email protected]> * address greptile comment Signed-off-by: tdophung <[email protected]> * del useless comments, add more restriction to int32 Signed-off-by: tdophung <[email protected]> --------- Signed-off-by: tdophung <[email protected]>
1 parent de51c96 commit 08dc786

File tree

1 file changed

+2
-11
lines changed

1 file changed

+2
-11
lines changed

tests/jax/test_permutation.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,6 @@
1919
from utils import assert_allclose, pytest_parametrize_wrapper
2020

2121

22-
# =============================================================================
23-
# Test parameter definitions with L0 (fast) and L2 (comprehensive) levels
24-
# =============================================================================
25-
26-
# All dispatch/combine test cases
2722
ALL_DISPATCH_COMBINE_CASES = [
2823
(128, 5, 128, 3),
2924
(1024, 8, 128, 8),
@@ -35,7 +30,6 @@
3530
"L2": ALL_DISPATCH_COMBINE_CASES,
3631
}
3732

38-
# All sort chunks test cases
3933
ALL_SORT_CHUNKS_CASES = [
4034
(8, 4096, 1280),
4135
(64, 4096, 4096),
@@ -46,7 +40,6 @@
4640
"L2": ALL_SORT_CHUNKS_CASES,
4741
}
4842

49-
# All dispatch/combine with padding test cases
5043
ALL_DISPATCH_COMBINE_PADDING_CASES = [
5144
(128, 5, 128, 3, 8),
5245
(1024, 8, 128, 8, 16),
@@ -58,14 +51,12 @@
5851
"L2": ALL_DISPATCH_COMBINE_PADDING_CASES,
5952
}
6053

61-
# Dtypes for testing
6254
ALL_DTYPES = [jnp.float32, jnp.bfloat16]
6355
DTYPES = {
6456
"L0": ALL_DTYPES,
6557
"L2": ALL_DTYPES,
6658
}
6759

68-
# With probs options
6960
ALL_WITH_PROBS = [True, False]
7061
WITH_PROBS = {
7162
"L0": [True],
@@ -389,15 +380,15 @@ def reference_make_chunk_sort_map(
389380

390381
# For each source chunk, compute its destination offset
391382
# inverse_indices[i] = position of chunk i in sorted order
392-
inverse_indices = jnp.argsort(sorted_indices)
383+
inverse_indices = jnp.argsort(sorted_indices).astype(jnp.int32)
393384
dest_offsets = dest_cumsum[inverse_indices]
394385

395386
# Create row_id_map: for each token position, compute its destination
396387
# First, figure out which chunk each position belongs to
397388
position_indices = jnp.arange(num_tokens, dtype=jnp.int32)
398389

399390
# chunk_ids[i] = which chunk position i belongs to
400-
chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right")
391+
chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right").astype(jnp.int32)
401392

402393
# within_chunk_offset[i] = position i's offset within its chunk
403394
within_chunk_offset = position_indices - src_cumsum[chunk_ids]

0 commit comments

Comments
 (0)