@@ -56,162 +56,87 @@ def gating_softmax_topk( # pylint: disable=too-many-statements
56
56
index_dtype = "int32"
57
57
58
58
TX = 1024
59
- SCAN_LEN_2 = 2
60
- SCAN_LEN_4 = 4
61
59
62
- # specialized kernel for top 2 case
63
- @T .prim_func (private = True )
64
- def top2_softmax_norm_func (
65
- var_x : T .handle ,
66
- var_out : T .handle ,
67
- var_out_index : T .handle ,
68
- ) -> None :
69
- T .func_attr ({"tir.noalias" : True , "tir.is_scheduled" : True })
70
- batch_size = T .int64 ()
71
- x = T .match_buffer (var_x , (batch_size , num_local_experts ), dtype )
72
- out = T .match_buffer (var_out , (batch_size , SCAN_LEN_2 ), dtype )
73
- out_index = T .match_buffer (var_out_index , (batch_size , SCAN_LEN_2 ), index_dtype )
74
- local_top_k = T .alloc_buffer ((SCAN_LEN_2 ,), dtype = dtype , scope = "local" )
75
- local_top_k_index = T .alloc_buffer ((SCAN_LEN_2 ,), dtype = index_dtype , scope = "local" )
76
- local_top_k_f32 = T .alloc_buffer ((SCAN_LEN_2 ,), dtype = "float32" , scope = "local" )
77
- local_top_k_max = T .alloc_buffer ((1 ,), dtype = "float32" , scope = "local" )
78
- for io in T .thread_binding (0 , T .ceildiv (batch_size , TX ), "blockIdx.x" ):
79
- for ii in T .thread_binding (0 , TX , "threadIdx.x" ):
80
- with T .block ("top_k" ):
81
- vi = T .axis .spatial (batch_size , io * TX + ii )
82
- T .where (io * TX + ii < batch_size )
83
- with T .block ("init" ):
84
- local_top_k [0 ] = T .min_value (dtype )
85
- local_top_k [1 ] = T .min_value (dtype )
86
- local_top_k_index [0 ] = 0
87
- local_top_k_index [1 ] = 1
88
- for k in range (num_local_experts ):
89
- with T .block ("update" ):
90
- vk = T .axis .remap ("S" , [k ])
91
- # N.B. This snippet is specialized for k = 2
92
- if x [vi , vk ] > local_top_k [0 ]:
93
- local_top_k [1 ] = local_top_k [0 ]
94
- local_top_k_index [1 ] = local_top_k_index [0 ]
95
- local_top_k [0 ] = x [vi , vk ]
96
- local_top_k_index [0 ] = vk
97
- elif x [vi , vk ] > local_top_k [1 ]:
98
- local_top_k [1 ] = x [vi , vk ]
99
- local_top_k_index [1 ] = vk
100
- for j in T .unroll (SCAN_LEN_2 ):
101
- with T .block ("cast" ):
102
- vj = T .axis .remap ("S" , [j ])
103
- local_top_k_f32 [vj ] = T .cast (local_top_k [vj ], "float32" )
104
- with T .block ("max" ):
105
- local_top_k_max [0 ] = T .max (local_top_k_f32 [0 ], local_top_k_f32 [1 ])
106
- for j in T .unroll (SCAN_LEN_2 ):
107
- with T .block ("output" ):
108
- vj = T .axis .remap ("S" , [j ])
109
- out [vi , vj ] = T .cast (
110
- T .exp (local_top_k_f32 [vj ] - local_top_k_max [0 ])
111
- / (
112
- T .exp (local_top_k_f32 [0 ] - local_top_k_max [0 ])
113
- + T .exp (local_top_k_f32 [1 ] - local_top_k_max [0 ])
114
- ),
115
- dtype ,
116
- )
117
- out_index [vi , vj ] = local_top_k_index [vj ]
118
-
119
- # specialized kernel for top 4 case
120
- @T .prim_func (private = True )
121
- def top4_softmax_norm_func (
122
- var_x : T .handle ,
123
- var_out : T .handle ,
124
- var_out_index : T .handle ,
125
- ) -> None :
126
- T .func_attr ({"tir.noalias" : True , "tir.is_scheduled" : True })
127
- batch_size = T .int64 ()
128
- x = T .match_buffer (var_x , (batch_size , num_local_experts ), dtype )
129
- out = T .match_buffer (var_out , (batch_size , SCAN_LEN_4 ), dtype )
130
- out_index = T .match_buffer (var_out_index , (batch_size , SCAN_LEN_4 ), index_dtype )
131
- local_top_k = T .alloc_buffer ((SCAN_LEN_4 ,), dtype = dtype , scope = "local" )
132
- local_top_k_index = T .alloc_buffer ((SCAN_LEN_4 ,), dtype = index_dtype , scope = "local" )
133
- for io in T .thread_binding (0 , T .ceildiv (batch_size , TX ), "blockIdx.x" ):
134
- for ii in T .thread_binding (0 , TX , "threadIdx.x" ):
135
- with T .block ("top_k" ):
136
- vi = T .axis .spatial (batch_size , io * TX + ii )
137
- T .where (io * TX + ii < batch_size )
138
- with T .block ("init" ):
139
- local_top_k [0 ] = T .min_value (dtype )
140
- local_top_k [1 ] = T .min_value (dtype )
141
- local_top_k [2 ] = T .min_value (dtype )
142
- local_top_k [3 ] = T .min_value (dtype )
143
- local_top_k_index [0 ] = 0
144
- local_top_k_index [1 ] = 1
145
- local_top_k_index [2 ] = 2
146
- local_top_k_index [3 ] = 3
147
- for k in range (num_local_experts ):
148
- with T .block ("update" ):
149
- vk = T .axis .remap ("S" , [k ])
150
- # N.B. This snippet is specialized for k = 4
151
- if x [vi , vk ] > local_top_k [0 ]:
152
- local_top_k [3 ] = local_top_k [2 ]
153
- local_top_k_index [3 ] = local_top_k_index [2 ]
154
- local_top_k [2 ] = local_top_k [1 ]
155
- local_top_k_index [2 ] = local_top_k_index [1 ]
156
- local_top_k [1 ] = local_top_k [0 ]
157
- local_top_k_index [1 ] = local_top_k_index [0 ]
158
- local_top_k [0 ] = x [vi , vk ]
159
- local_top_k_index [0 ] = vk
160
- elif x [vi , vk ] > local_top_k [1 ]:
161
- local_top_k [3 ] = local_top_k [2 ]
162
- local_top_k_index [3 ] = local_top_k_index [2 ]
163
- local_top_k [2 ] = local_top_k [1 ]
164
- local_top_k_index [2 ] = local_top_k_index [1 ]
165
- local_top_k [1 ] = x [vi , vk ]
166
- local_top_k_index [1 ] = vk
167
- elif x [vi , vk ] > local_top_k [2 ]:
168
- local_top_k [3 ] = local_top_k [2 ]
169
- local_top_k_index [3 ] = local_top_k_index [2 ]
170
- local_top_k [2 ] = x [vi , vk ]
171
- local_top_k_index [2 ] = vk
172
- elif x [vi , vk ] > local_top_k [3 ]:
173
- local_top_k [3 ] = x [vi , vk ]
174
- local_top_k_index [3 ] = vk
175
- for j in T .unroll (SCAN_LEN_4 ):
176
- with T .block ("output" ):
177
- vj = T .axis .remap ("S" , [j ])
178
- out [vi , vj ] = local_top_k [vj ]
179
- out_index [vi , vj ] = local_top_k_index [vj ]
180
-
181
- # fast path for Mixtral
182
- if k == 2 and norm_topk_prob :
60
+ def _get_topk_softmax_norm_func (k_val : int ):
61
+ def _init_local_top_k (local_top_k , local_top_k_index ):
62
+ for t in range (k_val ):
63
+ T .buffer_store (local_top_k , T .min_value (dtype ), indices = [t ])
64
+ for t in range (k_val ):
65
+ T .buffer_store (local_top_k_index , t , indices = [t ])
66
+
67
+ def _process_value (x , local_top_k , local_top_k_index , vi , vk ):
68
+ if_frames = [T .If (x [vi , vk ] > local_top_k [i ]) for i in range (k_val )]
69
+ then_frames = [T .Then () for _ in range (k_val )]
70
+ else_frames = [T .Else () for _ in range (k_val - 1 )]
71
+ for i in range (k_val ):
72
+ if_frames [i ].__enter__ () # pylint: disable=unnecessary-dunder-call
73
+ with then_frames [i ]:
74
+ for j in range (k_val - 1 , i , - 1 ):
75
+ T .buffer_store (local_top_k , local_top_k [j - 1 ], indices = [j ])
76
+ T .buffer_store (local_top_k_index , local_top_k_index [j - 1 ], indices = [j ])
77
+ T .buffer_store (local_top_k , x [vi , vk ], indices = [i ])
78
+ T .buffer_store (local_top_k_index , vk , indices = [i ])
79
+ if i != k_val - 1 :
80
+ else_frames [i ].__enter__ () # pylint: disable=unnecessary-dunder-call
81
+
82
+ for i in range (k_val - 1 , - 1 , - 1 ):
83
+ if i != k_val - 1 :
84
+ else_frames [i ].__exit__ (None , None , None )
85
+ if_frames [i ].__exit__ (None , None , None )
86
+
87
+ @T .prim_func (private = True )
88
+ def topk_softmax_norm_func (
89
+ var_x : T .handle ,
90
+ var_out : T .handle ,
91
+ var_out_index : T .handle ,
92
+ ) -> None :
93
+ T .func_attr ({"tir.noalias" : True , "tir.is_scheduled" : True })
94
+ batch_size = T .int64 ()
95
+ x = T .match_buffer (var_x , (batch_size , num_local_experts ), dtype )
96
+ out = T .match_buffer (var_out , (batch_size , k_val ), dtype )
97
+ out_index = T .match_buffer (var_out_index , (batch_size , k_val ), index_dtype )
98
+ local_top_k = T .alloc_buffer ((k_val ,), dtype = dtype , scope = "local" )
99
+ local_top_k_index = T .alloc_buffer ((k_val ,), dtype = index_dtype , scope = "local" )
100
+ for io in T .thread_binding (0 , T .ceildiv (batch_size , TX ), "blockIdx.x" ):
101
+ for ii in T .thread_binding (0 , TX , "threadIdx.x" ):
102
+ with T .block ("top_k" ):
103
+ vi = T .axis .spatial (batch_size , io * TX + ii )
104
+ T .where (io * TX + ii < batch_size )
105
+ with T .block ("init" ):
106
+ _init_local_top_k (local_top_k , local_top_k_index )
107
+ for k in range (num_local_experts ):
108
+ with T .block ("update" ):
109
+ vk = T .axis .remap ("S" , [k ])
110
+ _process_value (x , local_top_k , local_top_k_index , vi , vk )
111
+ for j in T .unroll (k_val ):
112
+ with T .block ("output" ):
113
+ vj = T .axis .remap ("S" , [j ])
114
+ out [vi , vj ] = local_top_k [vj ]
115
+ out_index [vi , vj ] = local_top_k_index [vj ]
116
+
117
+ return topk_softmax_norm_func
118
+
119
+ if norm_topk_prob :
183
120
return op .tensor_ir_op (
184
- top2_softmax_norm_func ,
185
- "top2_softmax " ,
121
+ _get_topk_softmax_norm_func ( k ) ,
122
+ f"top { k } _softmax " ,
186
123
args = [x ],
187
124
out = (
188
- Tensor .placeholder ([batch_size , 2 ], dtype ),
189
- Tensor .placeholder ([batch_size , 2 ], index_dtype ),
190
- ),
191
- )
192
- if k == 4 and not norm_topk_prob :
193
- expert_score = op .softmax (x .astype ("float32" ), axis = - 1 ).astype (dtype )
194
- return op .tensor_ir_op (
195
- top4_softmax_norm_func ,
196
- "top4_softmax" ,
197
- args = [expert_score ],
198
- out = (
199
- Tensor .placeholder ([batch_size , 4 ], dtype ),
200
- Tensor .placeholder ([batch_size , 4 ], index_dtype ),
125
+ Tensor .placeholder ([batch_size , k ], dtype ),
126
+ Tensor .placeholder ([batch_size , k ], index_dtype ),
201
127
),
202
128
)
203
- if norm_topk_prob :
204
- # Compute topk first and then softmax to avoid extra re-normalize
205
- expert_score , expert_indices = op .topk (
206
- x , k , axis = - 1 , ret_type = "both" , largest = True , dtype = index_dtype
207
- )
208
- expert_score = op .softmax (expert_score .astype ("float32" ), axis = - 1 ).astype (dtype )
209
- else :
210
- expert_score = op .softmax (x .astype ("float32" ), axis = - 1 ).astype (dtype )
211
- expert_score , expert_indices = op .topk (
212
- expert_score , k , axis = - 1 , ret_type = "both" , largest = True , dtype = index_dtype
213
- )
214
- return expert_score , expert_indices
129
+
130
+ expert_score = op .softmax (x .astype ("float32" ), axis = - 1 ).astype (dtype )
131
+ return op .tensor_ir_op (
132
+ _get_topk_softmax_norm_func (k ),
133
+ f"top{ k } _softmax" ,
134
+ args = [expert_score ],
135
+ out = (
136
+ Tensor .placeholder ([batch_size , k ], dtype ),
137
+ Tensor .placeholder ([batch_size , k ], index_dtype ),
138
+ ),
139
+ )
215
140
216
141
217
142
def moe_cumsum (expert_indices : Tensor , num_local_experts : int ) -> Tensor :
0 commit comments