|
32 | 32 | }) |
33 | 33 | @triton.autotune( |
34 | 34 | configs=[ |
35 | | - triton.Config({}, num_warps=num_warps) |
36 | | - for num_warps in [1, 2, 4] |
| 35 | + triton.Config({}, num_warps=num_warps, num_stages=num_stages) |
| 36 | + for num_warps in [1, 2, 4, 8] |
| 37 | + for num_stages in [2, 3, 4, 5] |
37 | 38 | ], |
38 | 39 | key=['BS', 'BK', 'BV'], |
39 | 40 | ) |
@@ -131,8 +132,9 @@ def parallel_nsa_compression_fwd_kernel( |
131 | 132 | }) |
132 | 133 | @triton.autotune( |
133 | 134 | configs=[ |
134 | | - triton.Config({}, num_warps=num_warps) |
| 135 | + triton.Config({}, num_warps=num_warps, num_stages=num_stages) |
135 | 136 | for num_warps in [1, 2, 4, 8] |
| 137 | + for num_stages in [2, 3, 4, 5] |
136 | 138 | ], |
137 | 139 | key=['BS', 'BK', 'BV'], |
138 | 140 | ) |
@@ -327,8 +329,9 @@ def parallel_nsa_compression_bwd_kernel_dkv( |
327 | 329 | }) |
328 | 330 | @triton.autotune( |
329 | 331 | configs=[ |
330 | | - triton.Config({}, num_warps=num_warps) |
331 | | - for num_warps in [1, 2, 4] |
| 332 | + triton.Config({}, num_warps=num_warps, num_stages=num_stages) |
| 333 | + for num_warps in [1, 2, 4, 8] |
| 334 | + for num_stages in [2, 3, 4, 5] |
332 | 335 | ], |
333 | 336 | key=['BS', 'BK'], |
334 | 337 | ) |
@@ -603,8 +606,9 @@ def parallel_nsa_bwd_kernel_preprocess( |
603 | 606 | }) |
604 | 607 | @triton.autotune( |
605 | 608 | configs=[ |
606 | | - triton.Config({}, num_warps=num_warps) |
| 609 | + triton.Config({}, num_warps=num_warps, num_stages=num_stages) |
607 | 610 | for num_warps in [1, 2, 4, 8] |
| 611 | + for num_stages in [2, 3, 4, 5] |
608 | 612 | ], |
609 | 613 | key=['BS', 'BK', 'BV'], |
610 | 614 | ) |
@@ -710,8 +714,9 @@ def parallel_nsa_bwd_kernel_dq( |
710 | 714 | }) |
711 | 715 | @triton.autotune( |
712 | 716 | configs=[ |
713 | | - triton.Config({}, num_warps=num_warps) |
| 717 | + triton.Config({}, num_warps=num_warps, num_stages=num_stages) |
714 | 718 | for num_warps in [1, 2, 4, 8] |
| 719 | + for num_stages in [2, 3, 4, 5] |
715 | 720 | ], |
716 | 721 | key=['BS', 'BK', 'BV'], |
717 | 722 | ) |
@@ -1030,7 +1035,7 @@ def parallel_nsa_topk( |
1030 | 1035 | BC = BS = block_size |
1031 | 1036 | BK = triton.next_power_of_2(K) |
1032 | 1037 |
|
1033 | | - block_indices = torch.zeros(B, T, H, S, dtype=torch.long, device=q.device) |
| 1038 | + block_indices = torch.zeros(B, T, H, S, dtype=torch.int32, device=q.device) |
1034 | 1039 | token_indices = prepare_token_indices(offsets) if offsets is not None else None |
1035 | 1040 | grid = (T, B * H) |
1036 | 1041 | parallel_nsa_kernel_topk[grid]( |
|
0 commit comments