@@ -146,7 +146,6 @@ def dynamic_mask_attention_python(
146146 A : torch .Tensor ,
147147 scaling : float ,
148148 cache_position : torch .Tensor ,
149- dout : torch .Tensor ,
150149 keep_window_size = 2048 ,
151150 is_causal = True ,
152151):
@@ -161,7 +160,6 @@ def dynamic_mask_attention_python(
161160 A: [num_kv_heads]
162161 scaling: Attention scaling factor
163162 cache_position: Cache position for causal masking
164- dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
165163 keep_window_size: Number of tokens to keep in attention window
166164 is_causal: Whether to apply causal masking
167165
@@ -201,7 +199,7 @@ def dynamic_mask_attention_python(
201199 attn_outputs = attn_outputs .transpose (1 , 2 ).contiguous () # Transpose to [batch, query_len, num_heads, head_dim]
202200
203201 # Backward pass
204- attn_outputs .backward (dout )
202+ attn_outputs .sum (). backward ()
205203
206204 return attn_outputs , query_states_leaf .grad , key_states_leaf .grad , value_states_leaf .grad , attn_bias_leaf .grad
207205
@@ -214,7 +212,6 @@ def dynamic_mask_attention_cuda(
214212 A : torch .Tensor ,
215213 scaling : float ,
216214 cache_position : torch .Tensor ,
217- dout : torch .Tensor ,
218215 keep_window_size = 2048 ,
219216 is_causal = True ,
220217):
@@ -229,7 +226,6 @@ def dynamic_mask_attention_cuda(
229226 A: [num_kv_heads]
230227 scaling: Attention scaling factor
231228 cache_position: Cache position for causal masking
232- dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
233229 keep_window_size: Number of tokens to keep in attention window
234230 is_causal: Whether to apply causal masking
235231
@@ -263,7 +259,7 @@ def dynamic_mask_attention_cuda(
263259 value_states = value_states .transpose (1 , 2 ).contiguous () # [batch, key_len, num_kv_heads, head_dim]
264260
265261 # Call the flash_dmattn_func interface
266- attn_outputs , softmax_lse , S_dmask = flash_dmattn_func (
262+ attn_outputs = flash_dmattn_func (
267263 query = query_states , # q: [batch, query_len, num_heads, head_dim]
268264 key = key_states , # k: [batch, key_len, num_kv_heads, head_dim]
269265 value = value_states , # v: [batch, key_len, num_kv_heads, head_dim]
@@ -272,12 +268,12 @@ def dynamic_mask_attention_cuda(
272268 is_causal = is_causal , # causal masking
273269 scale = scaling , # scaling factor
274270 softcap = 0.0 ,
275- deterministic = True ,
276- return_attn_probs = True
271+ deterministic = False ,
272+ return_attn_probs = False
277273 )
278274
279275 # Backward pass
280- attn_outputs .backward (dout )
276+ attn_outputs .sum (). backward ()
281277
282278 return attn_outputs , query_states_leaf .grad , key_states_leaf .grad , value_states_leaf .grad , attn_bias_leaf .grad
283279
@@ -290,7 +286,6 @@ def dynamic_mask_attention_triton(
290286 A : torch .Tensor ,
291287 scaling : float ,
292288 cache_position : torch .Tensor ,
293- dout : torch .Tensor ,
294289 keep_window_size = 2048 ,
295290 is_causal = True ,
296291):
@@ -305,7 +300,6 @@ def dynamic_mask_attention_triton(
305300 A: [num_kv_heads]
306301 scaling: Attention scaling factor
307302 cache_position: Cache position for causal masking
308- dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
309303 keep_window_size: Number of tokens to keep in attention window
310304 is_causal: Whether to apply causal masking
311305
@@ -361,7 +355,7 @@ def dynamic_mask_attention_triton(
361355 )
362356
363357 # Backward pass
364- attn_outputs .backward (dout )
358+ attn_outputs .sum (). backward ()
365359
366360 return attn_outputs , query_states_leaf .grad , key_states_leaf .grad , value_states_leaf .grad , attn_bias_leaf .grad
367361
@@ -374,7 +368,6 @@ def dynamic_mask_attention_flex(
374368 A : torch .Tensor ,
375369 scaling : float ,
376370 cache_position : torch .Tensor ,
377- dout : torch .Tensor ,
378371 keep_window_size = 2048 ,
379372 is_causal = True ,
380373):
@@ -389,7 +382,6 @@ def dynamic_mask_attention_flex(
389382 A: [num_kv_heads]
390383 scaling: Attention scaling factor
391384 cache_position: Cache position for causal masking
392- dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
393385 keep_window_size: Number of tokens to keep in attention window
394386 is_causal: Whether to apply causal masking
395387
@@ -436,7 +428,7 @@ def dynamic_mask_attention_flex(
436428 )
437429
438430 # Backward pass
439- attn_outputs .backward (dout )
431+ attn_outputs .sum (). backward ()
440432
441433 return attn_outputs , query_states .grad , key_states .grad , value_states .grad , attn_bias .grad
442434
@@ -552,76 +544,90 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
552544 # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
553545 test_configs = [
554546 # Head dim 32
555- (1 , 2 , 1 , 128 , 128 , 32 , True ),
556547 (1 , 2 , 1 , 128 , 128 , 32 , False ),
557- (1 , 2 , 1 , 256 , 256 , 32 , True ),
548+ (1 , 2 , 1 , 128 , 128 , 32 , True ),
558549 (1 , 2 , 1 , 256 , 256 , 32 , False ),
559- (1 , 2 , 1 , 512 , 512 , 32 , True ),
550+ (1 , 2 , 1 , 256 , 256 , 32 , True ),
560551 (1 , 2 , 1 , 512 , 512 , 32 , False ),
561- (1 , 2 , 1 , 1024 , 1024 , 32 , True ),
552+ (1 , 2 , 1 , 512 , 512 , 32 , True ),
562553 (1 , 2 , 1 , 1024 , 1024 , 32 , False ),
563- (1 , 2 , 1 , 2048 , 2048 , 32 , True ),
554+ (1 , 2 , 1 , 1024 , 1024 , 32 , True ),
564555 (1 , 2 , 1 , 2048 , 2048 , 32 , False ),
565- (1 , 2 , 1 , 4096 , 4096 , 32 , True ), # some INF in dbias, Idk why
556+ (1 , 2 , 1 , 2048 , 2048 , 32 , True ),
566557 (1 , 2 , 1 , 4096 , 4096 , 32 , False ),
558+ (1 , 2 , 1 , 4096 , 4096 , 32 , True ),
567559
568560 # Head dim 64
569- (1 , 2 , 1 , 128 , 128 , 64 , True ),
570561 (1 , 2 , 1 , 128 , 128 , 64 , False ),
571- (1 , 2 , 1 , 256 , 256 , 64 , True ), # some INF in dbias, Idk why
562+ (1 , 2 , 1 , 128 , 128 , 64 , True ),
572563 (1 , 2 , 1 , 256 , 256 , 64 , False ),
573- (1 , 2 , 1 , 512 , 512 , 64 , True ),
564+ (1 , 2 , 1 , 256 , 256 , 64 , True ),
574565 (1 , 2 , 1 , 512 , 512 , 64 , False ),
575- (1 , 2 , 1 , 1024 , 1024 , 64 , True ), # some INF in dbias, Idk why
566+ (1 , 2 , 1 , 512 , 512 , 64 , True ),
576567 (1 , 2 , 1 , 1024 , 1024 , 64 , False ),
577- (1 , 2 , 1 , 2048 , 2048 , 64 , True ),
568+ (1 , 2 , 1 , 1024 , 1024 , 64 , True ),
578569 (1 , 2 , 1 , 2048 , 2048 , 64 , False ),
579- (1 , 2 , 1 , 4096 , 4096 , 64 , True ),
570+ (1 , 2 , 1 , 2048 , 2048 , 64 , True ),
580571 (1 , 2 , 1 , 4096 , 4096 , 64 , False ),
572+ (1 , 2 , 1 , 4096 , 4096 , 64 , True ),
581573
582574 # Head dim 96
583- (1 , 2 , 1 , 128 , 128 , 96 , True ),
584575 (1 , 2 , 1 , 128 , 128 , 96 , False ),
585- (1 , 2 , 1 , 256 , 256 , 96 , True ),
576+ (1 , 2 , 1 , 128 , 128 , 96 , True ),
586577 (1 , 2 , 1 , 256 , 256 , 96 , False ),
587- (1 , 2 , 1 , 512 , 512 , 96 , True ),
578+ (1 , 2 , 1 , 256 , 256 , 96 , True ),
588579 (1 , 2 , 1 , 512 , 512 , 96 , False ),
589- (1 , 2 , 1 , 1024 , 1024 , 96 , True ), # some INF in dbias, Idk why
580+ (1 , 2 , 1 , 512 , 512 , 96 , True ),
590581 (1 , 2 , 1 , 1024 , 1024 , 96 , False ),
591- (1 , 2 , 1 , 2048 , 2048 , 96 , True ),
582+ (1 , 2 , 1 , 1024 , 1024 , 96 , True ),
592583 (1 , 2 , 1 , 2048 , 2048 , 96 , False ),
593- (1 , 2 , 1 , 4096 , 4096 , 96 , True ),
584+ (1 , 2 , 1 , 2048 , 2048 , 96 , True ),
594585 (1 , 2 , 1 , 4096 , 4096 , 96 , False ),
586+ (1 , 2 , 1 , 4096 , 4096 , 96 , True ),
595587
596588 # Head dim 128
597- (1 , 2 , 1 , 128 , 128 , 128 , True ),
598589 (1 , 2 , 1 , 128 , 128 , 128 , False ),
599- (1 , 2 , 1 , 256 , 256 , 128 , True ),
590+ (1 , 2 , 1 , 128 , 128 , 128 , True ),
600591 (1 , 2 , 1 , 256 , 256 , 128 , False ),
601- (1 , 2 , 1 , 512 , 512 , 128 , True ),
592+ (1 , 2 , 1 , 256 , 256 , 128 , True ),
602593 (1 , 2 , 1 , 512 , 512 , 128 , False ),
603- (1 , 2 , 1 , 1024 , 1024 , 128 , True ),
594+ (1 , 2 , 1 , 512 , 512 , 128 , True ),
604595 (1 , 2 , 1 , 1024 , 1024 , 128 , False ),
605- (1 , 2 , 1 , 2048 , 2048 , 128 , True ),
596+ (1 , 2 , 1 , 1024 , 1024 , 128 , True ),
606597 (1 , 2 , 1 , 2048 , 2048 , 128 , False ),
607- (1 , 2 , 1 , 4096 , 4096 , 128 , True ),
598+ (1 , 2 , 1 , 2048 , 2048 , 128 , True ),
608599 (1 , 2 , 1 , 4096 , 4096 , 128 , False ),
600+ (1 , 2 , 1 , 4096 , 4096 , 128 , True ),
601+
602+ # # Head dim 192
603+ # Not enough shared memory for head_dim=192 in bwd yet
604+ # (1, 2, 1, 128, 128, 192, False),
605+ # (1, 2, 1, 128, 128, 192, True),
606+ # (1, 2, 1, 256, 256, 192, False),
607+ # (1, 2, 1, 256, 256, 192, True),
608+ # (1, 2, 1, 512, 512, 192, False),
609+ # (1, 2, 1, 512, 512, 192, True),
610+ # (1, 2, 1, 1024, 1024, 192, False),
611+ # (1, 2, 1, 1024, 1024, 192, True),
612+ # (1, 2, 1, 2048, 2048, 192, False),
613+ # (1, 2, 1, 2048, 2048, 192, True),
614+ # (1, 2, 1, 4096, 4096, 192, False),
615+ # (1, 2, 1, 4096, 4096, 192, True),
609616
610617 # Head dim 256
611- # Because fwd uses splitkv branch, this branch does not support head_dim=256 for now
612- # For head_dim=256, besides the reason of splitkv branch, bwd itself does not support it, not enough shared memory
613- # (1, 2, 1, 128, 128, 256, True),
618+ # Not enough shared memory for head_dim=256 in bwd yet
614619 # (1, 2, 1, 128, 128, 256, False),
615- # (1, 2, 1, 256, 256 , 256, True),
620+ # (1, 2, 1, 128, 128 , 256, True),
616621 # (1, 2, 1, 256, 256, 256, False),
617- # (1, 2, 1, 512, 512 , 256, True),
622+ # (1, 2, 1, 256, 256 , 256, True),
618623 # (1, 2, 1, 512, 512, 256, False),
619- # (1, 2, 1, 1024, 1024 , 256, True),
624+ # (1, 2, 1, 512, 512 , 256, True),
620625 # (1, 2, 1, 1024, 1024, 256, False),
621- # (1, 2, 1, 2048, 2048 , 256, True),
626+ # (1, 2, 1, 1024, 1024 , 256, True),
622627 # (1, 2, 1, 2048, 2048, 256, False),
623- # (1, 2, 1, 4096, 4096 , 256, True),
628+ # (1, 2, 1, 2048, 2048 , 256, True),
624629 # (1, 2, 1, 4096, 4096, 256, False),
630+ # (1, 2, 1, 4096, 4096, 256, True),
625631 ]
626632
627633 device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
@@ -673,13 +679,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
673679
674680 # Set scaling factor and keep window size
675681 scaling = head_dim ** - 0.5
676- keep_window_size = 64
677-
678- # Create gradient for output
679- dout = torch .randn (
680- batch_size , query_len , num_heads , head_dim ,
681- device = device , dtype = dtype
682- )
682+ keep_window_size = 1024
683683
684684 # Clone inputs for Python implementation
685685 query_python = query_states .clone ().detach ().requires_grad_ (True )
@@ -692,7 +692,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
692692 start_time = time .time ()
693693 attn_outputs_python , dq_python , dk_python , dv_python , dbias_python = dynamic_mask_attention_python (
694694 query_python , key_python , value_python , dt_proj_python , A_python ,
695- scaling , cache_position , dout . clone (), keep_window_size , is_causal
695+ scaling , cache_position , keep_window_size , is_causal
696696 )
697697 torch .cuda .synchronize ()
698698 py_time = time .time () - start_time
@@ -709,7 +709,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
709709 start_time = time .time ()
710710 attn_outputs_cuda , dq_cuda , dk_cuda , dv_cuda , dbias_cuda = dynamic_mask_attention_cuda (
711711 query_cuda , key_cuda , value_cuda , dt_proj_cuda , A_cuda ,
712- scaling , cache_position , dout . clone (), keep_window_size , is_causal
712+ scaling , cache_position , keep_window_size , is_causal
713713 )
714714 torch .cuda .synchronize ()
715715 cuda_time = time .time () - start_time
@@ -774,7 +774,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
774774 if not is_close and max_dbias_diff > 1e-2 :
775775 print (" ⚠️ Difference too large, stopping subsequent tests." )
776776 break
777- del query_states , key_states , value_states , dt_proj , A , cache_position , dout , dq_python , dk_python , dv_python , dbias_python , dq_cuda , dk_cuda , dv_cuda , dbias_cuda
777+ del query_states , key_states , value_states , dt_proj , A , cache_position , dq_python , dk_python , dv_python , dbias_python , dq_cuda , dk_cuda , dv_cuda , dbias_cuda
778778 torch .cuda .empty_cache ()
779779 gc .collect ()
780780 torch .cuda .synchronize ()
0 commit comments