Skip to content

Commit 03b24a3

Browse files
committed
Refactor backward pass in dynamic mask attention functions to remove dout parameter and use sum() for gradient computation
1 parent 931bcdc commit 03b24a3

File tree

1 file changed

+57
-57
lines changed

1 file changed

+57
-57
lines changed

benchmarks/backward_equivalence.py

Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ def dynamic_mask_attention_python(
146146
A: torch.Tensor,
147147
scaling: float,
148148
cache_position: torch.Tensor,
149-
dout: torch.Tensor,
150149
keep_window_size=2048,
151150
is_causal=True,
152151
):
@@ -161,7 +160,6 @@ def dynamic_mask_attention_python(
161160
A: [num_kv_heads]
162161
scaling: Attention scaling factor
163162
cache_position: Cache position for causal masking
164-
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
165163
keep_window_size: Number of tokens to keep in attention window
166164
is_causal: Whether to apply causal masking
167165
@@ -201,7 +199,7 @@ def dynamic_mask_attention_python(
201199
attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim]
202200

203201
# Backward pass
204-
attn_outputs.backward(dout)
202+
attn_outputs.sum().backward()
205203

206204
return attn_outputs, query_states_leaf.grad, key_states_leaf.grad, value_states_leaf.grad, attn_bias_leaf.grad
207205

@@ -214,7 +212,6 @@ def dynamic_mask_attention_cuda(
214212
A: torch.Tensor,
215213
scaling: float,
216214
cache_position: torch.Tensor,
217-
dout: torch.Tensor,
218215
keep_window_size=2048,
219216
is_causal=True,
220217
):
@@ -229,7 +226,6 @@ def dynamic_mask_attention_cuda(
229226
A: [num_kv_heads]
230227
scaling: Attention scaling factor
231228
cache_position: Cache position for causal masking
232-
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
233229
keep_window_size: Number of tokens to keep in attention window
234230
is_causal: Whether to apply causal masking
235231
@@ -263,7 +259,7 @@ def dynamic_mask_attention_cuda(
263259
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]
264260

265261
# Call the flash_dmattn_func interface
266-
attn_outputs, softmax_lse, S_dmask = flash_dmattn_func(
262+
attn_outputs = flash_dmattn_func(
267263
query=query_states, # q: [batch, query_len, num_heads, head_dim]
268264
key=key_states, # k: [batch, key_len, num_kv_heads, head_dim]
269265
value=value_states, # v: [batch, key_len, num_kv_heads, head_dim]
@@ -272,12 +268,12 @@ def dynamic_mask_attention_cuda(
272268
is_causal=is_causal, # causal masking
273269
scale=scaling, # scaling factor
274270
softcap=0.0,
275-
deterministic=True,
276-
return_attn_probs=True
271+
deterministic=False,
272+
return_attn_probs=False
277273
)
278274

279275
# Backward pass
280-
attn_outputs.backward(dout)
276+
attn_outputs.sum().backward()
281277

282278
return attn_outputs, query_states_leaf.grad, key_states_leaf.grad, value_states_leaf.grad, attn_bias_leaf.grad
283279

@@ -290,7 +286,6 @@ def dynamic_mask_attention_triton(
290286
A: torch.Tensor,
291287
scaling: float,
292288
cache_position: torch.Tensor,
293-
dout: torch.Tensor,
294289
keep_window_size=2048,
295290
is_causal=True,
296291
):
@@ -305,7 +300,6 @@ def dynamic_mask_attention_triton(
305300
A: [num_kv_heads]
306301
scaling: Attention scaling factor
307302
cache_position: Cache position for causal masking
308-
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
309303
keep_window_size: Number of tokens to keep in attention window
310304
is_causal: Whether to apply causal masking
311305
@@ -361,7 +355,7 @@ def dynamic_mask_attention_triton(
361355
)
362356

363357
# Backward pass
364-
attn_outputs.backward(dout)
358+
attn_outputs.sum().backward()
365359

366360
return attn_outputs, query_states_leaf.grad, key_states_leaf.grad, value_states_leaf.grad, attn_bias_leaf.grad
367361

@@ -374,7 +368,6 @@ def dynamic_mask_attention_flex(
374368
A: torch.Tensor,
375369
scaling: float,
376370
cache_position: torch.Tensor,
377-
dout: torch.Tensor,
378371
keep_window_size=2048,
379372
is_causal=True,
380373
):
@@ -389,7 +382,6 @@ def dynamic_mask_attention_flex(
389382
A: [num_kv_heads]
390383
scaling: Attention scaling factor
391384
cache_position: Cache position for causal masking
392-
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
393385
keep_window_size: Number of tokens to keep in attention window
394386
is_causal: Whether to apply causal masking
395387
@@ -436,7 +428,7 @@ def dynamic_mask_attention_flex(
436428
)
437429

438430
# Backward pass
439-
attn_outputs.backward(dout)
431+
attn_outputs.sum().backward()
440432

441433
return attn_outputs, query_states.grad, key_states.grad, value_states.grad, attn_bias.grad
442434

@@ -552,76 +544,90 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
552544
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
553545
test_configs = [
554546
# Head dim 32
555-
(1, 2, 1, 128, 128, 32, True),
556547
(1, 2, 1, 128, 128, 32, False),
557-
(1, 2, 1, 256, 256, 32, True),
548+
(1, 2, 1, 128, 128, 32, True),
558549
(1, 2, 1, 256, 256, 32, False),
559-
(1, 2, 1, 512, 512, 32, True),
550+
(1, 2, 1, 256, 256, 32, True),
560551
(1, 2, 1, 512, 512, 32, False),
561-
(1, 2, 1, 1024, 1024, 32, True),
552+
(1, 2, 1, 512, 512, 32, True),
562553
(1, 2, 1, 1024, 1024, 32, False),
563-
(1, 2, 1, 2048, 2048, 32, True),
554+
(1, 2, 1, 1024, 1024, 32, True),
564555
(1, 2, 1, 2048, 2048, 32, False),
565-
(1, 2, 1, 4096, 4096, 32, True), # some INF in dbias, Idk why
556+
(1, 2, 1, 2048, 2048, 32, True),
566557
(1, 2, 1, 4096, 4096, 32, False),
558+
(1, 2, 1, 4096, 4096, 32, True),
567559

568560
# Head dim 64
569-
(1, 2, 1, 128, 128, 64, True),
570561
(1, 2, 1, 128, 128, 64, False),
571-
(1, 2, 1, 256, 256, 64, True), # some INF in dbias, Idk why
562+
(1, 2, 1, 128, 128, 64, True),
572563
(1, 2, 1, 256, 256, 64, False),
573-
(1, 2, 1, 512, 512, 64, True),
564+
(1, 2, 1, 256, 256, 64, True),
574565
(1, 2, 1, 512, 512, 64, False),
575-
(1, 2, 1, 1024, 1024, 64, True), # some INF in dbias, Idk why
566+
(1, 2, 1, 512, 512, 64, True),
576567
(1, 2, 1, 1024, 1024, 64, False),
577-
(1, 2, 1, 2048, 2048, 64, True),
568+
(1, 2, 1, 1024, 1024, 64, True),
578569
(1, 2, 1, 2048, 2048, 64, False),
579-
(1, 2, 1, 4096, 4096, 64, True),
570+
(1, 2, 1, 2048, 2048, 64, True),
580571
(1, 2, 1, 4096, 4096, 64, False),
572+
(1, 2, 1, 4096, 4096, 64, True),
581573

582574
# Head dim 96
583-
(1, 2, 1, 128, 128, 96, True),
584575
(1, 2, 1, 128, 128, 96, False),
585-
(1, 2, 1, 256, 256, 96, True),
576+
(1, 2, 1, 128, 128, 96, True),
586577
(1, 2, 1, 256, 256, 96, False),
587-
(1, 2, 1, 512, 512, 96, True),
578+
(1, 2, 1, 256, 256, 96, True),
588579
(1, 2, 1, 512, 512, 96, False),
589-
(1, 2, 1, 1024, 1024, 96, True), # some INF in dbias, Idk why
580+
(1, 2, 1, 512, 512, 96, True),
590581
(1, 2, 1, 1024, 1024, 96, False),
591-
(1, 2, 1, 2048, 2048, 96, True),
582+
(1, 2, 1, 1024, 1024, 96, True),
592583
(1, 2, 1, 2048, 2048, 96, False),
593-
(1, 2, 1, 4096, 4096, 96, True),
584+
(1, 2, 1, 2048, 2048, 96, True),
594585
(1, 2, 1, 4096, 4096, 96, False),
586+
(1, 2, 1, 4096, 4096, 96, True),
595587

596588
# Head dim 128
597-
(1, 2, 1, 128, 128, 128, True),
598589
(1, 2, 1, 128, 128, 128, False),
599-
(1, 2, 1, 256, 256, 128, True),
590+
(1, 2, 1, 128, 128, 128, True),
600591
(1, 2, 1, 256, 256, 128, False),
601-
(1, 2, 1, 512, 512, 128, True),
592+
(1, 2, 1, 256, 256, 128, True),
602593
(1, 2, 1, 512, 512, 128, False),
603-
(1, 2, 1, 1024, 1024, 128, True),
594+
(1, 2, 1, 512, 512, 128, True),
604595
(1, 2, 1, 1024, 1024, 128, False),
605-
(1, 2, 1, 2048, 2048, 128, True),
596+
(1, 2, 1, 1024, 1024, 128, True),
606597
(1, 2, 1, 2048, 2048, 128, False),
607-
(1, 2, 1, 4096, 4096, 128, True),
598+
(1, 2, 1, 2048, 2048, 128, True),
608599
(1, 2, 1, 4096, 4096, 128, False),
600+
(1, 2, 1, 4096, 4096, 128, True),
601+
602+
# # Head dim 192
603+
# Not enough shared memory for head_dim=192 in bwd yet
604+
# (1, 2, 1, 128, 128, 192, False),
605+
# (1, 2, 1, 128, 128, 192, True),
606+
# (1, 2, 1, 256, 256, 192, False),
607+
# (1, 2, 1, 256, 256, 192, True),
608+
# (1, 2, 1, 512, 512, 192, False),
609+
# (1, 2, 1, 512, 512, 192, True),
610+
# (1, 2, 1, 1024, 1024, 192, False),
611+
# (1, 2, 1, 1024, 1024, 192, True),
612+
# (1, 2, 1, 2048, 2048, 192, False),
613+
# (1, 2, 1, 2048, 2048, 192, True),
614+
# (1, 2, 1, 4096, 4096, 192, False),
615+
# (1, 2, 1, 4096, 4096, 192, True),
609616

610617
# Head dim 256
611-
# Because fwd uses splitkv branch, this branch does not support head_dim=256 for now
612-
# For head_dim=256, besides the reason of splitkv branch, bwd itself does not support it, not enough shared memory
613-
# (1, 2, 1, 128, 128, 256, True),
618+
# Not enough shared memory for head_dim=256 in bwd yet
614619
# (1, 2, 1, 128, 128, 256, False),
615-
# (1, 2, 1, 256, 256, 256, True),
620+
# (1, 2, 1, 128, 128, 256, True),
616621
# (1, 2, 1, 256, 256, 256, False),
617-
# (1, 2, 1, 512, 512, 256, True),
622+
# (1, 2, 1, 256, 256, 256, True),
618623
# (1, 2, 1, 512, 512, 256, False),
619-
# (1, 2, 1, 1024, 1024, 256, True),
624+
# (1, 2, 1, 512, 512, 256, True),
620625
# (1, 2, 1, 1024, 1024, 256, False),
621-
# (1, 2, 1, 2048, 2048, 256, True),
626+
# (1, 2, 1, 1024, 1024, 256, True),
622627
# (1, 2, 1, 2048, 2048, 256, False),
623-
# (1, 2, 1, 4096, 4096, 256, True),
628+
# (1, 2, 1, 2048, 2048, 256, True),
624629
# (1, 2, 1, 4096, 4096, 256, False),
630+
# (1, 2, 1, 4096, 4096, 256, True),
625631
]
626632

627633
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -673,13 +679,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
673679

674680
# Set scaling factor and keep window size
675681
scaling = head_dim ** -0.5
676-
keep_window_size = 64
677-
678-
# Create gradient for output
679-
dout = torch.randn(
680-
batch_size, query_len, num_heads, head_dim,
681-
device=device, dtype=dtype
682-
)
682+
keep_window_size = 1024
683683

684684
# Clone inputs for Python implementation
685685
query_python = query_states.clone().detach().requires_grad_(True)
@@ -692,7 +692,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
692692
start_time = time.time()
693693
attn_outputs_python, dq_python, dk_python, dv_python, dbias_python = dynamic_mask_attention_python(
694694
query_python, key_python, value_python, dt_proj_python, A_python,
695-
scaling, cache_position, dout.clone(), keep_window_size, is_causal
695+
scaling, cache_position, keep_window_size, is_causal
696696
)
697697
torch.cuda.synchronize()
698698
py_time = time.time() - start_time
@@ -709,7 +709,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
709709
start_time = time.time()
710710
attn_outputs_cuda, dq_cuda, dk_cuda, dv_cuda, dbias_cuda = dynamic_mask_attention_cuda(
711711
query_cuda, key_cuda, value_cuda, dt_proj_cuda, A_cuda,
712-
scaling, cache_position, dout.clone(), keep_window_size, is_causal
712+
scaling, cache_position, keep_window_size, is_causal
713713
)
714714
torch.cuda.synchronize()
715715
cuda_time = time.time() - start_time
@@ -774,7 +774,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
774774
if not is_close and max_dbias_diff > 1e-2:
775775
print(" ⚠️ Difference too large, stopping subsequent tests.")
776776
break
777-
del query_states, key_states, value_states, dt_proj, A, cache_position, dout, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda
777+
del query_states, key_states, value_states, dt_proj, A, cache_position, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda
778778
torch.cuda.empty_cache()
779779
gc.collect()
780780
torch.cuda.synchronize()

0 commit comments

Comments
 (0)