@@ -300,7 +300,7 @@ def parallel_deltaformer_bwd_kernel_u(
300300 block_shape = (BLOCK_C , D ),
301301 order = (1 , 0 ),
302302 )
303- q = tl .load (q_blk_ptr )
303+ q = tl .load (q_blk_ptr , boundary_check = ( 0 ,) )
304304
305305 for kv_i in range (0 , T , BLOCK_T ):
306306 k_blk_ptr = tl .make_block_ptr (
@@ -311,7 +311,7 @@ def parallel_deltaformer_bwd_kernel_u(
311311 block_shape = (D , BLOCK_T ),
312312 order = (0 , 1 ),
313313 )
314- k = tl .load (k_blk_ptr )
314+ k = tl .load (k_blk_ptr , boundary_check = ( 1 ,) )
315315 qk = tl .dot (q , k ) * fa_scale
316316
317317 lse_blk_ptr = tl .make_block_ptr (
@@ -322,7 +322,7 @@ def parallel_deltaformer_bwd_kernel_u(
322322 block_shape = (BLOCK_T ,),
323323 order = (0 ,),
324324 )
325- lse = tl .load (lse_blk_ptr )
325+ lse = tl .load (lse_blk_ptr , boundary_check = ( 0 ,) )
326326 beta_blk_ptr = tl .make_block_ptr (
327327 base = beta_ptr + pid_h ,
328328 shape = (T ,),
@@ -331,7 +331,7 @@ def parallel_deltaformer_bwd_kernel_u(
331331 block_shape = (BLOCK_T ,),
332332 order = (0 ,),
333333 )
334- beta = tl .load (beta_blk_ptr )
334+ beta = tl .load (beta_blk_ptr , boundary_check = ( 0 ,) )
335335
336336 p = tl .math .exp2 (qk - lse [None , :]) * beta [None , :]
337337
@@ -343,7 +343,7 @@ def parallel_deltaformer_bwd_kernel_u(
343343 block_shape = (BLOCK_T , D ),
344344 order = (1 , 0 ),
345345 )
346- v = tl .load (v_blk_ptr )
346+ v = tl .load (v_blk_ptr , boundary_check = ( 0 ,) )
347347 acc = tl .dot (p .to (v_ptr .dtype .element_ty ), v , acc )
348348
349349 o_blk_ptr = tl .make_block_ptr (
@@ -354,7 +354,7 @@ def parallel_deltaformer_bwd_kernel_u(
354354 block_shape = (BLOCK_C , D ),
355355 order = (1 , 0 ),
356356 )
357- tl .store (o_blk_ptr , acc .to (o_ptr .dtype .element_ty ))
357+ tl .store (o_blk_ptr , acc .to (o_ptr .dtype .element_ty ), boundary_check = ( 0 ,) )
358358
359359
360360@triton .autotune (configs = _config_deltaformer (), key = ['T' , 'D' ])
@@ -389,7 +389,7 @@ def parallel_deltaformer_bwd_kernel_row_sum(
389389 block_shape = (BLOCK_C , D ),
390390 order = (1 , 0 ),
391391 )
392- k_row = tl .load (k_row_blk_ptr )
392+ k_row = tl .load (k_row_blk_ptr , boundary_check = ( 0 ,) )
393393 lse_blk_ptr = tl .make_block_ptr (
394394 base = lse_ptr + pid_h ,
395395 shape = (T ,),
@@ -398,7 +398,7 @@ def parallel_deltaformer_bwd_kernel_row_sum(
398398 block_shape = (BLOCK_C ,),
399399 order = (0 ,),
400400 )
401- lse = tl .load (lse_blk_ptr )
401+ lse = tl .load (lse_blk_ptr , boundary_check = ( 0 ,) )
402402 grad_v_blk_ptr = tl .make_block_ptr (
403403 base = grad_v_ptr + pid_h * D ,
404404 shape = (T , D ),
@@ -407,7 +407,7 @@ def parallel_deltaformer_bwd_kernel_row_sum(
407407 block_shape = (BLOCK_C , D ),
408408 order = (1 , 0 ),
409409 )
410- grad_v_row = - tl .load (grad_v_blk_ptr )
410+ grad_v_row = - tl .load (grad_v_blk_ptr , boundary_check = ( 0 ,) )
411411
412412 for kv_i in range (0 , (pid_c + 1 ) * BLOCK_C , BLOCK_T ):
413413 k_blk_ptr = tl .make_block_ptr (
@@ -418,7 +418,7 @@ def parallel_deltaformer_bwd_kernel_row_sum(
418418 block_shape = (D , BLOCK_T ),
419419 order = (0 , 1 ),
420420 )
421- k = tl .load (k_blk_ptr )
421+ k = tl .load (k_blk_ptr , boundary_check = ( 1 ,) )
422422 qk = tl .dot (k_row , k ) * fa_scale
423423 p = tl .math .exp2 (qk - lse [:, None ])
424424
@@ -430,7 +430,7 @@ def parallel_deltaformer_bwd_kernel_row_sum(
430430 block_shape = (D , BLOCK_T ),
431431 order = (0 , 1 ),
432432 )
433- ut = tl .load (u_blk_ptr )
433+ ut = tl .load (u_blk_ptr , boundary_check = ( 1 ,) )
434434 dp = tl .dot (grad_v_row , ut )
435435 if kv_i + BLOCK_T >= pid_c * BLOCK_C :
436436 mask = (rowid_block [:, None ] <= colid_block [None , :] + kv_i )
@@ -445,7 +445,7 @@ def parallel_deltaformer_bwd_kernel_row_sum(
445445 block_shape = (BLOCK_C ,),
446446 order = (0 ,),
447447 )
448- tl .store (row_dot_block_ptr , acc )
448+ tl .store (row_dot_block_ptr , acc , boundary_check = ( 0 ,) )
449449
450450
451451@triton .autotune (configs = [triton .Config ({'BLOCK_C' : BC }, num_stages = ns , num_warps = nw )
@@ -484,7 +484,7 @@ def parallel_deltaformer_bwd_kernel_qk(
484484 block_shape = (BLOCK_C , D ),
485485 order = (1 , 0 ),
486486 )
487- k_row = tl .load (k_row_blk_ptr )
487+ k_row = tl .load (k_row_blk_ptr , boundary_check = ( 0 ,) )
488488 lse_blk_ptr = tl .make_block_ptr (
489489 base = lse_ptr + pid_h ,
490490 shape = (T ,),
@@ -493,7 +493,7 @@ def parallel_deltaformer_bwd_kernel_qk(
493493 block_shape = (BLOCK_C ,),
494494 order = (0 ,),
495495 )
496- lse = tl .load (lse_blk_ptr )
496+ lse = tl .load (lse_blk_ptr , boundary_check = ( 0 ,) )
497497 beta_blk_ptr = tl .make_block_ptr (
498498 base = beta_ptr + pid_h ,
499499 shape = (T ,),
@@ -502,7 +502,7 @@ def parallel_deltaformer_bwd_kernel_qk(
502502 block_shape = (BLOCK_C ,),
503503 order = (0 ,),
504504 )
505- beta = tl .load (beta_blk_ptr )
505+ beta = tl .load (beta_blk_ptr , boundary_check = ( 0 ,) )
506506 grad_v_blk_ptr = tl .make_block_ptr (
507507 base = grad_v_ptr + pid_h * D ,
508508 shape = (T , D ),
@@ -511,7 +511,7 @@ def parallel_deltaformer_bwd_kernel_qk(
511511 block_shape = (BLOCK_C , D ),
512512 order = (1 , 0 ),
513513 )
514- grad_v_row = - tl .load (grad_v_blk_ptr )
514+ grad_v_row = - tl .load (grad_v_blk_ptr , boundary_check = ( 0 ,) )
515515 row_dot_blk_ptr = tl .make_block_ptr (
516516 base = row_dot_ptr + pid_h ,
517517 shape = (T ,),
@@ -520,7 +520,7 @@ def parallel_deltaformer_bwd_kernel_qk(
520520 block_shape = (BLOCK_C ,),
521521 order = (0 ,),
522522 )
523- row_dot_row = tl .load (row_dot_blk_ptr ).to (k_ptr .dtype .element_ty )
523+ row_dot_row = tl .load (row_dot_blk_ptr , boundary_check = ( 0 ,) ).to (k_ptr .dtype .element_ty )
524524
525525 for kv_i in range (0 , pid_c * BLOCK_C , BLOCK_C ):
526526 k_blk_ptr = tl .make_block_ptr (
@@ -531,7 +531,7 @@ def parallel_deltaformer_bwd_kernel_qk(
531531 block_shape = (D , BLOCK_C ),
532532 order = (0 , 1 ),
533533 )
534- kt = tl .load (k_blk_ptr )
534+ kt = tl .load (k_blk_ptr , boundary_check = ( 1 ,) )
535535 qk = tl .dot (k_row , kt ) * fa_scale
536536 p = tl .math .exp2 (qk - lse [:, None ]) * beta [:, None ]
537537
@@ -557,7 +557,7 @@ def parallel_deltaformer_bwd_kernel_qk(
557557 block_shape = (BLOCK_C , D ),
558558 order = (1 , 0 ),
559559 )
560- k_row_true = tl .load (k_row_blk_ptr )
560+ k_row_true = tl .load (k_row_blk_ptr , boundary_check = ( 0 ,) )
561561 qk = tl .dot (k_row , tl .trans (k_row_true , 1 , 0 )) * fa_scale
562562 p = tl .math .exp2 (qk - lse [:, None ]) * beta [:, None ]
563563 u_blk_ptr = tl .make_block_ptr (
@@ -587,7 +587,7 @@ def parallel_deltaformer_bwd_kernel_qk(
587587 order = (1 , 0 ),
588588 )
589589 acc = acc * qk_scale
590- tl .store (grad_q_blk_ptr , acc .to (grad_q_ptr .dtype .element_ty ))
590+ tl .store (grad_q_blk_ptr , acc .to (grad_q_ptr .dtype .element_ty ), boundary_check = ( 0 ,) )
591591
592592 daat = tl .trans (da , 1 , 0 )
593593 acc = tl .dot (daat .to (k_row .dtype ), k_row )
@@ -602,7 +602,7 @@ def parallel_deltaformer_bwd_kernel_qk(
602602 block_shape = (D , BLOCK_C ),
603603 order = (0 , 1 ),
604604 )
605- kt = tl .load (k_blk_ptr )
605+ kt = tl .load (k_blk_ptr , boundary_check = ( 1 ,) )
606606 lse_blk_ptr = tl .make_block_ptr (
607607 base = lse_ptr + pid_h ,
608608 shape = (T ,),
@@ -611,7 +611,7 @@ def parallel_deltaformer_bwd_kernel_qk(
611611 block_shape = (BLOCK_C ,),
612612 order = (0 ,),
613613 )
614- lse = tl .load (lse_blk_ptr )
614+ lse = tl .load (lse_blk_ptr , boundary_check = ( 0 ,) )
615615 beta_blk_ptr = tl .make_block_ptr (
616616 base = beta_ptr + pid_h ,
617617 shape = (T ,),
@@ -620,7 +620,7 @@ def parallel_deltaformer_bwd_kernel_qk(
620620 block_shape = (BLOCK_C ,),
621621 order = (0 ,),
622622 )
623- beta = tl .load (beta_blk_ptr )
623+ beta = tl .load (beta_blk_ptr , boundary_check = ( 0 ,) )
624624 qk = tl .dot (k_row , kt ) * fa_scale
625625 p = tl .math .exp2 (qk - lse [None , :]) * beta [None , :]
626626
@@ -632,7 +632,7 @@ def parallel_deltaformer_bwd_kernel_qk(
632632 block_shape = (D , BLOCK_C ),
633633 order = (0 , 1 ),
634634 )
635- grad_vt = tl .load (grad_vt_blk_ptr )
635+ grad_vt = tl .load (grad_vt_blk_ptr , boundary_check = ( 1 ,) )
636636 row_dot_blk_ptr = tl .make_block_ptr (
637637 base = row_dot_ptr + pid_h ,
638638 shape = (T ,),
@@ -641,7 +641,7 @@ def parallel_deltaformer_bwd_kernel_qk(
641641 block_shape = (BLOCK_C ,),
642642 order = (0 ,),
643643 )
644- row_dot = tl .load (row_dot_blk_ptr ).to (k_ptr .dtype .element_ty )
644+ row_dot = tl .load (row_dot_blk_ptr , boundary_check = ( 0 ,) ).to (k_ptr .dtype .element_ty )
645645 dp = tl .dot (nu , grad_vt )
646646 da = p * (dp - row_dot [None , :])
647647 k = tl .trans (kt , 1 , 0 )
@@ -656,7 +656,7 @@ def parallel_deltaformer_bwd_kernel_qk(
656656 order = (1 , 0 ),
657657 )
658658 acc = acc * qk_scale
659- tl .store (grad_k_blk_ptr , acc .to (grad_k_ptr .dtype .element_ty ))
659+ tl .store (grad_k_blk_ptr , acc .to (grad_k_ptr .dtype .element_ty ), boundary_check = ( 0 ,) )
660660
661661
662662class ParallelDeltaformerFunction (torch .autograd .Function ):
@@ -872,6 +872,7 @@ def _forward_impl(
872872 w_t = w .transpose (0 , 1 ).contiguous ()
873873 u_chunk_view_t = u_chunk_view .transpose (0 , 1 ).contiguous ()
874874 invcum .forward_inplace (u_chunk_view_t , w_t )
875+ u_chunk_view .copy_ (u_chunk_view_t .transpose (0 , 1 ))
875876
876877 chunk_base += (T_max + C - 1 ) // C
877878
@@ -932,6 +933,7 @@ def _forward_impl(
932933 w_t = w .transpose (0 , 1 ).contiguous ()
933934 u_chunk_view_t = u_chunk_view .transpose (0 , 1 ).contiguous ()
934935 invcum .forward_inplace (u_chunk_view_t , w_t )
936+ u_chunk_view .copy_ (u_chunk_view_t .transpose (0 , 1 ))
935937
936938 chunk_base += (L + C - 1 ) // C
937939
@@ -953,7 +955,7 @@ def deltaformer_attn(
953955 B , T , H , D = k .shape
954956 C = min (C , T )
955957
956- u = ParallelDeltaformerFunction .apply (k , k , v , beta , C , cu_seqlens )
958+ u = ParallelDeltaformerFunction .apply (q , k , v , beta , C , cu_seqlens )
957959
958960 if attention_mask is not None :
959961 q_padded , (k_padded , u_padded ), indices_q , cu_seqlens_lens , max_seq_lens = unpad_input (q , (k , u ), attention_mask , T )
0 commit comments