|
19 | 19 | from utils import assert_allclose, pytest_parametrize_wrapper |
20 | 20 |
|
21 | 21 |
|
22 | | -# ============================================================================= |
23 | | -# Test parameter definitions with L0 (fast) and L2 (comprehensive) levels |
24 | | -# ============================================================================= |
25 | | - |
26 | | -# All dispatch/combine test cases |
27 | 22 | ALL_DISPATCH_COMBINE_CASES = [ |
28 | 23 | (128, 5, 128, 3), |
29 | 24 | (1024, 8, 128, 8), |
|
35 | 30 | "L2": ALL_DISPATCH_COMBINE_CASES, |
36 | 31 | } |
37 | 32 |
|
38 | | -# All sort chunks test cases |
39 | 33 | ALL_SORT_CHUNKS_CASES = [ |
40 | 34 | (8, 4096, 1280), |
41 | 35 | (64, 4096, 4096), |
|
46 | 40 | "L2": ALL_SORT_CHUNKS_CASES, |
47 | 41 | } |
48 | 42 |
|
49 | | -# All dispatch/combine with padding test cases |
50 | 43 | ALL_DISPATCH_COMBINE_PADDING_CASES = [ |
51 | 44 | (128, 5, 128, 3, 8), |
52 | 45 | (1024, 8, 128, 8, 16), |
|
58 | 51 | "L2": ALL_DISPATCH_COMBINE_PADDING_CASES, |
59 | 52 | } |
60 | 53 |
|
61 | | -# Dtypes for testing |
62 | 54 | ALL_DTYPES = [jnp.float32, jnp.bfloat16] |
63 | 55 | DTYPES = { |
64 | 56 | "L0": ALL_DTYPES, |
65 | 57 | "L2": ALL_DTYPES, |
66 | 58 | } |
67 | 59 |
|
68 | | -# With probs options |
69 | 60 | ALL_WITH_PROBS = [True, False] |
70 | 61 | WITH_PROBS = { |
71 | 62 | "L0": [True], |
@@ -389,15 +380,15 @@ def reference_make_chunk_sort_map( |
389 | 380 |
|
390 | 381 | # For each source chunk, compute its destination offset |
391 | 382 | # inverse_indices[i] = position of chunk i in sorted order |
392 | | - inverse_indices = jnp.argsort(sorted_indices) |
| 383 | + inverse_indices = jnp.argsort(sorted_indices).astype(jnp.int32) |
393 | 384 | dest_offsets = dest_cumsum[inverse_indices] |
394 | 385 |
|
395 | 386 | # Create row_id_map: for each token position, compute its destination |
396 | 387 | # First, figure out which chunk each position belongs to |
397 | 388 | position_indices = jnp.arange(num_tokens, dtype=jnp.int32) |
398 | 389 |
|
399 | 390 | # chunk_ids[i] = which chunk position i belongs to |
400 | | - chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right") |
| 391 | + chunk_ids = jnp.searchsorted(src_cumsum[1:], position_indices, side="right").astype(jnp.int32) |
401 | 392 |
|
402 | 393 | # within_chunk_offset[i] = position i's offset within its chunk |
403 | 394 | within_chunk_offset = position_indices - src_cumsum[chunk_ids] |
|
0 commit comments