Cache hip.get_num_xcc() to eliminate 14ms Config overhead#528
Open
Cache hip.get_num_xcc() to eliminate 14ms Config overhead#528
Conversation
Config.__post_init__ called iris.hip.get_num_xcc() on every collective invocation, adding ~14ms of ctypes/HIP runtime overhead per call. This caused all collectives to show constant 10-12ms latency regardless of message size (100-200x slower than RCCL). Fix: lru_cache on the XCC query since it never changes during a process. Also adds simplified PULL-model kernels for all_gather, reduce_scatter, and all_to_all that use linear tile indexing instead of GROUP_SIZE_M swizzle + chiplet_transform, which caused poor codegen on MI300X. Results on MI300X 8-GPU (before -> after): - reduce_scatter: 12,000us -> 58us at 128KB, beats RCCL by 3-6x - all_reduce: beats RCCL at 512KB-8MB (0.88-0.99x) - all_gather: 12,000us -> 58us at 32KB (1.08x RCCL) - all_to_all: 10,000us -> 55us at 128KB (1.39x RCCL) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Contributor
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR reduces per-collective overhead by caching the HIP XCC query and adds/uses simplified “PULL”-model Triton kernels (targeted at improving MI300X codegen/perf).
Changes:
- Cache
iris.hip.get_num_xcc()vialru_cacheto avoid repeated ctypes runtime calls inConfig.__post_init__. - Add new PULL-model kernels for
all_gather,all_to_all, andreduce_scatterusing linear tile indexing. - Switch defaults to prefer the new PULL/simple kernels in several launch paths and config.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 9 comments.
| File | Description |
|---|---|
| iris/ccl/config.py | Adds cached XCC query and updates all-gather variant options/default. |
| iris/ccl/triton/all_gather.py | Adds persistent_all_gather_pull and wires it to config variant "pull". |
| iris/ccl/triton/all_to_all.py | Adds persistent_all_to_all_pull and switches launcher to use it by default. |
| iris/ccl/triton/reduce_scatter.py | Adds persistent_reduce_scatter_simple and switches launcher to use it by default. |
Comment on lines
+182
to
+194
| # Block distribution: each rank handles contiguous chunk of tiles | ||
| tiles_per_rank = tl.cdiv(total_tiles, world_size) | ||
| start_tile = group_rank * tiles_per_rank | ||
| remaining = total_tiles - start_tile | ||
| remaining = tl.maximum(remaining, 0) | ||
| max_tiles = tl.minimum(tiles_per_rank, remaining) | ||
|
|
||
| for tile_offset in range(pid, max_tiles, COMM_SMS): | ||
| tile_id = start_tile + tile_offset | ||
|
|
||
| pid_m = tile_id // num_pid_n | ||
| pid_n = tile_id % num_pid_n | ||
|
|
Comment on lines
+205
to
+206
| output_offset = rm[:, None] * stride_out_m + rn[None, :] * stride_out_n | ||
|
|
Comment on lines
+267
to
+269
| # Use PULL model by default — PUSH model has poor performance on MI300X | ||
| kernel_fn = persistent_all_to_all_pull | ||
|
|
Comment on lines
+239
to
+241
| # Use simplified kernel by default — the two_shot variant has poor codegen on MI300X | ||
| kernel_fn = persistent_reduce_scatter_simple | ||
|
|
| chunk_size: int | None = None | ||
| use_gluon: bool = False | ||
| all_gather_variant: str = "persistent" | ||
| all_gather_variant: str = "pull" |
Comment on lines
+9
to
+15
| import functools | ||
| import iris | ||
|
|
||
|
|
||
| @functools.lru_cache(maxsize=1) | ||
| def _cached_num_xcc(): | ||
| """Cache the XCC count since it never changes during a process.""" |
Comment on lines
106
to
+109
| def __post_init__(self): | ||
| """Validate and auto-detect num_xcds if not set.""" | ||
| if self.num_xcds is None: | ||
| self.num_xcds = iris.hip.get_num_xcc() | ||
| self.num_xcds = _cached_num_xcc() |
Comment on lines
+296
to
+302
| BLOCK_SIZE_M: tl.constexpr, | ||
| BLOCK_SIZE_N: tl.constexpr, | ||
| GROUP_SIZE_M: tl.constexpr, | ||
| COMM_SMS: tl.constexpr, | ||
| NUM_XCDS: tl.constexpr, | ||
| CHUNK_SIZE: tl.constexpr, | ||
| ): |
Comment on lines
+306
to
+307
| Uses simple linear tile indexing (no swizzle/chiplet transform) for clean | ||
| codegen on MI300X. GROUP_SIZE_M/NUM_XCDS/CHUNK_SIZE accepted but unused. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Config.__post_init__callediris.hip.get_num_xcc()(a ctypes call into the HIP runtime) on every collective invocation, adding ~14ms overhead per call. This caused all collectives to show constant 10-12ms latency regardless of message size.functools.lru_cacheon the XCC query since it never changes during a process lifetime.Results (MI300X, 8 GPUs, bfloat16)
ratio = iris/RCCL, bold = iris wins
reduce_scatter beats RCCL by 3-6x at medium/large sizes. all_reduce beats RCCL at 512KB-8MB.
Test plan
🤖 Generated with Claude Code