Commit 0e1d9fa
authored
[JAX] Bug fix for distributed normalization (NVIDIA#1366)
* fix ctx.aval_out indexing for workspace
* add cudnn init to prepare phase of norm custom calls
* add thread_local for norm registry instance
---------
Signed-off-by: Phuong Nguyen <[email protected]>1 parent e4c99b0 commit 0e1d9fa
File tree
3 files changed
+25
-14
lines changed- transformer_engine
- common/normalization
- jax
- cpp_extensions
- csrc/extensions
3 files changed
+25
-14
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
287 | 287 | | |
288 | 288 | | |
289 | 289 | | |
290 | | - | |
291 | 290 | | |
292 | | - | |
| 291 | + | |
293 | 292 | | |
294 | 293 | | |
295 | 294 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
147 | 147 | | |
148 | 148 | | |
149 | 149 | | |
150 | | - | |
| 150 | + | |
151 | 151 | | |
152 | 152 | | |
153 | 153 | | |
| |||
441 | 441 | | |
442 | 442 | | |
443 | 443 | | |
444 | | - | |
| 444 | + | |
445 | 445 | | |
446 | 446 | | |
447 | 447 | | |
| |||
650 | 650 | | |
651 | 651 | | |
652 | 652 | | |
653 | | - | |
| 653 | + | |
654 | 654 | | |
655 | 655 | | |
656 | 656 | | |
| |||
841 | 841 | | |
842 | 842 | | |
843 | 843 | | |
844 | | - | |
| 844 | + | |
845 | 845 | | |
846 | 846 | | |
847 | 847 | | |
| |||
1088 | 1088 | | |
1089 | 1089 | | |
1090 | 1090 | | |
1091 | | - | |
| 1091 | + | |
1092 | 1092 | | |
1093 | 1093 | | |
1094 | 1094 | | |
| |||
1394 | 1394 | | |
1395 | 1395 | | |
1396 | 1396 | | |
1397 | | - | |
| 1397 | + | |
1398 | 1398 | | |
1399 | 1399 | | |
1400 | 1400 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
83 | 83 | | |
84 | 84 | | |
85 | 85 | | |
86 | | - | |
87 | | - | |
88 | | - | |
89 | | - | |
90 | | - | |
91 | | - | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
92 | 104 | | |
93 | 105 | | |
94 | 106 | | |
| |||
0 commit comments