@@ -80,9 +80,9 @@ def optimized_attention_kernel(
8080 if value .dtype != mx .float32 :
8181 value = value .astype (mx .float32 )
8282
83- # Determine scale factor
83+ # Determine scale factor - make sure it matches reference implementation
8484 if scale_strategy == "sqrt_dk" :
85- scale = 1.0 / math .sqrt (d_model )
85+ scale = 1.0 / math .sqrt (d_model ) # This should match reference
8686 elif scale_strategy == "learned" :
8787 # Slightly different scale as a heuristic
8888 scale = 0.9 / math .sqrt (d_model )
@@ -92,24 +92,20 @@ def optimized_attention_kernel(
9292 # For now, implement basic attention to ensure correctness
9393 # More complex optimizations will be evolved
9494
95- # Compute attention scores
96- scores = mx .matmul (query , mx .transpose (key , axes = (0 , 2 , 1 )))
97-
98- # Apply scaling
99- scores = scores * scale
95+ # Compute attention scores - match reference implementation exactly
96+ if scale_strategy == "sqrt_dk" :
97+ # Match reference exactly: scores = matmul(...) / sqrt(d_k)
98+ scores = mx .matmul (query , mx .transpose (key , axes = (0 , 2 , 1 ))) / math .sqrt (d_model )
99+ else :
100+ # For other strategies, compute separately
101+ scores = mx .matmul (query , mx .transpose (key , axes = (0 , 2 , 1 )))
102+ scores = scores * scale
100103
101- # Apply mask if provided
104+ # Apply mask if provided - match reference implementation
102105 if mask is not None :
103- # Ensure mask has the right shape and dtype
104- if mask .shape != scores .shape :
105- # Handle different mask shapes - broadcast if needed
106- if len (mask .shape ) == 2 : # [seq_len, seq_len]
107- mask = mx .broadcast_to (mask [None , :, :], scores .shape )
108- elif len (mask .shape ) == 3 and mask .shape [0 ] == 1 : # [1, seq_len, seq_len]
109- mask = mx .broadcast_to (mask , scores .shape )
110-
111- mask_value = - 1e9 if compute_dtype == mx .float32 else - 1e4
112- scores = scores + mask * mask_value
106+ # Reference implementation does: scores = scores + mask
107+ # So mask should already contain the large negative values
108+ scores = scores + mask
113109
114110 # Compute attention weights (always use high precision initially)
115111 attention_weights = mx .softmax (scores , axis = - 1 )
@@ -138,18 +134,13 @@ def _chunked_attention(
138134 """
139135 # For now, fall back to standard attention to ensure correctness
140136 # Evolution will implement proper chunking
141- scores = mx .matmul (query , mx .transpose (key , axes = (0 , 2 , 1 )))
142- scores = scores * scale
137+ d_model = query .shape [- 1 ]
138+
139+ # Match reference implementation exactly
140+ scores = mx .matmul (query , mx .transpose (key , axes = (0 , 2 , 1 ))) / math .sqrt (d_model )
143141
144142 if mask is not None :
145- if mask .shape != scores .shape :
146- if len (mask .shape ) == 2 : # [seq_len, seq_len]
147- mask = mx .broadcast_to (mask [None , :, :], scores .shape )
148- elif len (mask .shape ) == 3 and mask .shape [0 ] == 1 : # [1, seq_len, seq_len]
149- mask = mx .broadcast_to (mask , scores .shape )
150-
151- mask_value = - 1e9 if scores .dtype == mx .float32 else - 1e4
152- scores = scores + mask * mask_value
143+ scores = scores + mask
153144
154145 attention_weights = mx .softmax (scores , axis = - 1 )
155146 output = mx .matmul (attention_weights , value )
@@ -230,7 +221,7 @@ def benchmark_attention(
230221
231222 # Create causal mask for decoder attention
232223 mask = mx .triu (mx .ones ((seq_len , seq_len )), k = 1 ) * - 1e9
233- mask = mx .broadcast_to (mask [None , None , :, :], (batch_size , 1 , seq_len , seq_len ))
224+ mask = mx .broadcast_to (mask [None , :, :], (batch_size , seq_len , seq_len ))
234225
235226 # Warmup
236227 _ = optimized_attention_kernel (query , key , value , mask , ** config )
0 commit comments