Commit 5c18bec
authored
RMSNorm backward kernel implementaton (#709)
* the first version of rmsnorm bwd kernel implementation
* rename variables to be more representitive
* fix dg shared accross batch dimension
* add dg reduction kernel and restructure rmsnorm forward kernel and backward kernel
* change to wrapper approach
* replace wrapper with torch.autograd
* fix fwd benchmark
* add benchmark for rmsnorm backward
* use Michalel's new Integration script to test
* revert back to orginal, as it won't run
* change grad_sum init and remove epsilon from backward argument list
* add num_warps to backward kernel
* updates to do block reduction
* change grid_sum init
* fix rmsnomr tests for fp32
* dx, dg is the same type of x, and g, and fix unit test for backward kernel
* make sure norm_factor is tl.float32
* switch off fp32 test1 parent eb7e015 commit 5c18bec
File tree
2 files changed
+312
-33
lines changed- .github/workflows
- python/perf-kernels
2 files changed
+312
-33
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
126 | 126 | | |
127 | 127 | | |
128 | 128 | | |
129 | | - | |
| 129 | + | |
| 130 | + | |
130 | 131 | | |
131 | 132 | | |
0 commit comments