Batch U projection all-reduces into single NCCL collective#72
Conversation
Restructure distributed project_gradients into three phases: 1. Compute local U coefficients U_high^T @ dU per target (no collective) 2. cat + single dist.all_reduce (224 kernel launches → 1 for Llama-8B) 3. Offset-split and apply per-target projections Add skip_u keyword to project_gradient_to_orthogonal_space so V projections reuse the shared function without code duplication. Single-GPU path unchanged. Batching is exact: all_reduce(SUM) distributes over concatenation, and per-target coefficient matrices are independent. Transient memory only — the batched tensor (~1.4 GB for Llama-8B bf16) is freed after phase 3.
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review infoConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughThis PR optimizes distributed training by batching U projection all-reduce operations into a single NCCL collective. The implementation introduces a Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Training Loop
participant Compute as Compute Phase
participant Batch as Batch & AllReduce
participant Apply as Apply Phase
Client->>Compute: for each OSFT module
Compute->>Compute: U_high^T @ dU (local coefficients)
Compute-->>Batch: flatten coefficients
Batch->>Batch: torch.cat(all flattened coeffs)
Batch->>Batch: dist.all_reduce(batched tensor, SUM)
Batch-->>Apply: split all-reduced tensor
Apply->>Apply: for each module
Apply->>Apply: apply U projection: dU -= U_high @ proj_coeff
Apply->>Apply: apply V projection (skip_u=True)
Apply-->>Client: completed gradients
sequenceDiagram
participant Client as Training Loop
participant Modules as OSFT Modules (224x)
participant AllReduce as NCCL AllReduce
Note over Client,AllReduce: Before: Unbatched (224 separate all-reduces)
loop for each module
Client->>Modules: compute U_high^T @ dU
Modules->>AllReduce: all_reduce(proj_coeff)
AllReduce-->>Modules: receive reduced coeff
Modules->>Modules: apply projection
end
Note over Client,AllReduce: After: Batched (1 all-reduce)
Client->>Modules: compute U_high^T @ dU (all modules)
Modules-->>Client: collect flattened coeffs
Client->>AllReduce: all_reduce(cat(all_coeffs))
AllReduce-->>Client: receive single reduced tensor
Client->>Modules: split & apply projections
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
RobotSail
left a comment
There was a problem hiding this comment.
Tested and it seems to work as expected, I am seeing a training speedup of around 2%. Thanks for this change, LGTM!
Closes #71
Batch all per-target U projection coefficient all-reduces into a single
dist.all_reducecall. Reduces 224 NCCL kernel launches to 1 per step(Llama-8B). Data volume unchanged — this targets launch latency, not bandwidth.
Benchmark (2× H200, NVLink, 224 targets):
Single-GPU path unchanged. Transient memory only (~1.4 GB for Llama-8B,
freed after
project_gradientsreturns).Changes
project_gradient_to_orthogonal_spacegainsskip_ukwarg (defaultFalse)project_gradientsrestructured into compute → batch all-reduce → applyfor the distributed path
test_batched_path_matches_unbatched(bitwise equalityvs unbatched path)
Summary by CodeRabbit
Release Notes
New Features
Tests