Skip to content

Batch U projection all-reduces into single NCCL collective#72

Merged
RobotSail merged 1 commit intoRed-Hat-AI-Innovation-Team:mainfrom
stmcgovern:batch-u-allreduce
Mar 9, 2026
Merged

Batch U projection all-reduces into single NCCL collective#72
RobotSail merged 1 commit intoRed-Hat-AI-Innovation-Team:mainfrom
stmcgovern:batch-u-allreduce

Conversation

@stmcgovern
Copy link
Contributor

@stmcgovern stmcgovern commented Mar 3, 2026

Closes #71

Batch all per-target U projection coefficient all-reduces into a single
dist.all_reduce call. Reduces 224 NCCL kernel launches to 1 per step
(Llama-8B). Data volume unchanged — this targets launch latency, not bandwidth.

Benchmark (2× H200, NVLink, 224 targets):

Individual Batched
Wall time 12.4 ms 7.1 ms
Speedup 1.8×
NCCL launches 224 1
Per-call overhead 56 μs

Single-GPU path unchanged. Transient memory only (~1.4 GB for Llama-8B,
freed after project_gradients returns).

Changes

  • project_gradient_to_orthogonal_space gains skip_u kwarg (default False)
  • project_gradients restructured into compute → batch all-reduce → apply
    for the distributed path
  • 5 new tests including test_batched_path_matches_unbatched (bitwise equality
    vs unbatched path)

Summary by CodeRabbit

Release Notes

  • New Features

    • Added optional parameter to control gradient projection behavior in distributed training scenarios.
    • Implemented batched all-reduce optimization for improved distributed training performance across multiple modules.
  • Tests

    • Added comprehensive test coverage for batched gradient projection behavior and multi-module optimization scenarios.

Restructure distributed project_gradients into three phases:

  1. Compute local U coefficients U_high^T @ dU per target (no collective)
  2. cat + single dist.all_reduce (224 kernel launches → 1 for Llama-8B)
  3. Offset-split and apply per-target projections

Add skip_u keyword to project_gradient_to_orthogonal_space so V
projections reuse the shared function without code duplication.
Single-GPU path unchanged.

Batching is exact: all_reduce(SUM) distributes over concatenation,
and per-target coefficient matrices are independent.

Transient memory only — the batched tensor (~1.4 GB for Llama-8B
bf16) is freed after phase 3.
@coderabbitai
Copy link

coderabbitai bot commented Mar 3, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info

Configuration used: defaults

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4d6dc87 and 14b17b8.

📒 Files selected for processing (2)
  • src/mini_trainer/osft_utils.py
  • tests/test_osft.py

📝 Walkthrough

Walkthrough

This PR optimizes distributed training by batching U projection all-reduce operations into a single NCCL collective. The implementation introduces a skip_u parameter to project_gradient_to_orthogonal_space, restructures ModelWithOSFT.project_gradients into compute/batch/apply phases for distributed cases, and validates the batched approach through comprehensive tests ensuring mathematical equivalence.

Changes

Cohort / File(s) Summary
Core Implementation
src/mini_trainer/osft_utils.py
Added skip_u keyword-only parameter to project_gradient_to_orthogonal_space() with conditional U projection control. Restructured project_gradients() into three distributed phases: compute local U coefficients, batch-flatten and single all-reduce, then apply per-module updates. Maintained V projections with per-module handling and optional all-reduce.
Test Coverage
tests/test_osft.py
Introduced TestBatchedUAllReduce test suite validating: skip_u flag behavior, V projection correctness, orthogonality across multiple targets, coefficient flatten/cat/split round-tripping, and parity between batched and unbatched distributed paths with mocked all-reduce consensus.

Sequence Diagram(s)

sequenceDiagram
    participant Client as Training Loop
    participant Compute as Compute Phase
    participant Batch as Batch & AllReduce
    participant Apply as Apply Phase
    
    Client->>Compute: for each OSFT module
    Compute->>Compute: U_high^T @ dU (local coefficients)
    Compute-->>Batch: flatten coefficients
    
    Batch->>Batch: torch.cat(all flattened coeffs)
    Batch->>Batch: dist.all_reduce(batched tensor, SUM)
    Batch-->>Apply: split all-reduced tensor
    
    Apply->>Apply: for each module
    Apply->>Apply: apply U projection: dU -= U_high @ proj_coeff
    Apply->>Apply: apply V projection (skip_u=True)
    Apply-->>Client: completed gradients
Loading
sequenceDiagram
    participant Client as Training Loop
    participant Modules as OSFT Modules (224x)
    participant AllReduce as NCCL AllReduce
    
    Note over Client,AllReduce: Before: Unbatched (224 separate all-reduces)
    loop for each module
        Client->>Modules: compute U_high^T @ dU
        Modules->>AllReduce: all_reduce(proj_coeff)
        AllReduce-->>Modules: receive reduced coeff
        Modules->>Modules: apply projection
    end
    
    Note over Client,AllReduce: After: Batched (1 all-reduce)
    Client->>Modules: compute U_high^T @ dU (all modules)
    Modules-->>Client: collect flattened coeffs
    Client->>AllReduce: all_reduce(cat(all_coeffs))
    AllReduce-->>Client: receive single reduced tensor
    Client->>Modules: split & apply projections
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

  • PR #47: Modifies distributed gradient-projection logic with U/V projection and all-reduce handling in the same core file.
  • PR #50: Touches OSFT orthogonalization logic and test coverage with related SVD projection mechanisms.

Suggested reviewers

  • NikhilNayak-debug
  • RobotSail

Poem

🐰 A rabbit's ode to batched projections:
Two-twenty-four became just one,
NCCL kernels now can run—
With batched all-reduces bright,
We orthogonalize the night!
SVD spreads joy so wide,
Safe in math, our guide! ✨

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and specifically describes the main change: batching U projection all-reduces into a single NCCL collective operation.
Linked Issues check ✅ Passed The PR implements all coding requirements from issue #71: adds skip_u parameter, restructures distributed project_gradients into compute-batch-apply phases, implements coefficient batching and all-reduce, and includes comprehensive test coverage for batched behavior.
Out of Scope Changes check ✅ Passed All code changes are directly scoped to issue #71: enhancing project_gradient_to_orthogonal_space, restructuring project_gradients for batched all-reduces, and adding tests validating the batched distributed path implementation.
Docstring Coverage ✅ Passed Docstring coverage is 92.31% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Mar 5, 2026

Codecov Report

❌ Patch coverage is 90.47619% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/mini_trainer/osft_utils.py 90.47% 4 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator

@RobotSail RobotSail left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested and it seems to work as expected, I am seeing a training speedup of around 2%. Thanks for this change, LGTM!

@RobotSail RobotSail merged commit 3300833 into Red-Hat-AI-Innovation-Team:main Mar 9, 2026
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Batch U projection all-reduces into single NCCL collective

2 participants