Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 239 additions & 8 deletions benchmarks/backward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,9 @@ def dynamic_mask_attention_triton(
attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv)

# Ensure correct data types and memory layout for Triton function
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim]
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim]
attn_mask = attn_mask.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]
attn_bias = attn_bias.contiguous() # [batch, num_heads, seqlen_q, seqlen_k]
query_states = query_states.transpose(1, 2) # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim]
value_states = value_states.transpose(1, 2) # [batch, key_len, num_heads, head_dim]

# Call the Triton implementation
attn_outputs = triton_dmattn_func(
Expand Down Expand Up @@ -729,6 +727,239 @@ def test_cuda_backward_equivalence(accuracy_threshold=0.95):
return all_passed


def test_triton_backward_equivalence(accuracy_threshold=0.95):
"""Test backward pass equivalence between Python prototype and Triton implementation."""
print("\n" + "🚀" + "=" * 76 + "🚀")
print("🔬 Testing backward Pass Equivalence: Python Prototype vs Triton Implementation")
print("🚀" + "=" * 76 + "🚀")

# Check if Triton implementation is available
if triton_dmattn_func is None:
print("❌ Triton implementation not available, skipping test.")
return False

# Set random seed for reproducibility
torch.manual_seed(0)

# Test different parameter configurations
# If you encounter NAN issues when running multiple configurations, try running a single configuration
# (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal)
test_configs = [
# Head dim 32
(1, 2, 1, 128, 128, 32, False),
(1, 2, 1, 128, 128, 32, True),
(1, 2, 1, 256, 256, 32, False),
(1, 2, 1, 256, 256, 32, True),
(1, 2, 1, 512, 512, 32, False),
(1, 2, 1, 512, 512, 32, True),
(1, 2, 1, 1024, 1024, 32, False),
(1, 2, 1, 1024, 1024, 32, True),
(1, 2, 1, 2048, 2048, 32, False),
(1, 2, 1, 2048, 2048, 32, True),
(1, 2, 1, 4096, 4096, 32, False),
(1, 2, 1, 4096, 4096, 32, True),

# Head dim 64
(1, 2, 1, 128, 128, 64, False),
(1, 2, 1, 128, 128, 64, True),
(1, 2, 1, 256, 256, 64, False),
(1, 2, 1, 256, 256, 64, True),
(1, 2, 1, 512, 512, 64, False),
(1, 2, 1, 512, 512, 64, True),
(1, 2, 1, 1024, 1024, 64, False),
(1, 2, 1, 1024, 1024, 64, True),
(1, 2, 1, 2048, 2048, 64, False),
(1, 2, 1, 2048, 2048, 64, True),
(1, 2, 1, 4096, 4096, 64, False),
(1, 2, 1, 4096, 4096, 64, True),

# Head dim 96
(1, 2, 1, 128, 128, 96, False),
(1, 2, 1, 128, 128, 96, True),
(1, 2, 1, 256, 256, 96, False),
(1, 2, 1, 256, 256, 96, True),
(1, 2, 1, 512, 512, 96, False),
(1, 2, 1, 512, 512, 96, True),
(1, 2, 1, 1024, 1024, 96, False),
(1, 2, 1, 1024, 1024, 96, True),
(1, 2, 1, 2048, 2048, 96, False),
(1, 2, 1, 2048, 2048, 96, True),
(1, 2, 1, 4096, 4096, 96, False),
(1, 2, 1, 4096, 4096, 96, True),

# Head dim 128
(1, 2, 1, 128, 128, 128, False),
(1, 2, 1, 128, 128, 128, True),
(1, 2, 1, 256, 256, 128, False),
(1, 2, 1, 256, 256, 128, True),
(1, 2, 1, 512, 512, 128, False),
(1, 2, 1, 512, 512, 128, True),
(1, 2, 1, 1024, 1024, 128, False),
(1, 2, 1, 1024, 1024, 128, True),
(1, 2, 1, 2048, 2048, 128, False),
(1, 2, 1, 2048, 2048, 128, True),
(1, 2, 1, 4096, 4096, 128, False),
(1, 2, 1, 4096, 4096, 128, True),

# triton currently supports up to head dim 128
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
device_icon = "🔥" if device.type == "cuda" else "💻"
print(f"{device_icon} Using device: {device}")

all_passed = True

for i, config in enumerate(test_configs):
torch.cuda.empty_cache()
gc.collect()
torch.cuda.synchronize()

batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal = config

# Progress indicator
progress_filled = "█" * (i + 1)
progress_empty = "░" * (len(test_configs) - i - 1)
progress_bar = f"[{progress_filled}{progress_empty}]"

print(f"\n🧪 Test configuration {i+1}/{len(test_configs)} {progress_bar}")
print(f" 📊 batch_size={batch_size}, num_heads={num_heads}, num_kv_heads={num_kv_heads}")
print(f" 📏 query_len={query_len}, key_len={key_len}, head_dim={head_dim}")
print(f" 🔒 is_causal={is_causal}")
print(f" 🎯 Accuracy threshold: {accuracy_threshold*100:.1f}%")

# Create random input data
query_states = torch.randn(
batch_size, num_heads, query_len, head_dim,
device=device, dtype=dtype, requires_grad=True
)
key_states = torch.randn(
batch_size, num_kv_heads, key_len, head_dim,
device=device, dtype=dtype, requires_grad=True
)
value_states = torch.randn(
batch_size, num_kv_heads, key_len, head_dim,
device=device, dtype=dtype, requires_grad=True
)
attn_bias = torch.randn(
batch_size, num_kv_heads, query_len, key_len,
device=device, dtype=torch.bfloat16
)
cache_position = torch.arange(key_len - query_len, key_len, device=device)
causal_mask = torch.arange(key_len, device=device) <= cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)

# Set scaling factor and keep window size
scaling = head_dim ** -0.5
window_size = 10240

# Clone inputs for Python implementation
query_python = query_states.clone().detach().requires_grad_(True)
key_python = key_states.clone().detach().requires_grad_(True)
value_python = value_states.clone().detach().requires_grad_(True)
attn_bias_python = attn_bias.clone().detach().requires_grad_(True)
causal_mask_python = causal_mask.clone().detach()

# Run Python implementation
start_time = time.time()
attn_outputs_python, dq_python, dk_python, dv_python, dbias_python = dynamic_mask_attention_python(
query_python, key_python, value_python,
attn_bias_python, causal_mask_python,
scaling, window_size, is_causal
)
torch.cuda.synchronize()
py_time = time.time() - start_time

# Clone inputs for Triton implementation
query_triton = query_states.clone().detach().requires_grad_(True)
key_triton = key_states.clone().detach().requires_grad_(True)
value_triton = value_states.clone().detach().requires_grad_(True)
attn_bias_triton = attn_bias.clone().detach().requires_grad_(True)
causal_mask_triton = causal_mask.clone().detach()

# Run Triton implementation
start_time = time.time()
attn_outputs_triton, dq_triton, dk_triton, dv_triton, dbias_triton = dynamic_mask_attention_triton(
query_triton, key_triton, value_triton,
attn_bias_triton, causal_mask_triton,
scaling, window_size, is_causal
)
torch.cuda.synchronize()
triton_time = time.time() - start_time

# Analyze outputs
print(f"\n🔍 Analyzing differences between Python and Triton outputs:")
is_attn_output_close, max_attn_output_diff, mean_attn_output_diff = analyze_differences(
attn_outputs_python, attn_outputs_triton, accuracy_threshold
)

# Analyze dQ gradients
print(f"\n🔍 Analyzing dQ gradients:")
is_dq_close, max_dq_diff, mean_dq_diff = analyze_differences(
dq_python, dq_triton, accuracy_threshold
)

# Analyze dK gradients
print(f"\n🔍 Analyzing dK gradients:")
is_dk_close, max_dk_diff, mean_dk_diff = analyze_differences(
dk_python, dk_triton, accuracy_threshold
)

# Analyze dV gradients
print(f"\n🔍 Analyzing dV gradients:")
is_dv_close, max_dv_diff, mean_dv_diff = analyze_differences(
dv_python, dv_triton, accuracy_threshold
)

# Analyze dBias gradients
print(f"\n🔍 Analyzing dBias gradients:")
is_dbias_close, max_dbias_diff, mean_dbias_diff = analyze_differences(
dbias_python, dbias_triton, accuracy_threshold
)

# Report performance difference
speedup = py_time / triton_time if triton_time > 0 else float('inf')
print(f"\n⚡ Performance comparison:")
print(f" 🐍 Python implementation: {py_time*1000:.2f} ms")
print(f" 🚀 Triton implementation: {triton_time*1000:.2f} ms")
print(f" 📈 Speedup: {speedup:.2f}x")

# Check if all gradients pass
is_close = (is_attn_output_close and is_dq_close and is_dk_close and is_dv_close and is_dbias_close)
test_result = "Passed" if is_close else "Failed"
result_icon = "✅" if is_close else "❌"
all_passed = all_passed and is_close
print(f"\n{result_icon} Test result: {test_result}")

# If test fails with large difference, can exit early
if not is_close and max_attn_output_diff > 1e-2:
print(" ⚠️ Difference too large, stopping subsequent tests.")
break
if not is_close and max_dq_diff > 1e-2:
print(" ⚠️ Difference too large, stopping subsequent tests.")
break
if not is_close and max_dk_diff > 1e-2:
print(" ⚠️ Difference too large, stopping subsequent tests.")
break
if not is_close and max_dv_diff > 1e-2:
print(" ⚠️ Difference too large, stopping subsequent tests.")
break
if not is_close and max_dbias_diff > 1e-2:
print(" ⚠️ Difference too large, stopping subsequent tests.")
break
del query_states, key_states, value_states, attn_bias, causal_mask, cache_position, dq_python, dk_python, dv_python, dbias_python, dq_triton, dk_triton, dv_triton, dbias_triton
torch.cuda.empty_cache()
gc.collect()
torch.cuda.synchronize()

print("\n" + "🏁" + "=" * 76 + "🏁")
summary_icon = "🎉" if all_passed else "😞"
print(f"{summary_icon} Backward Equivalence Test Summary: {'All Passed' if all_passed else 'Some Tests Failed'}")
print("🏁" + "=" * 76 + "🏁")

return all_passed

def main():
"""
Test backward pass equivalence between Python prototype and various implementations
Expand Down Expand Up @@ -782,9 +1013,9 @@ def main():
print("\n" + "📍" + " Starting Python vs CUDA Backward Tests " + "📍")
test_results['cuda'] = test_cuda_backward_equivalence(args.accuracy_threshold)

# if args.test_type in ['all', 'triton']:
# print("\n" + "🔥" + " Starting Python vs Triton Backward Tests " + "🔥")
# test_results['triton'] = test_triton_backward_equivalence(args.accuracy_threshold)
if args.test_type in ['all', 'triton']:
print("\n" + "🔥" + " Starting Python vs Triton Backward Tests " + "🔥")
test_results['triton'] = test_triton_backward_equivalence(args.accuracy_threshold)

# if args.test_type in ['all', 'flex']:
# print("\n" + "🌟" + " Starting Python vs Flex Attention Backward Tests " + "🌟")
Expand Down
Loading
Loading