5
5
@triton .jit
6
6
def streaming_topk (X , stride_xm , n_expts_tot , offs_m , mask_m , N_EXPTS_PAD : tl .constexpr , N_EXPTS_ACT : tl .constexpr ,
7
7
BLOCK_N : tl .constexpr ):
8
+ x_nbits : tl .constexpr = X .dtype .element_ty .primitive_bitwidth
9
+ x_utype : tl .constexpr = tl .dtype (f"uint{ x_nbits } " )
10
+ x_ultype : tl .constexpr = tl .dtype (f"uint{ 2 * x_nbits } " )
11
+ x_dbtype : tl .constexpr = tl .dtype (f"fp{ 2 * x_nbits } " )
8
12
9
13
# subtract 1 from loop iterations because we peel the first (masked) iteration:
10
14
loop_iterations : tl .constexpr = N_EXPTS_PAD // BLOCK_N - 1
@@ -15,8 +19,8 @@ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.co
15
19
# first iteration:
16
20
X_ptrs = X + offs_m [:, None ] * stride_xm + offs_x_n [None , :]
17
21
x = tl .load (X_ptrs , mask = (mask_m & mask_n ), other = float ("-inf" ))
18
- x = (x .to (tl . uint16 , bitcast = True ).to (tl . int32 ) << 16 ) | offs_x_n [None , :]
19
- x = x .to (tl . float32 , bitcast = True )
22
+ x = (x .to (x_utype , bitcast = True ).to (x_ultype ) << x_nbits ) | offs_x_n [None , :]
23
+ x = x .to (x_dbtype , bitcast = True )
20
24
21
25
acc = tl .topk (x , N_EXPTS_ACT , dim = 1 )
22
26
@@ -26,8 +30,8 @@ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.co
26
30
X_ptrs -= BLOCK_N
27
31
offs_x_n -= BLOCK_N
28
32
x = tl .load (X_ptrs , mask = mask_m , other = float ("-inf" ))
29
- x = (x .to (tl . uint16 , bitcast = True ).to (tl . int32 ) << 16 ) | offs_x_n [None , :]
30
- x = x .to (tl . float32 , bitcast = True )
33
+ x = (x .to (x_utype , bitcast = True ).to (x_ultype ) << x_nbits ) | offs_x_n [None , :]
34
+ x = x .to (x_dbtype , bitcast = True )
31
35
acc = tl .maximum (acc , tl .topk (x , N_EXPTS_ACT , dim = 1 ))
32
36
33
37
return acc
@@ -43,18 +47,21 @@ def _topk(X, stride_xm, # inputs
43
47
tl .static_assert (BLOCK_N % 32 == 0 )
44
48
tl .static_assert (N_EXPTS_PAD % BLOCK_N == 0 )
45
49
x_dtype : tl .constexpr = X .dtype .element_ty
50
+ x_nbits : tl .constexpr = X .dtype .element_ty .primitive_bitwidth
51
+ x_utype : tl .constexpr = tl .dtype (f"uint{ x_nbits } " )
52
+ x_ultype : tl .constexpr = tl .dtype (f"uint{ 2 * x_nbits } " )
46
53
47
54
# load logits
48
55
offs_m = tl .program_id (0 ) * BLOCK_M + tl .arange (0 , BLOCK_M )
49
56
mask_m = offs_m [:, None ] < n_rows
50
57
y = streaming_topk (X , stride_xm , n_expts_tot , offs_m , mask_m , N_EXPTS_PAD , N_EXPTS_ACT , BLOCK_N )
51
- y = y .to (tl . uint32 , bitcast = True )
58
+ y = y .to (x_ultype , bitcast = True )
52
59
53
60
# sort result in direction of ascending expert index
54
- y = (y << 16 ) | (y >> 16 )
61
+ y = (y << x_nbits ) | (y >> x_nbits )
55
62
y = tl .sort (y , dim = 1 )
56
- y_indices = y >> 16
57
- y_values = (y & 0x0000FFFF ) .to (tl . uint16 ).to (x_dtype , bitcast = True )
63
+ y_indices = y >> x_nbits
64
+ y_values = (y & (( 1 << x_nbits ) - 1 )) .to (x_utype ).to (x_dtype , bitcast = True )
58
65
y_values = tl .softmax (y_values .to (tl .float32 ), dim = 1 , keep_dims = True ).to (x_dtype )
59
66
60
67
# write back
0 commit comments