You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: tools/llm/torchtrt_ext/register_sdpa.py
+4-1Lines changed: 4 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -1,6 +1,7 @@
1
1
importcopy
2
2
importlogging
3
3
importoperator
4
+
fromreimportI
4
5
fromtypingimportCallable, Sequence, Tuple
5
6
6
7
importtorch
@@ -89,7 +90,9 @@ def replace_variants_of_sdpa(
89
90
logger.warning(
90
91
f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations."
# TODO: lan to figure out why is_causal is always False in google/gemma-3-1b-it, as in the config file it should be every 5 sliding window layer followed by a full attention layer
94
+
# also to figure out why the attn_mask passed in from transformers is not working
0 commit comments