@@ -97,7 +97,9 @@ def reference_make_row_id_map(
9797
9898 # Compute total tokens per expert and expert offsets
9999 tokens_per_expert = jnp .sum (routing_map , axis = 0 )
100- expert_offsets = jnp .concatenate ([jnp .array ([0 ]), jnp .cumsum (tokens_per_expert )[:- 1 ]])
100+ expert_offsets = jnp .concatenate (
101+ [jnp .array ([0 ], dtype = jnp .int32 ), jnp .cumsum (tokens_per_expert )[:- 1 ].astype (jnp .int32 )]
102+ )
101103
102104 # Compute destination rows for all (token, expert) pairs
103105 # dest_row[i, j] = expert_offsets[j] + cumsum_per_expert[i, j] - 1 if routed, else -1
@@ -115,7 +117,9 @@ def reference_make_row_id_map(
115117
116118 # Gather the sorted destination rows and expert indices using advanced indexing
117119 # Create indices for gathering
118- token_idx = jnp .broadcast_to (jnp .arange (num_tokens )[:, None ], (num_tokens , num_experts ))
120+ token_idx = jnp .broadcast_to (
121+ jnp .arange (num_tokens , dtype = jnp .int32 )[:, None ], (num_tokens , num_experts )
122+ )
119123 sorted_dest_rows = dest_rows_all [token_idx , sorted_expert_indices ]
120124
121125 # Build row_id_map: [dest_row_0, ..., dest_row_{E-1}, expert_idx_0, ..., expert_idx_{E-1}, n_routed]
@@ -373,11 +377,15 @@ def reference_make_chunk_sort_map(
373377 Row ID map for chunk sorting of shape [num_tokens,].
374378 """
375379 # Compute source chunk boundaries (cumulative sum of original split_sizes)
376- src_cumsum = jnp .concatenate ([jnp .array ([0 ]), jnp .cumsum (split_sizes )])
380+ src_cumsum = jnp .concatenate (
381+ [jnp .array ([0 ], dtype = jnp .int32 ), jnp .cumsum (split_sizes ).astype (jnp .int32 )]
382+ )
377383
378384 # Compute destination chunk boundaries based on sorted order
379385 sorted_sizes = split_sizes [sorted_indices ]
380- dest_cumsum = jnp .concatenate ([jnp .array ([0 ]), jnp .cumsum (sorted_sizes )])
386+ dest_cumsum = jnp .concatenate (
387+ [jnp .array ([0 ], dtype = jnp .int32 ), jnp .cumsum (sorted_sizes ).astype (jnp .int32 )]
388+ )
381389
382390 # For each source chunk, compute its destination offset
383391 # inverse_indices[i] = position of chunk i in sorted order
@@ -386,7 +394,7 @@ def reference_make_chunk_sort_map(
386394
387395 # Create row_id_map: for each token position, compute its destination
388396 # First, figure out which chunk each position belongs to
389- position_indices = jnp .arange (num_tokens )
397+ position_indices = jnp .arange (num_tokens , dtype = jnp . int32 )
390398
391399 # chunk_ids[i] = which chunk position i belongs to
392400 chunk_ids = jnp .searchsorted (src_cumsum [1 :], position_indices , side = "right" )
0 commit comments