-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Description
Summary
A proposal for RDMA weight streaming: a master node reads model weights from disk and broadcasts only the slices each rank needs via all_sum, eliminating the requirement to store the full model on every node. This builds on the existing mlx_lm.share and sharded_load infrastructure.
Motivation
Currently, distributed inference in MLX requires the full model to be stored on every participating node:
- TP (Tensor Parallel):
sharded_loaddownloads all weight files to every node, thenmodel.shard()selects the needed slices at eval time. Each node stores the full model on disk even though it only uses 1/N of the weights in memory. - PP (Pipeline Parallel):
sharded_loadalready optimizes this — it downloads only the files needed for the local pipeline stage. This is the closest to what we're proposing. - EP (Expert Parallel, MoE): No built-in support. Each node stores and loads the full model.
For large models this is a significant constraint:
| Model | FP16 Size | 4-Node TP4 Disk Usage (current) | With Weight Streaming |
|---|---|---|---|
| Llama 405B | 810 GB | 3.2 TB (810 GB × 4) | 810 GB (master only) |
| DeepSeek V3 | 670 GB | 2.7 TB | 670 GB |
| Kimi K2.5 | 612 GB | 2.4 TB | 612 GB |
With mlx_lm.share demonstrating 5+ GB/s broadcast throughput over Thunderbolt 5 RDMA, streaming weights at startup is fast enough to be practical.
What Exists Today
| Component | Status | What It Does |
|---|---|---|
mlx_lm.share |
Working | Broadcasts entire directories via all_sum at 5+ GB/s |
sharded_load (PP) |
Working | Downloads only the files each pipeline stage needs |
sharded_load (TP) |
Working but redundant | Downloads full model to every node, shards in memory |
model.shard() |
Working | Knows how to slice weight matrices for TP |
model.pipeline() |
Working | Knows which layers belong to which pipeline stage |
| MoE expert placement | Not implemented | No built-in expert-level distribution |
Proposed Approach
Phase 1: TP Weight Streaming (Low-Hanging Fruit)
The master node (rank 0) reads each safetensor file, and instead of broadcasting the entire file, slices the weight matrices according to the TP sharding plan and broadcasts only the relevant slice to each rank.
Current TP flow:
Every node: read full model from disk → shard() → eval (loads 1/N into RAM)
Disk: N copies of full model
Proposed TP flow:
Rank 0: read full model from disk → slice weights → broadcast slices via all_sum
Rank 1-N: receive only their slices → load into RAM
Disk: 1 copy on rank 0 only
This requires:
- Rank 0 inspects the model's shard plan (which
model.shard()already computes) - For each weight file, rank 0 reads it, extracts each rank's slice
- Broadcasts each slice via
all_sum(rank 0 sends real data, others send zeros — same pattern asmlx_lm.share) - Each rank receives only its portion and loads directly into the model
The safetensors.index.json file already maps weight names to files, and model.shard() already knows the slicing dimensions. The main engineering work is connecting these two systems.
Phase 2: PP Weight Streaming
PP is partially solved — sharded_load already knows which files each stage needs and only downloads those. The missing piece is sourcing files from a master node via RDMA instead of from disk/HuggingFace:
Current PP flow:
Each node: download only needed files from HF → load
Disk: partial model per node (already efficient)
Proposed PP flow:
Rank 0: read files from disk → broadcast only the files each stage needs
Rank 1-N: receive only their stage's files → load
Disk: 1 copy on rank 0 only
This is simpler than TP streaming since it operates at the file level (no matrix slicing). sharded_load already computes the file-to-stage mapping.
Phase 3: MoE Expert Placement
This is the most impactful but also most complex. For MoE models (Mixtral, DeepSeek V3, Kimi K2.5), experts are independent weight blocks that can be distributed across nodes:
Proposed EP flow:
Rank 0: read model → distribute expert subsets to each node
Each node: holds 1/N of experts in memory
Inference: router selects active experts → each node runs its local experts → all_sum collects results
Benefits:
- A 4-node cluster with 512 GB each can serve a model with up to ~2 TB of expert weights
- Only active experts compute per token — idle experts consume zero GPU time
- Could extend to demand-paging: evict cold experts, stream hot ones at 5 GB/s
This requires deeper integration with the model architecture — the router needs to know which experts are on which nodes, and the forward pass needs all_sum collectives after expert computation.
Phase 4: Dynamic Expert Demand-Paging (Speculative)
The most advanced version: experts are streamed on-demand during inference based on routing patterns. Hot experts stay resident, cold experts are evicted, and when a cold expert is needed, it's streamed from the master at 5 GB/s (~50ms for a typical expert block).
This is speculative and may not be practical for real-time inference, but for batch/offline workloads the latency could be acceptable.
Benefits Beyond Disk Space
-
Model switching without pre-staging — swap from one model to another by streaming from the master node. No need to copy hundreds of GB to every node first.
-
Serve models larger than any single node's storage — a node with 1 TB NVMe can participate in serving a 2 TB model if it only holds its shard.
-
Heterogeneous clusters ([BUG] Distributed inference OOMs on machines with different RAM size #1804) — nodes with different RAM sizes can receive proportional weight slices. A 64 GB node gets fewer layers/experts than a 512 GB node, with the master orchestrating the placement.
-
Faster cold start for PP — pipeline stages can begin processing as soon as their layers arrive. Stage 0 starts inference while stages 2-3 are still receiving weights.
-
Dynamic reconfiguration — switch from TP2 to TP4 by re-streaming with different slicing. No file management, no pre-staging, just a different sharding plan.
-
Single source of truth — model updates, quantization changes, and fine-tune checkpoints only need to exist on the master node.
Hardware Context
Tested on a 5-node cluster of Mac Studio M3 Ultra (512 GB each) connected via Thunderbolt 5 full mesh (4 cables per node):
mlx_lm.sharemeasured at 5.2 GB/s broadcast to 3 nodes simultaneously- Our independent
all_sum-based broadcast measured 3.7 GB/s (withoutasync_evalpipelining) - Successfully broadcast 812 GB (Llama 405B FP16) to 3 nodes with verified checksum integrity
- RDMA throughput does not degrade with additional receivers (each TB5 cable is an independent 80 Gbps link)
At these speeds, streaming a 200 GB TP4 slice takes ~38 seconds — a one-time startup cost that eliminates terabytes of redundant storage.
Questions for the Team
- Is this direction something the MLX team is already exploring or interested in?
- For TP weight streaming, would it make sense to extend
sharded_loador create a new loading path? - For MoE expert placement, are there plans to add expert parallelism to MLX's distributed primitives?
- The
Group.split()proposal ([Enhancement] Group.split() support for JACCL and Ring backends (parity with NCCL #3172) #3205) would be useful for hybrid TP+EP — is that on the roadmap?
Happy to contribute implementation work if this aligns with the project's direction.