Skip to content

Cache hip.get_num_xcc() to eliminate 14ms Config overhead#528

Open
mawad-amd wants to merge 2 commits intomainfrom
muhaawad/fix-config-latency
Open

Cache hip.get_num_xcc() to eliminate 14ms Config overhead#528
mawad-amd wants to merge 2 commits intomainfrom
muhaawad/fix-config-latency

Conversation

@mawad-amd
Copy link
Copy Markdown
Collaborator

Summary

  • Root cause: Config.__post_init__ called iris.hip.get_num_xcc() (a ctypes call into the HIP runtime) on every collective invocation, adding ~14ms overhead per call. This caused all collectives to show constant 10-12ms latency regardless of message size.
  • Fix: functools.lru_cache on the XCC query since it never changes during a process lifetime.
  • Added simplified PULL-model kernels for all_gather, reduce_scatter, and all_to_all that use linear tile indexing instead of GROUP_SIZE_M swizzle + chiplet_transform, which caused poor codegen on MI300X.

Results (MI300X, 8 GPUs, bfloat16)

Size all_gather reduce_scatter all_to_all all_reduce
32KB 1.08x 1.10x 1.50x 1.11x
128KB 1.19x 1.01x 1.39x 1.10x
512KB 1.37x 0.81x 1.49x 0.93x
2MB 1.62x 0.57x 1.58x 0.90x
8MB 1.87x 0.35x 1.62x 0.99x
32MB 2.04x 0.23x 1.77x 1.22x
128MB 2.13x 0.17x 1.97x 1.15x

ratio = iris/RCCL, bold = iris wins

reduce_scatter beats RCCL by 3-6x at medium/large sizes. all_reduce beats RCCL at 512KB-8MB.

Test plan

  • All 4 collectives pass correctness checks (8 GPUs, MI300X)
  • Performance benchmarked against RCCL across 7 message sizes
  • CI tests pass

🤖 Generated with Claude Code

Config.__post_init__ called iris.hip.get_num_xcc() on every collective
invocation, adding ~14ms of ctypes/HIP runtime overhead per call. This
caused all collectives to show constant 10-12ms latency regardless of
message size (100-200x slower than RCCL).

Fix: lru_cache on the XCC query since it never changes during a process.

Also adds simplified PULL-model kernels for all_gather, reduce_scatter,
and all_to_all that use linear tile indexing instead of GROUP_SIZE_M
swizzle + chiplet_transform, which caused poor codegen on MI300X.

Results on MI300X 8-GPU (before -> after):
- reduce_scatter: 12,000us -> 58us at 128KB, beats RCCL by 3-6x
- all_reduce: beats RCCL at 512KB-8MB (0.88-0.99x)
- all_gather: 12,000us -> 58us at 32KB (1.08x RCCL)
- all_to_all: 10,000us -> 55us at 128KB (1.39x RCCL)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings May 2, 2026 10:26
@mawad-amd mawad-amd requested review from BKP and neoblizz as code owners May 2, 2026 10:26
@github-actions github-actions Bot added in-progress We are working on it iris Iris project issue labels May 2, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR reduces per-collective overhead by caching the HIP XCC query and adds/uses simplified “PULL”-model Triton kernels (targeted at improving MI300X codegen/perf).

Changes:

  • Cache iris.hip.get_num_xcc() via lru_cache to avoid repeated ctypes runtime calls in Config.__post_init__.
  • Add new PULL-model kernels for all_gather, all_to_all, and reduce_scatter using linear tile indexing.
  • Switch defaults to prefer the new PULL/simple kernels in several launch paths and config.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 9 comments.

File Description
iris/ccl/config.py Adds cached XCC query and updates all-gather variant options/default.
iris/ccl/triton/all_gather.py Adds persistent_all_gather_pull and wires it to config variant "pull".
iris/ccl/triton/all_to_all.py Adds persistent_all_to_all_pull and switches launcher to use it by default.
iris/ccl/triton/reduce_scatter.py Adds persistent_reduce_scatter_simple and switches launcher to use it by default.

Comment on lines +182 to +194
# Block distribution: each rank handles contiguous chunk of tiles
tiles_per_rank = tl.cdiv(total_tiles, world_size)
start_tile = group_rank * tiles_per_rank
remaining = total_tiles - start_tile
remaining = tl.maximum(remaining, 0)
max_tiles = tl.minimum(tiles_per_rank, remaining)

for tile_offset in range(pid, max_tiles, COMM_SMS):
tile_id = start_tile + tile_offset

pid_m = tile_id // num_pid_n
pid_n = tile_id % num_pid_n

Comment on lines +205 to +206
output_offset = rm[:, None] * stride_out_m + rn[None, :] * stride_out_n

Comment on lines +267 to +269
# Use PULL model by default — PUSH model has poor performance on MI300X
kernel_fn = persistent_all_to_all_pull

Comment on lines +239 to +241
# Use simplified kernel by default — the two_shot variant has poor codegen on MI300X
kernel_fn = persistent_reduce_scatter_simple

Comment thread iris/ccl/config.py
chunk_size: int | None = None
use_gluon: bool = False
all_gather_variant: str = "persistent"
all_gather_variant: str = "pull"
Comment thread iris/ccl/config.py
Comment on lines +9 to +15
import functools
import iris


@functools.lru_cache(maxsize=1)
def _cached_num_xcc():
"""Cache the XCC count since it never changes during a process."""
Comment thread iris/ccl/config.py
Comment on lines 106 to +109
def __post_init__(self):
"""Validate and auto-detect num_xcds if not set."""
if self.num_xcds is None:
self.num_xcds = iris.hip.get_num_xcc()
self.num_xcds = _cached_num_xcc()
Comment on lines +296 to +302
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
COMM_SMS: tl.constexpr,
NUM_XCDS: tl.constexpr,
CHUNK_SIZE: tl.constexpr,
):
Comment on lines +306 to +307
Uses simple linear tile indexing (no swizzle/chiplet transform) for clean
codegen on MI300X. GROUP_SIZE_M/NUM_XCDS/CHUNK_SIZE accepted but unused.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants