Skip to content

Commit 702fc5e

Browse files
authored
Fix 50% comparison mismatch in sort_chunks_by_index (#2566)
* force initialization to int32 Signed-off-by: tdophung <[email protected]> * address greptile comment Signed-off-by: tdophung <[email protected]> --------- Signed-off-by: tdophung <[email protected]>
1 parent 404a3ee commit 702fc5e

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

tests/jax/test_permutation.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def reference_make_row_id_map(
9797

9898
# Compute total tokens per expert and expert offsets
9999
tokens_per_expert = jnp.sum(routing_map, axis=0)
100-
expert_offsets = jnp.concatenate([jnp.array([0]), jnp.cumsum(tokens_per_expert)[:-1]])
100+
expert_offsets = jnp.concatenate(
101+
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(tokens_per_expert)[:-1].astype(jnp.int32)]
102+
)
101103

102104
# Compute destination rows for all (token, expert) pairs
103105
# dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1
@@ -115,7 +117,9 @@ def reference_make_row_id_map(
115117

116118
# Gather the sorted destination rows and expert indices using advanced indexing
117119
# Create indices for gathering
118-
token_idx = jnp.broadcast_to(jnp.arange(num_tokens)[:, None], (num_tokens, num_experts))
120+
token_idx = jnp.broadcast_to(
121+
jnp.arange(num_tokens, dtype=jnp.int32)[:, None], (num_tokens, num_experts)
122+
)
119123
sorted_dest_rows = dest_rows_all[token_idx, sorted_expert_indices]
120124

121125
# Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed]
@@ -373,11 +377,15 @@ def reference_make_chunk_sort_map(
373377
Row ID map for chunk sorting of shape [num_tokens,].
374378
"""
375379
# Compute source chunk boundaries (cumulative sum of original split_sizes)
376-
src_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(split_sizes)])
380+
src_cumsum = jnp.concatenate(
381+
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(split_sizes).astype(jnp.int32)]
382+
)
377383

378384
# Compute destination chunk boundaries based on sorted order
379385
sorted_sizes = split_sizes[sorted_indices]
380-
dest_cumsum = jnp.concatenate([jnp.array([0]), jnp.cumsum(sorted_sizes)])
386+
dest_cumsum = jnp.concatenate(
387+
[jnp.array([0], dtype=jnp.int32), jnp.cumsum(sorted_sizes).astype(jnp.int32)]
388+
)
381389

382390
# For each source chunk, compute its destination offset
383391
# inverse_indices[i] = position of chunk i in sorted order
@@ -386,7 +394,7 @@ def reference_make_chunk_sort_map(
386394

387395
# Create row_id_map: for each token position, compute its destination
388396
# First, figure out which chunk each position belongs to
389-
position_indices = jnp.arange(num_tokens)
397+
position_indices = jnp.arange(num_tokens, dtype=jnp.int32)
390398

391399
# chunk_ids[i] = which chunk position i belongs to
392400
chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right")

0 commit comments

Comments
 (0)