An example where gemm and all-scatter are independent#232
Merged
Conversation
Contributor
There was a problem hiding this comment.
Pull Request Overview
This PR introduces a new example demonstrating independent GEMM and all-scatter operations in a distributed setting. The implementation provides two algorithmic approaches: bulk synchronous all-scatter and ring-based all-reduce, with comprehensive benchmarking and validation capabilities for multi-GPU matrix multiplication scenarios.
- Adds distributed GEMM implementations with two communication strategies (bulk synchronous and ring-based)
- Provides a comprehensive benchmarking framework with timing, validation, and trace collection
- Integrates with Iris library for multi-GPU memory management and communication primitives
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| matmul_wrapper.py | PyTorch autograd wrapper for distributed GEMM kernel execution with debugging and timing support |
| gemm_all_scatter_bulk_synchronous.py | Triton kernels for persistent GEMM and bulk synchronous all-scatter communication |
| gemm_all_reduce_ring_based.py | Alternative implementation using ring-based all-reduce with more complex synchronization |
| benchmark.py | Comprehensive benchmarking script with distributed execution, validation, and performance measurement |
examples/20_gemm_all_scatter_independent/gemm_all_reduce_ring_based.py
Outdated
Show resolved
Hide resolved
…e 20 (#234) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: neoblizz <9790745+neoblizz@users.noreply.github.com> Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: neoblizz <9790745+neoblizz@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
examples/20_gemm_all_scatter_independent/gemm_all_scatter_bulk_synchronous.py
Show resolved
Hide resolved
mawad-amd
approved these changes
Oct 14, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This pull request introduces two new files to the
examples/20_gemm_all_scatter_independentdirectory, providing benchmarking and kernel implementations for distributed GEMM (General Matrix Multiply) using Triton and Iris. The changes add a full benchmarking script and a ring-based all-reduce GEMM kernel, enabling efficient multi-GPU matrix multiplication and communication. These additions support flexible configuration, validation, and performance measurement for distributed compute scenarios.New benchmarking and execution script
benchmark.py, supporting distributed execution with PyTorch, command-line configuration for matrix dimensions, datatypes, block sizes, and benchmarking/validation modes. The script manages process spawning, distributed setup, memory allocation, kernel timing, validation, and performance logging.New Triton kernel implementation for distributed GEMM
gemm_all_reduce_ring_based.py, implementing two Triton kernels:persistent_gemmfor local matrix multiplication andpersistent_all_reducefor ring-based distributed reduction and scatter of results. These kernels use advanced synchronization and communication primitives for efficient multi-GPU execution.Integration with Iris and validation utilities
Performance measurement and output
Support for flexible configuration and debugging