We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4e505b2 commit 1050261Copy full SHA for 1050261
benchmarks/backward_equivalence.py
@@ -191,7 +191,7 @@ def dynamic_mask_attention_python(
191
value_states = repeat_kv(value_states, num_queries_per_kv)
192
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
193
attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv)
194
-
+
195
attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1))
196
attn_weights = attn_weights * scaling + attn_bias # Apply scaling and zoh
197
attn_weights = F.softmax(attn_weights, dim=-1) # Softmax normalization
0 commit comments