Skip to content

Commit 1050261

Browse files
committed
Fix formatting in dynamic_mask_attention_python function
1 parent 4e505b2 commit 1050261

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

benchmarks/backward_equivalence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def dynamic_mask_attention_python(
191191
value_states = repeat_kv(value_states, num_queries_per_kv)
192192
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
193193
attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv)
194-
194+
195195
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1))
196196
attn_weights = attn_weights * scaling + attn_bias # Apply scaling and zoh
197197
attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization

0 commit comments

Comments
 (0)