@@ -236,7 +236,7 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
236236 return mx .fast .scaled_dot_product_attention (queries , keys , values , scale = scale , mask = mask )
237237
238238
239- class CustomMetalGQAAttention (nn .Module ):
239+ class CustomGQAAttention (nn .Module ):
240240 """
241241 Qwen3 attention module with custom Metal kernel optimization.
242242
@@ -332,7 +332,7 @@ def apply_optimization_hook():
332332 original_attention = qwen3_module .Attention
333333
334334 # Replace with Metal optimized implementation
335- qwen3_module .Attention = CustomMetalGQAAttention
335+ qwen3_module .Attention = CustomGQAAttention
336336
337337 print ("✅ Applied Custom Metal GQA Attention hook" )
338338 return original_attention
@@ -384,7 +384,7 @@ class MockArgs:
384384 print ("=" * 70 )
385385
386386 # Initialize Metal optimized attention
387- metal_attn = CustomMetalGQAAttention (args )
387+ metal_attn = CustomGQAAttention (args )
388388
389389 for config_name , batch_size , seq_len , hidden_size in test_configs :
390390 print (f"\n Testing { config_name } : B={ batch_size } , L={ seq_len } " )
@@ -443,7 +443,7 @@ class MockArgs:
443443 mask = "causal"
444444
445445 # Test Metal optimized implementation
446- metal_attn = CustomMetalGQAAttention (args )
446+ metal_attn = CustomGQAAttention (args )
447447 output = metal_attn (x , mask = mask )
448448
449449 print (f"✅ Metal GQA output shape: { output .shape } " )
0 commit comments