Skip to content

Commit 5c18bec

Browse files
RMSNorm backward kernel implementaton (#709)
* the first version of rmsnorm bwd kernel implementation * rename variables to be more representitive * fix dg shared accross batch dimension * add dg reduction kernel and restructure rmsnorm forward kernel and backward kernel * change to wrapper approach * replace wrapper with torch.autograd * fix fwd benchmark * add benchmark for rmsnorm backward * use Michalel's new Integration script to test * revert back to orginal, as it won't run * change grad_sum init and remove epsilon from backward argument list * add num_warps to backward kernel * updates to do block reduction * change grid_sum init * fix rmsnomr tests for fp32 * dx, dg is the same type of x, and g, and fix unit test for backward kernel * make sure norm_factor is tl.float32 * switch off fp32 test
1 parent eb7e015 commit 5c18bec

File tree

2 files changed

+312
-33
lines changed

2 files changed

+312
-33
lines changed

.github/workflows/amd_perf_kernel_Integration_tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ jobs:
126126
run: |
127127
python ./python/perf-kernels/flash-attention.py
128128
python ./python/perf-kernels/softmax.py
129-
python ./python/perf-kernels/rmsnorm.py
129+
python ./python/perf-kernels/rmsnorm.py --mode fwd
130+
python ./python/perf-kernels/rmsnorm.py --mode bwd
130131
python ./python/perf-kernels/layernorm.py
131132
python ./python/perf-kernels/multreduce_matmul_kernel.py bench

0 commit comments

Comments
 (0)