Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 57 additions & 57 deletions benchmarks/backward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def dynamic_mask_attention_python(
A: torch.Tensor,
scaling: float,
cache_position: torch.Tensor,
dout: torch.Tensor,
keep_window_size=2048,
is_causal=True,
):
Expand All @@ -161,7 +160,6 @@ def dynamic_mask_attention_python(
A: [num_kv_heads]
scaling: Attention scaling factor
cache_position: Cache position for causal masking
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
keep_window_size: Number of tokens to keep in attention window
is_causal: Whether to apply causal masking

Expand Down Expand Up @@ -201,7 +199,7 @@ def dynamic_mask_attention_python(
attn_outputs = attn_outputs.transpose(1, 2).contiguous() # Transpose to [batch, query_len, num_heads, head_dim]

# Backward pass
attn_outputs.backward(dout)
attn_outputs.sum().backward()

return attn_outputs, query_states_leaf.grad, key_states_leaf.grad, value_states_leaf.grad, attn_bias_leaf.grad

Expand All @@ -214,7 +212,6 @@ def dynamic_mask_attention_cuda(
A: torch.Tensor,
scaling: float,
cache_position: torch.Tensor,
dout: torch.Tensor,
keep_window_size=2048,
is_causal=True,
):
Expand All @@ -229,7 +226,6 @@ def dynamic_mask_attention_cuda(
A: [num_kv_heads]
scaling: Attention scaling factor
cache_position: Cache position for causal masking
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
keep_window_size: Number of tokens to keep in attention window
is_causal: Whether to apply causal masking

Expand Down Expand Up @@ -263,7 +259,7 @@ def dynamic_mask_attention_cuda(
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]

# Call the flash_dmattn_func interface
attn_outputs, softmax_lse, S_dmask = flash_dmattn_func(
attn_outputs = flash_dmattn_func(
query=query_states, # q: [batch, query_len, num_heads, head_dim]
key=key_states, # k: [batch, key_len, num_kv_heads, head_dim]
value=value_states, # v: [batch, key_len, num_kv_heads, head_dim]
Expand All @@ -272,12 +268,12 @@ def dynamic_mask_attention_cuda(
is_causal=is_causal, # causal masking
scale=scaling, # scaling factor
softcap=0.0,
deterministic=True,
return_attn_probs=True
deterministic=False,
return_attn_probs=False
)

# Backward pass
attn_outputs.backward(dout)
attn_outputs.sum().backward()

return attn_outputs, query_states_leaf.grad, key_states_leaf.grad, value_states_leaf.grad, attn_bias_leaf.grad

Expand All @@ -290,7 +286,6 @@ def dynamic_mask_attention_triton(
A: torch.Tensor,
scaling: float,
cache_position: torch.Tensor,
dout: torch.Tensor,
keep_window_size=2048,
is_causal=True,
):
Expand All @@ -305,7 +300,6 @@ def dynamic_mask_attention_triton(
A: [num_kv_heads]
scaling: Attention scaling factor
cache_position: Cache position for causal masking
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
keep_window_size: Number of tokens to keep in attention window
is_causal: Whether to apply causal masking

Expand Down Expand Up @@ -361,7 +355,7 @@ def dynamic_mask_attention_triton(
)

# Backward pass
attn_outputs.backward(dout)
attn_outputs.sum().backward()

return attn_outputs, query_states_leaf.grad, key_states_leaf.grad, value_states_leaf.grad, attn_bias_leaf.grad

Expand All @@ -374,7 +368,6 @@ def dynamic_mask_attention_flex(
A: torch.Tensor,
scaling: float,
cache_position: torch.Tensor,
dout: torch.Tensor,
keep_window_size=2048,
is_causal=True,
):
Expand All @@ -389,7 +382,6 @@ def dynamic_mask_attention_flex(
A: [num_kv_heads]
scaling: Attention scaling factor
cache_position: Cache position for causal masking
dout: [batch_size, query_len, num_heads, head_dim] - gradient w.r.t. output
keep_window_size: Number of tokens to keep in attention window
is_causal: Whether to apply causal masking

Expand Down Expand Up @@ -436,7 +428,7 @@ def dynamic_mask_attention_flex(
)

# Backward pass
attn_outputs.backward(dout)
attn_outputs.sum().backward()

return attn_outputs, query_states.grad, key_states.grad, value_states.grad, attn_bias.grad

Expand Down Expand Up @@ -552,76 +544,90 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
test_configs = [
# Head dim 32
(1, 2, 1, 128, 128, 32, True),
(1, 2, 1, 128, 128, 32, False),
(1, 2, 1, 256, 256, 32, True),
(1, 2, 1, 128, 128, 32, True),
(1, 2, 1, 256, 256, 32, False),
(1, 2, 1, 512, 512, 32, True),
(1, 2, 1, 256, 256, 32, True),
(1, 2, 1, 512, 512, 32, False),
(1, 2, 1, 1024, 1024, 32, True),
(1, 2, 1, 512, 512, 32, True),
(1, 2, 1, 1024, 1024, 32, False),
(1, 2, 1, 2048, 2048, 32, True),
(1, 2, 1, 1024, 1024, 32, True),
(1, 2, 1, 2048, 2048, 32, False),
(1, 2, 1, 4096, 4096, 32, True), # some INF in dbias, Idk why
(1, 2, 1, 2048, 2048, 32, True),
(1, 2, 1, 4096, 4096, 32, False),
(1, 2, 1, 4096, 4096, 32, True),

# Head dim 64
(1, 2, 1, 128, 128, 64, True),
(1, 2, 1, 128, 128, 64, False),
(1, 2, 1, 256, 256, 64, True), # some INF in dbias, Idk why
(1, 2, 1, 128, 128, 64, True),
(1, 2, 1, 256, 256, 64, False),
(1, 2, 1, 512, 512, 64, True),
(1, 2, 1, 256, 256, 64, True),
(1, 2, 1, 512, 512, 64, False),
(1, 2, 1, 1024, 1024, 64, True), # some INF in dbias, Idk why
(1, 2, 1, 512, 512, 64, True),
(1, 2, 1, 1024, 1024, 64, False),
(1, 2, 1, 2048, 2048, 64, True),
(1, 2, 1, 1024, 1024, 64, True),
(1, 2, 1, 2048, 2048, 64, False),
(1, 2, 1, 4096, 4096, 64, True),
(1, 2, 1, 2048, 2048, 64, True),
(1, 2, 1, 4096, 4096, 64, False),
(1, 2, 1, 4096, 4096, 64, True),

# Head dim 96
(1, 2, 1, 128, 128, 96, True),
(1, 2, 1, 128, 128, 96, False),
(1, 2, 1, 256, 256, 96, True),
(1, 2, 1, 128, 128, 96, True),
(1, 2, 1, 256, 256, 96, False),
(1, 2, 1, 512, 512, 96, True),
(1, 2, 1, 256, 256, 96, True),
(1, 2, 1, 512, 512, 96, False),
(1, 2, 1, 1024, 1024, 96, True), # some INF in dbias, Idk why
(1, 2, 1, 512, 512, 96, True),
(1, 2, 1, 1024, 1024, 96, False),
(1, 2, 1, 2048, 2048, 96, True),
(1, 2, 1, 1024, 1024, 96, True),
(1, 2, 1, 2048, 2048, 96, False),
(1, 2, 1, 4096, 4096, 96, True),
(1, 2, 1, 2048, 2048, 96, True),
(1, 2, 1, 4096, 4096, 96, False),
(1, 2, 1, 4096, 4096, 96, True),

# Head dim 128
(1, 2, 1, 128, 128, 128, True),
(1, 2, 1, 128, 128, 128, False),
(1, 2, 1, 256, 256, 128, True),
(1, 2, 1, 128, 128, 128, True),
(1, 2, 1, 256, 256, 128, False),
(1, 2, 1, 512, 512, 128, True),
(1, 2, 1, 256, 256, 128, True),
(1, 2, 1, 512, 512, 128, False),
(1, 2, 1, 1024, 1024, 128, True),
(1, 2, 1, 512, 512, 128, True),
(1, 2, 1, 1024, 1024, 128, False),
(1, 2, 1, 2048, 2048, 128, True),
(1, 2, 1, 1024, 1024, 128, True),
(1, 2, 1, 2048, 2048, 128, False),
(1, 2, 1, 4096, 4096, 128, True),
(1, 2, 1, 2048, 2048, 128, True),
(1, 2, 1, 4096, 4096, 128, False),
(1, 2, 1, 4096, 4096, 128, True),

# # Head dim 192
# Not enough shared memory for head_dim=192 in bwd yet
# (1, 2, 1, 128, 128, 192, False),
# (1, 2, 1, 128, 128, 192, True),
# (1, 2, 1, 256, 256, 192, False),
# (1, 2, 1, 256, 256, 192, True),
# (1, 2, 1, 512, 512, 192, False),
# (1, 2, 1, 512, 512, 192, True),
# (1, 2, 1, 1024, 1024, 192, False),
# (1, 2, 1, 1024, 1024, 192, True),
# (1, 2, 1, 2048, 2048, 192, False),
# (1, 2, 1, 2048, 2048, 192, True),
# (1, 2, 1, 4096, 4096, 192, False),
# (1, 2, 1, 4096, 4096, 192, True),

# Head dim 256
# Because fwd uses splitkv branch, this branch does not support head_dim=256 for now
# For head_dim=256, besides the reason of splitkv branch, bwd itself does not support it, not enough shared memory
# (1, 2, 1, 128, 128, 256, True),
# Not enough shared memory for head_dim=256 in bwd yet
# (1, 2, 1, 128, 128, 256, False),
# (1, 2, 1, 256, 256, 256, True),
# (1, 2, 1, 128, 128, 256, True),
# (1, 2, 1, 256, 256, 256, False),
# (1, 2, 1, 512, 512, 256, True),
# (1, 2, 1, 256, 256, 256, True),
# (1, 2, 1, 512, 512, 256, False),
# (1, 2, 1, 1024, 1024, 256, True),
# (1, 2, 1, 512, 512, 256, True),
# (1, 2, 1, 1024, 1024, 256, False),
# (1, 2, 1, 2048, 2048, 256, True),
# (1, 2, 1, 1024, 1024, 256, True),
# (1, 2, 1, 2048, 2048, 256, False),
# (1, 2, 1, 4096, 4096, 256, True),
# (1, 2, 1, 2048, 2048, 256, True),
# (1, 2, 1, 4096, 4096, 256, False),
# (1, 2, 1, 4096, 4096, 256, True),
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -673,13 +679,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):

# Set scaling factor and keep window size
scaling = head_dim ** -0.5
keep_window_size = 64

# Create gradient for output
dout = torch.randn(
batch_size, query_len, num_heads, head_dim,
device=device, dtype=dtype
)
keep_window_size = 1024

# Clone inputs for Python implementation
query_python = query_states.clone().detach().requires_grad_(True)
Expand All @@ -692,7 +692,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
start_time = time.time()
attn_outputs_python, dq_python, dk_python, dv_python, dbias_python = dynamic_mask_attention_python(
query_python, key_python, value_python, dt_proj_python, A_python,
scaling, cache_position, dout.clone(), keep_window_size, is_causal
scaling, cache_position, keep_window_size, is_causal
)
torch.cuda.synchronize()
py_time = time.time() - start_time
Expand All @@ -709,7 +709,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
start_time = time.time()
attn_outputs_cuda, dq_cuda, dk_cuda, dv_cuda, dbias_cuda = dynamic_mask_attention_cuda(
query_cuda, key_cuda, value_cuda, dt_proj_cuda, A_cuda,
scaling, cache_position, dout.clone(), keep_window_size, is_causal
scaling, cache_position, keep_window_size, is_causal
)
torch.cuda.synchronize()
cuda_time = time.time() - start_time
Expand Down Expand Up @@ -774,7 +774,7 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
if not is_close and max_dbias_diff > 1e-2:
print(" ⚠️ Difference too large, stopping subsequent tests.")
break
del query_states, key_states, value_states, dt_proj, A, cache_position, dout, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda
del query_states, key_states, value_states, dt_proj, A, cache_position, dq_python, dk_python, dv_python, dbias_python, dq_cuda, dk_cuda, dv_cuda, dbias_cuda
torch.cuda.empty_cache()
gc.collect()
torch.cuda.synchronize()
Expand Down
38 changes: 5 additions & 33 deletions benchmarks/backward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,6 @@ def scaled_dot_product_attention_backward(
value_states = value_states.contiguous()

try:
# Create gradient for output
batch_size, num_heads, query_len, head_dim = query_states.shape
dout = torch.randn(
batch_size, query_len, num_heads, head_dim,
device=query_states.device, dtype=query_states.dtype
)

# Forward pass - SDPA expects q, k, v in [batch, num_heads, seq_len, head_dim] format
attn_outputs = F.scaled_dot_product_attention(
query_states, # [batch, num_heads, query_len, head_dim]
Expand All @@ -201,8 +194,8 @@ def scaled_dot_product_attention_backward(
start_time = time.time()

# Backward pass
attn_outputs.backward(dout)
attn_outputs.sum().backward()

torch.cuda.synchronize()
end_time = time.time()

Expand Down Expand Up @@ -261,13 +254,6 @@ def dynamic_mask_attention_backward_cuda(
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_kv_heads, head_dim]

try:
# Create gradient for output
batch_size, query_len, num_heads, head_dim = query_states.shape
dout = torch.randn(
batch_size, query_len, num_heads, head_dim,
device=query_states.device, dtype=query_states.dtype
)

# Call the flash_dmattn_func interface
attn_outputs = flash_dmattn_func(
query=query_states, # q: [batch, query_len, num_heads, head_dim]
Expand All @@ -286,7 +272,7 @@ def dynamic_mask_attention_backward_cuda(
start_time = time.time()

# Backward pass
attn_outputs.backward(dout)
attn_outputs.sum().backward()

torch.cuda.synchronize()
end_time = time.time()
Expand Down Expand Up @@ -356,13 +342,6 @@ def dynamic_mask_attention_backward_triton(
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim]
attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]
attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]

# Create gradient for output
batch_size, query_len, num_heads, head_dim = query_states.shape
dout = torch.randn(
batch_size, query_len, num_heads, head_dim,
device=query_states.device, dtype=query_states.dtype
)

# Call the Triton implementation
attn_outputs = triton_dmattn_func(
Expand All @@ -379,7 +358,7 @@ def dynamic_mask_attention_backward_triton(
start_time = time.time()

# Backward pass
attn_outputs.backward(dout)
attn_outputs.sum().backward()

torch.cuda.synchronize()
end_time = time.time()
Expand Down Expand Up @@ -445,13 +424,6 @@ def dynamic_mask_attention_backward_flex(

# Flex attention expects: q, k, v in [batch, num_heads, seqlen, head_dim] format
# But attention_mask and attention_bias in [batch, num_heads, query_len, key_len] format

# Create gradient for output
batch_size, query_len, head_dim = query_states.shape[0], query_states.shape[2], query_states.shape[3]
dout = torch.randn(
batch_size, query_len, num_heads, head_dim,
device=query_states.device, dtype=query_states.dtype
)

# Call the Flex Attention implementation
attn_outputs = flex_dmattn_func(
Expand All @@ -468,7 +440,7 @@ def dynamic_mask_attention_backward_flex(
start_time = time.time()

# Backward pass
attn_outputs.backward(dout)
attn_outputs.sum().backward()

torch.cuda.synchronize()
end_time = time.time()
Expand Down
Loading