[COMMS] Fused allreduce, residual add, rms, quant, gemm#2238
[COMMS] Fused allreduce, residual add, rms, quant, gemm#2238micmelesse wants to merge 1 commit intomainfrom
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
|
ROCm/iris#443 needs to merge before this can merge. |
9f7b803 to
5f5382a
Compare
There was a problem hiding this comment.
Pull request overview
Adds a new ROCm/Triton communication+compute path that fuses tensor-parallel AllReduce + residual add + RMSNorm + per-row FP8 quantization into a single Iris-backed kernel launch, followed by a torch._scaled_mm (hipBLASLt) GEMM. It also updates Triton comms packaging behavior around optional Iris availability and includes a new multi-GPU correctness test and benchmark script.
Changes:
- Add
persistent_fused_allreduce_rmsnorm_2d_quant_two_shotTriton kernel + Python manager API (fused_allreduce_add_rms_quant_gemm) using Iris symmetric heap. - Add a multi-GPU correctness test (eager + CUDA graph replay) and a
torchrunbenchmark driver. - Update Iris dependency and adjust
aiter.ops.triton.commsimport behavior / availability flagging.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| requirements-triton-comms.txt | Switches Iris dependency to a branch ref for the fused comms work. |
| aiter/ops/triton/comms/fused_allreduce_add_rms_quant_gemm.py | New fused AllReduce+RMSNorm+FP8-quant kernel + manager + NCCL reference path. |
| aiter/ops/triton/comms/init.py | Makes comms importable without Iris via IRIS_COMM_AVAILABLE flag. |
| aiter/ops/triton/init.py | Imports comms unconditionally and re-exports IRIS_COMM_AVAILABLE. |
| op_tests/multigpu_tests/triton_test/test_fused_allreduce_add_rms_quant_gemm.py | New pytest correctness test covering eager + CUDA graph replay. |
| op_tests/multigpu_tests/triton_test/bench_fused_allreduce_add_rms_quant_gemm.py | New torchrun benchmark comparing fused vs unfused baseline. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
op_tests/multigpu_tests/triton_test/test_fused_allreduce_add_rms_quant_gemm.py
Show resolved
Hide resolved
op_tests/multigpu_tests/triton_test/test_fused_allreduce_add_rms_quant_gemm.py
Outdated
Show resolved
Hide resolved
| # Iris library for GPU-initiated communication primitives | ||
| # Pinned to commit 905ec1c (Nov 18, 2024) for reproducibility and API stability | ||
| iris @ git+https://github.com/ROCm/iris.git@905ec1cea8f350211a70c7d0b2bc11a09a6f6429 | ||
| iris @ git+https://github.com/ROCm/iris.git@micmelesse/fusion |
There was a problem hiding this comment.
The Iris dependency is no longer pinned to a commit/tag and instead tracks a branch (@micmelesse/fusion). That makes installs non-reproducible and can break CI if the branch is force-pushed or deleted. Please pin to an immutable ref (commit SHA or release tag), and if you need a moving target, consider documenting it separately from the requirements file used by CI.
| iris @ git+https://github.com/ROCm/iris.git@micmelesse/fusion | |
| iris @ git+https://github.com/ROCm/iris.git@v0.9.0 |
There was a problem hiding this comment.
This is place holder for CI right now. I will update when the iris pr merges.
9feb5d8 to
a7e8484
Compare
852ace2 to
fc0bd99
Compare
fc0bd99 to
17b7a86
Compare
Motivation
This pr adds a fused Triton kernel that combines allreduce + residual add + rmsnorm + FP8 per-row quant into a single kernel launch, followed by a
torch._scaled_mmcall. Uses iris symmetric heap for GPU-initiated communication. Includes standalone correctness test and benchmark script. This pr requires thedevice_barrierpr in iris. Future work will look at fusing the gemm into the kernel.Technical Details
Test Plan
Test Result
Submission Checklist