Skip to content

Commit 4f8f767

Browse files
authored
add eager attention pattern that does not cast attn weight to fp 32 (#121)
Signed-off-by: Frida Hou <[email protected]>
1 parent b47abb4 commit 4f8f767

File tree

1 file changed

+23
-0
lines changed
  • tensorrt_llm/_torch/auto_deploy/transformations/library

1 file changed

+23
-0
lines changed

tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,28 @@ def _sfdp_replacement_7(query, key, value, attention_mask, scaling, dropout):
297297
)
298298

299299

300+
# with causal_mask, no division, does not cast to fp32 for softmax
301+
def _sfdp_pattern_8(query, key, value, attention_mask, scaling, dropout):
302+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
303+
attn_weights = attn_weights + attention_mask
304+
attn_weights = F.softmax(attn_weights, dim=-1)
305+
attn_weights = F.dropout(attn_weights, p=dropout, training=False)
306+
attn_output = torch.matmul(attn_weights, value)
307+
return attn_output
308+
309+
310+
def _sfdp_replacement_8(query, key, value, attention_mask, scaling, dropout):
311+
return torch.ops.auto_deploy.torch_attention_sdpa.default(
312+
query,
313+
key,
314+
value,
315+
attn_mask=None,
316+
dropout_p=dropout,
317+
is_causal=True,
318+
scale=scaling,
319+
)
320+
321+
300322
def _get_sfdp_patterns() -> List[Dict[str, Any]]:
301323
bs, seq_len, n_heads, hidden_size = 8, 16, 8, 512
302324
head_dim = hidden_size // n_heads
@@ -315,6 +337,7 @@ def causal_mask():
315337
(_sfdp_pattern_5, _sfdp_replacement_5, False, 0.874321, 0.89734),
316338
(_sfdp_pattern_6, _sfdp_replacement_6, True, 0.634743, 0.6849734),
317339
(_sfdp_pattern_7, _sfdp_replacement_7, True, 0.34743, 0.849734),
340+
(_sfdp_pattern_8, _sfdp_replacement_8, True, 0.2234743, 0.95849734),
318341
]
319342

320343
patterns = []

0 commit comments

Comments
 (0)