Conversation
Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
| (1024, 2048, 1024), | ||
| ], | ||
| ) | ||
| def test_matmul_all_scatter(dtype, atol, rtol, M, N, K): |
There was a problem hiding this comment.
@copilot did you run all the tests? Please do and report back. Report back a roofline as well.
There was a problem hiding this comment.
Ran all tests on 8× MI300X (gfx942) with 2, 4, and 8 ranks. All pass.
Test results (correctness)
| Ranks | Result |
|---|---|
| 2 | ✅ 8 passed |
| 4 | ✅ 6 passed, 2 skipped (M=64 too small for 4-way N split) |
| 8 | ✅ 6 passed, 2 skipped (same reason) |
Ran the full ops test suite (all_gather_matmul, matmul_all_gather, matmul_all_reduce, matmul_all_scatter, matmul_reduce_scatter) with 4 ranks — all pass, no regressions.
Roofline (fp16, iris.do_bench-style mean over 20 iterations)
4 ranks, N_local = N/4:
| Shape (M×N×K) | ms | TFLOPS | % of 1307 peak |
|---|---|---|---|
| 1024 × 4096 × 4096 | 0.33 | 105.7 | 8.1% |
| 4096 × 4096 × 4096 | 3.37 | 40.8 | 3.1% |
| 8192 × 8192 × 8192 | 3.78 | 290.8 | 22.3% |
8 ranks, N_local = N/8:
| Shape (M×N×K) | ms | TFLOPS | % of 1307 peak |
|---|---|---|---|
| 1024 × 4096 × 4096 | 0.34 | 101.0 | 7.7% |
| 4096 × 4096 × 4096 | 8.82 | 15.6 | 1.2% |
| 8192 × 8192 × 8192 | 11.15 | 98.7 | 7.5% |
The 4K×4K×4K case at 8 ranks is memory-bandwidth bound (N_local=512 → very thin slabs per rank). The 8K² case benefits from the compute-to-communication ratio being favorable. These numbers are in line with comparable ops (matmul_all_gather) in the benchmark suite.
iris/ops/matmul_all_scatter.py
Outdated
| # Broadcast this rank's tile to all ranks using iris.x.all_gather with dim=1. | ||
| # dim=1 places the tile at the current rank's column offset in the global output, | ||
| # so every rank receives each rank's column-shard (all-scatter along N dimension). | ||
| iris.x.all_gather(tile_obj, dst_view, dim=1, ctx=ctx) |
There was a problem hiding this comment.
The operation produces the requested semantics: each rank has a column shard B_local (K, N_local), computes C_local = A @ B_local, and every rank ends up with the full C (M, N) — column shards broadcasted to all ranks.
The iris.x API has no all_scatter primitive. The available device-level collective for "every rank broadcasts its shard so all ranks get the complete tensor" is iris.x.all_gather, which its own docstring describes as "scatter pre-computed data mode — scatters a pre-computed tile to all ranks at correct offsets". That is exactly the pattern here: each rank writes its computed column tile to every other rank's output buffer at the correct N offset.
The distinction from matmul_all_gather (which also uses iris.x.all_gather) is the sharding axis:
matmul_all_gather: A is row-sharded → gathers along M (dim=0)matmul_all_scatter: B is column-sharded → gathers along N (dim=1)
Both result in all ranks holding the full output. The name matmul_all_scatter follows the original issue title. If @aamarnat intended a different collective (e.g., one where each rank receives only a subset of the output rather than the full tensor), the semantics would need clarification and a different implementation would be needed.
There was a problem hiding this comment.
@copilot code review[agent] This is a weird implementation. I wanted a replacement for https://github.com/ROCm/iris/tree/main/examples/07_gemm_all_scatter
There was a problem hiding this comment.
@copilot can u introduce an x for all scatter and use that? also plese review ur variable naming.
There was a problem hiding this comment.
Done in commit 0ac27a4.
- Added
iris/x/all_scatter.py— newiris.x.all_scatter(tile, dst_view, ctx)primitive that scatters each rank's pre-computed tile to all ranks at its column-stripe offset. - The
matmul_all_scatterkernel now callsiris.x.all_scatterinstead of the rawiris.storeloop. - Variable naming cleaned up throughout:
B_local→B_shard,N_local→N_shard,C_gathered→C,stride_cm_gathered→stride_cm/stride_cn,tensorA/tensorB→view_A/view_B.
There was a problem hiding this comment.
Pull request overview
Adds a new fused tensor-parallel primitive matmul_all_scatter to iris.ops, intended to compute local GEMM on column-sharded B_local and distribute the resulting column tiles so each rank materializes the full C.
Changes:
- Introduces
iris/ops/matmul_all_scatter.pywith a Triton persistent-kernel implementation usingiris.x.all_gather(dim=1). - Exposes the op via
iris/ops/__init__.py(OpsNamespace.matmul_all_scatter, exports, docstring list). - Adds distributed correctness tests in
tests/ops/test_matmul_all_scatter.py.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
iris/ops/matmul_all_scatter.py |
New fused GEMM + tile distribution kernel and host API. |
iris/ops/__init__.py |
Wires the new op into the public iris.ops namespace. |
tests/ops/test_matmul_all_scatter.py |
Adds reference-based distributed tests for correctness/semantics. |
You can also share your feedback on Copilot code review. Take the survey.
| device = A.device | ||
| num_sms = config.num_sms | ||
| if num_sms is None: | ||
| props = torch.cuda.get_device_properties(device) | ||
| num_sms = props.multi_processor_count | ||
|
|
||
| even_k = K % config.block_size_k == 0 | ||
|
|
||
| # Launch single fused kernel | ||
| grid = (num_sms,) | ||
| _fused_matmul_all_scatter_kernel[grid]( |
There was a problem hiding this comment.
The kernel performs rank-to-rank communication (iris.x.all_gather(...)) inside a persistent-tile loop, which relies on all ranks executing the same sequence/partitioning of tiles. Choosing grid=(num_sms,) from local device properties can diverge across ranks (heterogeneous GPUs, MIG/partitioning, differing visibility), causing mismatched collective participation and potential deadlock. Recommendation: require config.num_sms to be explicitly set and identical across ranks for this op, or derive a consistent cross-rank value (e.g., compute min/agree via a host-side distributed reduction/broadcast) and use that for grid/scheduler.
iris/ops/matmul_all_scatter.py
Outdated
| Fused matrix multiplication and all-scatter using scatter pattern. | ||
|
|
||
| Computes: output = all_scatter(A @ B_local) along N dimension | ||
|
|
||
| Each rank has B_local of shape (K, N_local) where N_local = N / world_size. | ||
| The operation computes C_local = A @ B_local on each rank and immediately | ||
| broadcasts each rank's column-shard tiles to all ranks via iris.x.all_gather | ||
| (dim=1), so that every rank ends up with the full C (M, N). |
There was a problem hiding this comment.
The docstring states all_scatter(A @ B_local) but the described semantics (‘every rank ends up with full C’) and the implementation both use iris.x.all_gather(dim=1). In standard collective terminology this is an all-gather of column shards, so the current wording is internally inconsistent and can confuse users. Suggest updating the docstring (and any similar text in the namespace doc) to describe the operation as ‘compute local GEMM then all-gather along N (dim=1)’ while keeping the op name if it’s part of the established API set.
| config = ops.FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) | ||
|
|
||
| if N_local < config.block_size_n: | ||
| config = ops.FusedConfig(block_size_m=32, block_size_n=N_local, block_size_k=32) | ||
|
|
||
| from iris.ops.matmul_all_scatter import matmul_all_scatter |
There was a problem hiding this comment.
This semantics test can set block_size_n=N_local for small N_local. In the other test in this file you explicitly treat 32 as the minimum supported block size and skip when N_local < 32. These two tests can therefore disagree on what configurations are valid, which risks flaky failures on larger world_size (e.g., N_local becomes 16 for N=128, world_size=8). Recommendation: align this test with the same minimum-block-size constraints (skip when N_local < 32), or choose from a known-supported set of block sizes rather than assigning block_size_n dynamically.
tests/ops/test_matmul_all_scatter.py
Outdated
| shmem.ops.matmul_all_scatter(output, A, B_local, config=config) | ||
|
|
There was a problem hiding this comment.
The new op supports an optional bias argument (and the kernel has a bias path), but the tests only exercise bias=None. Add at least one parametrized case that passes a bias tensor and validates against a reference that adds the bias in the same way as the kernel (per-row based on the implementation) to prevent regressions in the bias path.
| shmem.ops.matmul_all_scatter(output, A, B_local, config=config) | |
| # Add a per-row bias so we exercise the bias path in the kernel. | |
| bias = torch.randn((M, 1), device=output.device, dtype=output.dtype) | |
| # Apply the same per-row bias to the PyTorch reference. | |
| pytorch_output = pytorch_output + bias | |
| shmem.ops.matmul_all_scatter(output, A, B_local, bias=bias, config=config) |
tests/ops/test_matmul_all_scatter.py
Outdated
| elif dtype == torch.float32: | ||
| config = FusedConfig(block_size_m=128, block_size_n=128, block_size_k=64) |
There was a problem hiding this comment.
The elif dtype == torch.float32 branch is currently unreachable because dtype is only parametrized as float16/bfloat16 in this test. Consider removing the dead branch to reduce confusion, or extend the dtype parametrization if float32 coverage is intended.
| elif dtype == torch.float32: | |
| config = FusedConfig(block_size_m=128, block_size_n=128, block_size_k=64) |
…ttern from example 07 Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
…x variable naming Co-authored-by: mawad-amd <112003944+mawad-amd@users.noreply.github.com>
Introduces
matmul_all_scatter— a fused GEMM+AllScatter operation — to theiris.opsAPI, completing the set of tensor-parallel GEMM+CCL primitives alongsidematmul_all_gather,matmul_reduce_scatter, andmatmul_all_reduce.Semantics
Each rank holds a replicated
A (M, K)and a column-shardedB_shard (K, N_shard)whereN_shard = N / world_size. The fused kernel computesC_shard = A @ B_shardand immediately scatters each rank's column stripe to all other ranks, so every rank ends up with the fullC (M, N).Equivalent to:
Implemented as a single persistent kernel using the new
iris.x.all_scattertile-level primitive — no intermediate buffer needed.Changes
iris/x/all_scatter.py(new) — Tile-levelall_scatter(tile, dst_view, ctx)primitive. Each rank pushes its pre-computed tile to all ranks at its column-stripe offset in the global output, usingiris.store+DeviceContext. Analogous toiris.x.all_gatherbut for column-scatter semantics.iris/x/__init__.py— Exportsall_scatter; adds usage example in module docstring.iris/ops/matmul_all_scatter.py— Triton kernel_fused_matmul_all_scatter_kernelwith persistent tile scheduling viatritonblas.ScheduleContext(M, N_shard, K). After computing each local GEMM tile, the kernel callsiris.x.all_scatterto scatter the result to all ranks. Host-sidematmul_all_scatter()andmatmul_all_scatter_preamble(). Variable naming updated throughout (B_shard,N_shard,C,view_A,view_B).iris/ops/__init__.py— Imports new op, addsOpsNamespace.matmul_all_scatter()method, updates__all__and module docstring.tests/ops/test_matmul_all_scatter.py— Parametrized correctness tests across dtypes and problem sizes, validated against a PyTorchdist.all_gatherreference. Semantics test aligned to skip whenN_shard < 32.Usage
The
iris.x.all_scatterprimitive can also be used directly in custom kernels:Original prompt
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.