Skip to content

Conversation

@knwng
Copy link
Contributor

@knwng knwng commented Jan 13, 2026

Motivation

To add an example of GEMM + ReduceScatter by workgroup specialization. Resolve #178

Technical Details

It's an one-shot GEMM + ReduceScatter kernel, using atomic_add to do reduce in-place.

Test Plan

As discussed, it's been tested locally.

Test Result

image image image

Submission Checklist

Copy link
Collaborator

@mawad-amd mawad-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR, Kyle! I know it is a draft but I left a couple of comments.

@knwng
Copy link
Contributor Author

knwng commented Jan 20, 2026

Hi @mawad-amd , as you mentioned in #169, do I need to add a test for this like https://github.com/ROCm/iris/blob/main/tests/examples/test_all_load_bench.py?

@knwng knwng marked this pull request as ready for review January 20, 2026 17:59
@knwng knwng requested review from BKP and neoblizz as code owners January 20, 2026 17:59
Copilot AI review requested due to automatic review settings January 20, 2026 17:59
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a GEMM + ReduceScatter example that uses workgroup specialization to overlap computation and communication on AMD GPUs. The implementation divides SMs into GEMM workgroups for matrix multiplication and communication workgroups for scatter operations.

Changes:

  • Added validation function for reduce-scatter operations
  • Implemented persistent GEMM kernel with integrated ReduceScatter using workgroup specialization
  • Created benchmark infrastructure with timing, validation, and tracing capabilities

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
examples/common/validation.py Added validate_reduce_scatter function to verify reduce-scatter correctness
examples/22_gemm_one_shot_reduce_scatter_wg_specialization/gemm_reduce_scatter.py Core kernel implementing GEMM + ReduceScatter with SM specialization
examples/22_gemm_one_shot_reduce_scatter_wg_specialization/matmul_wrapper.py PyTorch autograd wrapper for the GEMM kernel
examples/22_gemm_one_shot_reduce_scatter_wg_specialization/benchmark.py Benchmark script with validation, timing, and distributed setup

@knwng knwng requested a review from mawad-amd January 20, 2026 18:04
Copy link
Collaborator

@mawad-amd mawad-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thanks, Kyle!

@mawad-amd mawad-amd merged commit 9352d1a into ROCm:main Jan 20, 2026
16 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fused GEMM + ReduceScatter with workgroup specialization

2 participants