Skip to content

Commit f9de81c

Browse files
committed
Update initial_program.py
1 parent 0a9f073 commit f9de81c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

examples/mlx_metal_kernel_opt/initial_program.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
236236
return mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale, mask=mask)
237237

238238

239-
class CustomMetalGQAAttention(nn.Module):
239+
class CustomGQAAttention(nn.Module):
240240
"""
241241
Qwen3 attention module with custom Metal kernel optimization.
242242
@@ -332,7 +332,7 @@ def apply_optimization_hook():
332332
original_attention = qwen3_module.Attention
333333

334334
# Replace with Metal optimized implementation
335-
qwen3_module.Attention = CustomMetalGQAAttention
335+
qwen3_module.Attention = CustomGQAAttention
336336

337337
print("✅ Applied Custom Metal GQA Attention hook")
338338
return original_attention
@@ -384,7 +384,7 @@ class MockArgs:
384384
print("=" * 70)
385385

386386
# Initialize Metal optimized attention
387-
metal_attn = CustomMetalGQAAttention(args)
387+
metal_attn = CustomGQAAttention(args)
388388

389389
for config_name, batch_size, seq_len, hidden_size in test_configs:
390390
print(f"\nTesting {config_name}: B={batch_size}, L={seq_len}")
@@ -443,7 +443,7 @@ class MockArgs:
443443
mask = "causal"
444444

445445
# Test Metal optimized implementation
446-
metal_attn = CustomMetalGQAAttention(args)
446+
metal_attn = CustomGQAAttention(args)
447447
output = metal_attn(x, mask=mask)
448448

449449
print(f"✅ Metal GQA output shape: {output.shape}")

0 commit comments

Comments
 (0)