Skip to content

Commit cecdee8

Browse files
committed
Update initial_program.py
1 parent f9a8f0f commit cecdee8

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

examples/mlx_metal_kernel_opt/initial_program.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
5454
mask_tensor = mask.astype(mx.bool_)
5555
use_mask = True
5656
else:
57-
# Fallback for unsupported mask types
58-
return mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale, mask=mask)
57+
# Raise error for unsupported mask types - no fallback
58+
raise ValueError(f"Unsupported mask type: {type(mask)}. Custom kernel requires None, 'causal', or mx.array mask.")
5959

6060
# Expand mask to match batch and head dimensions if needed
6161
if mask_tensor.ndim == 2:
@@ -231,9 +231,9 @@ def qwen3_custom_gqa_attention(queries, keys, values, scale=1.0, mask=None):
231231
return outputs[0]
232232

233233
except Exception as e:
234-
# Fallback to standard MLX implementation if custom kernel fails
235-
print(f"⚠️ Custom GQA kernel failed: {e}, falling back to MLX SPDA")
236-
return mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale, mask=mask)
234+
# No fallback - let the custom kernel failure propagate for proper scoring
235+
print(f" Custom GQA kernel failed: {e}")
236+
raise RuntimeError(f"Custom Metal kernel execution failed: {e}") from e
237237

238238

239239
class CustomGQAAttention(nn.Module):

0 commit comments

Comments
 (0)