Commit d978e80
authored
Fix attention mask type for Flash Attention + CP + THD (NVIDIA#1354)
* always have padding mask type for both flash and fused attentions
Signed-off-by: Xiaowei Ren <[email protected]>
* remove an redundant assert
Signed-off-by: Xiaowei Ren <[email protected]>
---------
Signed-off-by: Xiaowei Ren <[email protected]>1 parent 8c00424 commit d978e80
File tree
2 files changed
+4
-12
lines changed- tests/pytorch/fused_attn
- transformer_engine/pytorch
2 files changed
+4
-12
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
42 | 42 | | |
43 | 43 | | |
44 | 44 | | |
45 | | - | |
| 45 | + | |
46 | 46 | | |
47 | 47 | | |
48 | 48 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4309 | 4309 | | |
4310 | 4310 | | |
4311 | 4311 | | |
4312 | | - | |
4313 | | - | |
4314 | | - | |
4315 | | - | |
4316 | | - | |
4317 | | - | |
4318 | | - | |
4319 | | - | |
4320 | 4312 | | |
4321 | 4313 | | |
4322 | 4314 | | |
| |||
7878 | 7870 | | |
7879 | 7871 | | |
7880 | 7872 | | |
| 7873 | + | |
| 7874 | + | |
| 7875 | + | |
7881 | 7876 | | |
7882 | 7877 | | |
7883 | 7878 | | |
| |||
7904 | 7899 | | |
7905 | 7900 | | |
7906 | 7901 | | |
7907 | | - | |
7908 | | - | |
7909 | | - | |
7910 | 7902 | | |
7911 | 7903 | | |
7912 | 7904 | | |
| |||
0 commit comments