Skip to content

[COMMS] Fused allreduce, residual add, rms, quant, gemm#2238

Open
micmelesse wants to merge 1 commit intomainfrom
micmelesse/fused_allreduce
Open

[COMMS] Fused allreduce, residual add, rms, quant, gemm#2238
micmelesse wants to merge 1 commit intomainfrom
micmelesse/fused_allreduce

Conversation

@micmelesse
Copy link
Contributor

@micmelesse micmelesse commented Mar 10, 2026

Motivation

This pr adds a fused Triton kernel that combines allreduce + residual add + rmsnorm + FP8 per-row quant into a single kernel launch, followed by a torch._scaled_mm call. Uses iris symmetric heap for GPU-initiated communication. Includes standalone correctness test and benchmark script. This pr requires the device_barrier pr in iris. Future work will look at fusing the gemm into the kernel.

Technical Details

Test Plan

Test Result

Submission Checklist

@github-actions
Copy link
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:multi-gpu Multi-GPU op tests (8 GPU)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2238 --add-label <label>

@micmelesse
Copy link
Contributor Author

ROCm/iris#443 needs to merge before this can merge.

@micmelesse micmelesse force-pushed the micmelesse/fused_allreduce branch 2 times, most recently from 9f7b803 to 5f5382a Compare March 10, 2026 15:30
@micmelesse micmelesse changed the title [COMMS] fused_allreduce_add_rms_quant_gemm [COMMS] Fused allreduce, residual add, rms, quant, gemm Mar 10, 2026
@micmelesse micmelesse marked this pull request as ready for review March 10, 2026 18:39
@micmelesse micmelesse requested review from a team and Copilot March 10, 2026 18:39
@micmelesse micmelesse marked this pull request as draft March 10, 2026 18:40
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new ROCm/Triton communication+compute path that fuses tensor-parallel AllReduce + residual add + RMSNorm + per-row FP8 quantization into a single Iris-backed kernel launch, followed by a torch._scaled_mm (hipBLASLt) GEMM. It also updates Triton comms packaging behavior around optional Iris availability and includes a new multi-GPU correctness test and benchmark script.

Changes:

  • Add persistent_fused_allreduce_rmsnorm_2d_quant_two_shot Triton kernel + Python manager API (fused_allreduce_add_rms_quant_gemm) using Iris symmetric heap.
  • Add a multi-GPU correctness test (eager + CUDA graph replay) and a torchrun benchmark driver.
  • Update Iris dependency and adjust aiter.ops.triton.comms import behavior / availability flagging.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
requirements-triton-comms.txt Switches Iris dependency to a branch ref for the fused comms work.
aiter/ops/triton/comms/fused_allreduce_add_rms_quant_gemm.py New fused AllReduce+RMSNorm+FP8-quant kernel + manager + NCCL reference path.
aiter/ops/triton/comms/init.py Makes comms importable without Iris via IRIS_COMM_AVAILABLE flag.
aiter/ops/triton/init.py Imports comms unconditionally and re-exports IRIS_COMM_AVAILABLE.
op_tests/multigpu_tests/triton_test/test_fused_allreduce_add_rms_quant_gemm.py New pytest correctness test covering eager + CUDA graph replay.
op_tests/multigpu_tests/triton_test/bench_fused_allreduce_add_rms_quant_gemm.py New torchrun benchmark comparing fused vs unfused baseline.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

# Iris library for GPU-initiated communication primitives
# Pinned to commit 905ec1c (Nov 18, 2024) for reproducibility and API stability
iris @ git+https://github.com/ROCm/iris.git@905ec1cea8f350211a70c7d0b2bc11a09a6f6429
iris @ git+https://github.com/ROCm/iris.git@micmelesse/fusion
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Iris dependency is no longer pinned to a commit/tag and instead tracks a branch (@micmelesse/fusion). That makes installs non-reproducible and can break CI if the branch is force-pushed or deleted. Please pin to an immutable ref (commit SHA or release tag), and if you need a moving target, consider documenting it separately from the requirements file used by CI.

Suggested change
iris @ git+https://github.com/ROCm/iris.git@micmelesse/fusion
iris @ git+https://github.com/ROCm/iris.git@v0.9.0

Copilot uses AI. Check for mistakes.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is place holder for CI right now. I will update when the iris pr merges.

@micmelesse micmelesse force-pushed the micmelesse/fused_allreduce branch 3 times, most recently from 9feb5d8 to a7e8484 Compare March 10, 2026 20:23
@micmelesse micmelesse marked this pull request as ready for review March 10, 2026 20:51
@micmelesse micmelesse force-pushed the micmelesse/fused_allreduce branch 3 times, most recently from 852ace2 to fc0bd99 Compare March 12, 2026 16:47
@micmelesse micmelesse force-pushed the micmelesse/fused_allreduce branch from fc0bd99 to 17b7a86 Compare March 12, 2026 16:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants