@@ -81,12 +81,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
8181 b_h4 = tl .zeros ([64 , BV ], dtype = tl .float32 )
8282
8383 # calculate offset
84- h += (boh * H + i_h ) * K * V
85- v += (bos * H + i_h ) * V
86- k += (bos * H + i_h ) * K
87- w += (bos * H + i_h ) * K
84+ h += (( boh * H + i_h ) * K * V ). to ( tl . int64 )
85+ v += (( bos * H + i_h ) * V ). to ( tl . int64 )
86+ k += (( bos * H + i_h ) * K ). to ( tl . int64 )
87+ w += (( bos * H + i_h ) * K ). to ( tl . int64 )
8888 if SAVE_NEW_VALUE :
89- v_new += (bos * H + i_h ) * V
89+ v_new += (( bos * H + i_h ) * V ). to ( tl . int64 )
9090 stride_v = H * V
9191 stride_h = H * K * V
9292 stride_k = H * K
@@ -181,30 +181,18 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
181181
182182 p_k = tl .make_block_ptr (k , (K , T ), (1 , stride_k ), (0 , i_t * BT ), (64 , BT ), (0 , 1 ))
183183 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
184- if USE_GK :
185- p_g = tl .make_block_ptr (gk + (bos * H + i_h ) * K , (K , T ), (1 , H * K ), (0 , i_t * BT ), (64 , BT ), (0 , 1 ))
186- b_k = (b_k * exp (b_gk_last1 [:, None ] - tl .load (p_g , boundary_check = (0 , 1 )))).to (b_k .dtype )
187184 b_h1 += tl .dot (b_k , b_v )
188185 if K > 64 :
189186 p_k = tl .make_block_ptr (k , (K , T ), (1 , stride_k ), (64 , i_t * BT ), (64 , BT ), (0 , 1 ))
190187 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
191- if USE_GK :
192- p_g = tl .make_block_ptr (gk + (bos * H + i_h ) * K , (K , T ), (1 , H * K ), (64 , i_t * BT ), (64 , BT ), (0 , 1 ))
193- b_k = (b_k * exp (b_gk_last2 [:, None ] - tl .load (p_g , boundary_check = (0 , 1 )))).to (b_k .dtype )
194188 b_h2 += tl .dot (b_k , b_v )
195189 if K > 128 :
196190 p_k = tl .make_block_ptr (k , (K , T ), (1 , stride_k ), (128 , i_t * BT ), (64 , BT ), (0 , 1 ))
197191 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
198- if USE_GK :
199- p_g = tl .make_block_ptr (gk + (bos * H + i_h ) * K , (K , T ), (1 , H * K ), (128 , i_t * BT ), (64 , BT ), (0 , 1 ))
200- b_k = (b_k * exp (b_gk_last3 [:, None ] - tl .load (p_g , boundary_check = (0 , 1 )))).to (b_k .dtype )
201192 b_h3 += tl .dot (b_k , b_v )
202193 if K > 192 :
203194 p_k = tl .make_block_ptr (k , (K , T ), (1 , stride_k ), (192 , i_t * BT ), (64 , BT ), (0 , 1 ))
204195 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
205- if USE_GK :
206- p_g = tl .make_block_ptr (gk + (bos * H + i_h ) * K , (K , T ), (1 , H * K ), (192 , i_t * BT ), (64 , BT ), (0 , 1 ))
207- b_k = (b_k * exp (b_gk_last4 [:, None ] - tl .load (p_g , boundary_check = (0 , 1 )))).to (b_k .dtype )
208196 b_h4 += tl .dot (b_k , b_v )
209197 # epilogue
210198 if STORE_FINAL_STATE :
@@ -223,6 +211,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
223211
224212@triton .heuristics ({
225213 'USE_G' : lambda args : args ['g' ] is not None ,
214+ 'USE_GK' : lambda args : args ['gk' ] is not None ,
226215 'USE_INITIAL_STATE' : lambda args : args ['dh0' ] is not None ,
227216 'USE_FINAL_STATE_GRADIENT' : lambda args : args ['dht' ] is not None ,
228217 'IS_VARLEN' : lambda args : args ['cu_seqlens' ] is not None ,
@@ -244,6 +233,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
244233 k ,
245234 w ,
246235 g ,
236+ gk ,
247237 dht ,
248238 dh0 ,
249239 do ,
@@ -260,6 +250,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
260250 BT : tl .constexpr ,
261251 BV : tl .constexpr ,
262252 USE_G : tl .constexpr ,
253+ USE_GK : tl .constexpr ,
263254 USE_INITIAL_STATE : tl .constexpr ,
264255 USE_FINAL_STATE_GRADIENT : tl .constexpr ,
265256 IS_VARLEN : tl .constexpr
@@ -286,13 +277,16 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
286277 b_dh4 = tl .zeros ([64 , BV ], dtype = tl .float32 )
287278
288279 # calculate offset
289- dh += (boh * H + i_h ) * K * V
290- dv += (bos * H + i_h ) * V
291- dv2 += (bos * H + i_h ) * V
292- q += (bos * H + i_h ) * K
293- k += (bos * H + i_h ) * K
294- w += (bos * H + i_h ) * K
295- do += (bos * H + i_h ) * V
280+ q += ((bos * H + i_h ) * K ).to (tl .int64 )
281+ k += ((bos * H + i_h ) * K ).to (tl .int64 )
282+ w += ((bos * H + i_h ) * K ).to (tl .int64 )
283+ do += ((bos * H + i_h ) * V ).to (tl .int64 )
284+ dv += ((bos * H + i_h ) * V ).to (tl .int64 )
285+ dv2 += ((bos * H + i_h ) * V ).to (tl .int64 )
286+ dh += ((boh * H + i_h ) * K * V ).to (tl .int64 )
287+ if USE_GK :
288+ gk += ((bos * H + i_h ) * K ).to (tl .int64 )
289+
296290 stride_v = H * V
297291 stride_h = H * K * V
298292 stride_k = H * K
@@ -327,44 +321,50 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
327321 p_dh4 = tl .make_block_ptr (dh + i_t * stride_h , (K , V ), (V , 1 ), (192 , i_v * BV ), (64 , BV ), (1 , 0 ))
328322 tl .store (p_dh4 , b_dh4 .to (p_dh4 .dtype .element_ty ), boundary_check = (0 , 1 ))
329323
324+ last_idx = min ((i_t + 1 ) * BT , T ) - 1
330325 if USE_G :
331- last_idx = min ((i_t + 1 ) * BT , T ) - 1
332326 bg_last = tl .load (g + (bos + last_idx ) * H + i_h )
333327 bg_last_exp = exp (bg_last )
334328 p_g = tl .make_block_ptr (g + bos * H + i_h , (T ,), (H ,), (i_t * BT ,), (BT ,), (0 ,))
335329 b_g = tl .load (p_g , boundary_check = (0 ,))
336330 b_g_exp = exp (b_g )
337- else :
338- bg_last = None
339- last_idx = None
340- b_g = None
341- b_g_exp = None
342331
343332 p_dv = tl .make_block_ptr (dv , (T , V ), (stride_v , 1 ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
344- p_do = tl .make_block_ptr (do , (T , V ), (stride_v , 1 ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
345333 p_dv2 = tl .make_block_ptr (dv2 , (T , V ), (stride_v , 1 ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
334+ p_do = tl .make_block_ptr (do , (T , V ), (stride_v , 1 ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
346335
347336 b_do = tl .load (p_do , boundary_check = (0 , 1 ))
348- b_dv = tl .zeros ([BT , BV ], dtype = tl .float32 )
349337
350338 # Update dv
351339 p_k = tl .make_block_ptr (k , (T , K ), (stride_k , 1 ), (i_t * BT , 0 ), (BT , 64 ), (1 , 0 ))
352340 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
353- b_dv += tl .dot (b_k , b_dh1 .to (b_k .dtype ))
341+ if USE_GK :
342+ o_k1 = tl .arange (0 , 64 )
343+ b_gk_last1 = tl .load (gk + last_idx * H * K + o_k1 , mask = (o_k1 < K ), other = 0. )
344+ b_dv = tl .dot (b_k , b_dh1 .to (b_k .dtype ))
354345
355346 if K > 64 :
356347 p_k = tl .make_block_ptr (k , (T , K ), (stride_k , 1 ), (i_t * BT , 64 ), (BT , 64 ), (1 , 0 ))
357348 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
349+ if USE_GK :
350+ o_k2 = 64 + o_k1
351+ b_gk_last2 = tl .load (gk + last_idx * H * K + o_k2 , mask = (o_k2 < K ), other = 0. )
358352 b_dv += tl .dot (b_k , b_dh2 .to (b_k .dtype ))
359353
360354 if K > 128 :
361355 p_k = tl .make_block_ptr (k , (T , K ), (stride_k , 1 ), (i_t * BT , 128 ), (BT , 64 ), (1 , 0 ))
362356 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
357+ if USE_GK :
358+ o_k3 = 128 + o_k1
359+ b_gk_last3 = tl .load (gk + last_idx * H * K + o_k3 , mask = (o_k3 < K ), other = 0. )
363360 b_dv += tl .dot (b_k , b_dh3 .to (b_k .dtype ))
364361
365362 if K > 192 :
366363 p_k = tl .make_block_ptr (k , (T , K ), (stride_k , 1 ), (i_t * BT , 192 ), (BT , 64 ), (1 , 0 ))
367364 b_k = tl .load (p_k , boundary_check = (0 , 1 ))
365+ if USE_GK :
366+ o_k4 = 192 + o_k1
367+ b_gk_last4 = tl .load (gk + last_idx * H * K + o_k4 , mask = (o_k4 < K ), other = 0. )
368368 b_dv += tl .dot (b_k , b_dh4 .to (b_k .dtype ))
369369
370370 if USE_G :
@@ -381,8 +381,9 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
381381 if USE_G :
382382 b_dh1 *= bg_last_exp
383383 b_q = b_q * b_g_exp [None , :]
384- b_q = (b_q * scale ).to (b_q .dtype )
385- b_dh1 += tl .dot (b_q , b_do .to (b_q .dtype ))- tl .dot (b_w , b_dv .to (b_w .dtype ))
384+ if USE_GK :
385+ b_dh1 *= exp (b_gk_last1 [:, None ])
386+ b_dh1 += tl .dot (b_q .to (b_q .dtype ), b_do .to (b_q .dtype )) * scale - tl .dot (b_w , b_dv .to (b_w .dtype ))
386387 if K > 64 :
387388 p_q = tl .make_block_ptr (q , (K , T ), (1 , stride_k ), (64 , i_t * BT ), (64 , BT ), (0 , 1 ))
388389 p_w = tl .make_block_ptr (w , (K , T ), (1 , stride_k ), (64 , i_t * BT ), (64 , BT ), (0 , 1 ))
@@ -391,8 +392,9 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
391392 if USE_G :
392393 b_dh2 *= bg_last_exp
393394 b_q = b_q * b_g_exp [None , :]
394- b_q = (b_q * scale ).to (b_q .dtype )
395- b_dh2 += tl .dot (b_q , b_do .to (b_q .dtype ))- tl .dot (b_w , b_dv .to (b_w .dtype ))
395+ if USE_GK :
396+ b_dh2 *= exp (b_gk_last2 [:, None ])
397+ b_dh2 += tl .dot (b_q .to (b_q .dtype ), b_do .to (b_q .dtype )) * scale - tl .dot (b_w , b_dv .to (b_w .dtype ))
396398 if K > 128 :
397399 p_q = tl .make_block_ptr (q , (K , T ), (1 , stride_k ), (128 , i_t * BT ), (64 , BT ), (0 , 1 ))
398400 p_w = tl .make_block_ptr (w , (K , T ), (1 , stride_k ), (128 , i_t * BT ), (64 , BT ), (0 , 1 ))
@@ -401,8 +403,9 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
401403 if USE_G :
402404 b_dh3 *= bg_last_exp
403405 b_q = b_q * b_g_exp [None , :]
404- b_q = (b_q * scale ).to (b_q .dtype )
405- b_dh3 += tl .dot (b_q , b_do .to (b_q .dtype ))- tl .dot (b_w , b_dv .to (b_w .dtype ))
406+ if USE_GK :
407+ b_dh3 *= exp (b_gk_last3 [:, None ])
408+ b_dh3 += tl .dot (b_q .to (b_q .dtype ), b_do .to (b_q .dtype )) * scale - tl .dot (b_w , b_dv .to (b_w .dtype ))
406409 if K > 192 :
407410 p_q = tl .make_block_ptr (q , (K , T ), (1 , stride_k ), (192 , i_t * BT ), (64 , BT ), (0 , 1 ))
408411 p_w = tl .make_block_ptr (w , (K , T ), (1 , stride_k ), (192 , i_t * BT ), (64 , BT ), (0 , 1 ))
@@ -411,8 +414,9 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
411414 if USE_G :
412415 b_dh4 *= bg_last_exp
413416 b_q = b_q * b_g_exp [None , :]
414- b_q = (b_q * scale ).to (b_q .dtype )
415- b_dh4 += tl .dot (b_q , b_do .to (b_q .dtype ))- tl .dot (b_w , b_dv .to (b_w .dtype ))
417+ if USE_GK :
418+ b_dh4 *= exp (b_gk_last4 [:, None ])
419+ b_dh4 += tl .dot (b_q .to (b_q .dtype ), b_do .to (b_q .dtype )) * scale - tl .dot (b_w , b_dv .to (b_w .dtype ))
416420
417421 if USE_INITIAL_STATE :
418422 p_dh0 = tl .make_block_ptr (dh0 , (K , V ), (V , 1 ), (0 , i_v * BV ), (64 , BV ), (1 , 0 ))
@@ -481,12 +485,13 @@ def chunk_gated_delta_rule_bwd_dhu(
481485 q : torch .Tensor ,
482486 k : torch .Tensor ,
483487 w : torch .Tensor ,
484- g : torch .Tensor ,
485- h0 : torch .Tensor ,
486- dht : Optional [torch .Tensor ],
487488 do : torch .Tensor ,
488489 dv : torch .Tensor ,
489- scale : float ,
490+ g : Optional [torch .Tensor ] = None ,
491+ gk : Optional [torch .Tensor ] = None ,
492+ h0 : Optional [torch .Tensor ] = None ,
493+ dht : Optional [torch .Tensor ] = None ,
494+ scale : Optional [float ] = None ,
490495 cu_seqlens : Optional [torch .LongTensor ] = None ,
491496 chunk_size : int = 64 , # SY: remove this argument and force chunk size 64?
492497) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
@@ -511,6 +516,7 @@ def grid(meta): return (triton.cdiv(V, meta['BV']), N*H)
511516 k = k ,
512517 w = w ,
513518 g = g ,
519+ gk = gk ,
514520 dht = dht ,
515521 dh0 = dh0 ,
516522 do = do ,
0 commit comments