-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
Summary
Group.split() was recently added to the NCCL backend in #3172 for multi-parallelism support. The JACCL and Ring backends currently throw std::runtime_error when split() is called. Adding split() to these backends would bring parity across all MLX distributed backends and unlock the full range of multi-group distributed patterns on Apple Silicon Thunderbolt 5 clusters.
Motivation
Group.split() is a foundational primitive for creating communication subgroups allowing for other distributed patterns:
- Hybrid tensor + pipeline parallelism: TP within small node groups, PP between them.
- Expert parallelism: Distributing MoE experts across node subgroups instead of replicating all experts on every node. Reduces per-node memory for models like Kimi K2.5 where the vast majority of weights are unused per inference step.
- Data parallel replicas: Running independent model replicas on node subsets for concurrent serving. Each replica has its own TP subgroup for example.
- Context parallelism: Splitting long sequences across node subgroups, each handling a chunk of the context window. Relevant for architectures with large KV caches and large context.
- Speculative decoding: Draft model on one subgroup, verifier on another, coordinating candidate token generation and batch verification.
- Multi-model pipelines: Different models (vision encoder → LLM → reward model) on different subgroups with activation forwarding between them.
With TB5 RDMA and JACCL, fully connected topologies are limited by physical port count, and practical TP group sizes are further constrained by model dimension divisibility (attention heads, hidden sizes must divide evenly), making 2 and 4 the most common TP configurations.
| Machine | Chip | TB5 Ports | Full-Mesh TP Max (N-1 ≤ ports) |
|---|---|---|---|
| MacBook Pro | M4 Max | 3 | 4 nodes |
| Mac Studio | M4 Max | 4 | 5 nodes |
| Mac Studio | M3 Ultra | 6 | 7 nodes |
Note: Full-Mesh TP Max is a hardware ceiling.
Subgroup support would allow clusters of Apple Silicon machines to apply a broader spectrum of parallelism strategies such as running large models at full precision across nodes that individually can't hold them, serving concurrent usage with replica groups, or distributing MoE experts efficiently.
MLX already has the higher-level building blocks:
shard_linear/AllToShardedLinearfor TP.PipelineMixinwithsend/recvfor PPsharded_load()accepting bothpipeline_groupandtensor_group.
The subgroup primitive allows users to compose them.
Current behavior
world = mx.distributed.init(backend="jaccl")
tp_group = world.split(color=world.rank() // 2)
# RuntimeError: [jaccl] Group split not supported.world = mx.distributed.init(backend="ring")
tp_group = world.split(color=world.rank() // 2)
# RuntimeError: [ring] Group split not supported.Desired behavior
world = mx.distributed.init(backend="jaccl")
# Example: hybrid TP + PP
tp_group = world.split(color=world.rank() // 2) # TP sub-groups
pp_group = world.split(color=world.rank() % 2) # PP chain across groups
# Example: data parallel replicas
replica_group = world.split(color=world.rank() // 4) # independent replicas
# Example: expert parallel subgroups
ep_group = world.split(color=world.rank() % num_ep_groups)Why this seems feasible
JACCL already supports point-to-point send/recv to arbitrary peers in the mesh. A sub-group would operate on a subset of existing RDMA connections rather than requiring new transport infrastructure. The NCCL implementation in #3172 can serve as a reference for the GroupImpl::split() interface contract.
For Ring, the sub-group would create a smaller ring over a subset of ranks — the neighbor-only send/recv constraint naturally maps to sequential forwarding patterns like pipeline parallelism.
Potential workarounds
Without split(), users could launch each subgroup as a separate backend instance with independent environment configuration and manage inter-group communication at the application layer. Although in general this functionality seems much better suited inside mlx
Environment
- MLX 0.31.0
- macOS 26.2+
- Apple Silicon with Thunderbolt 5/RDMA support