Skip to content

Replace Gram matrix V projection with factored form#74

Open
stmcgovern wants to merge 1 commit intoRed-Hat-AI-Innovation-Team:mainfrom
stmcgovern:factored-v-projection
Open

Replace Gram matrix V projection with factored form#74
stmcgovern wants to merge 1 commit intoRed-Hat-AI-Innovation-Team:mainfrom
stmcgovern:factored-v-projection

Conversation

@stmcgovern
Copy link
Contributor

@stmcgovern stmcgovern commented Mar 6, 2026

Fixes #73
Replace dV -= dV @ (V_high^T @ V_high) with the factored form dV -= (dV @ V_high^T) @ V_high. Under FSDP2 this replaces an (M, M) all-reduce with a (k_high, M) all-gather — M/k_high fewer bytes (2x for square weights, 7x for down_proj) and all-gather is cheaper per byte.

Add opt-in caching of the all-gathered V_high (OSFT_CACHE_V=1, default off). V_high is frozen, so the cache is exact.

Includes bench_v_proj.py benchmark and 10 new tests.

Summary by CodeRabbit

  • New Features

    • Optional V_high caching mechanism for optimizing distributed gradient operations
    • Environment flag to enable or disable caching behavior
    • Enhanced monitoring and logging for cache performance metrics and memory usage
  • Tests

    • Comprehensive test coverage for caching functionality, including correctness validation and cache behavior verification

Replace dV -= dV @ (V_high^T @ V_high) with the factored form
dV -= (dV @ V_high^T) @ V_high.  Under FSDP2 this replaces an
(M, M) all-reduce with a (k_high, M) all-gather — M/k_high fewer
bytes (2x for square weights, 7x for down_proj) and all-gather is
cheaper per byte.

Add opt-in caching of the all-gathered V_high (OSFT_CACHE_V=1,
default off).  V_high is frozen, so the cache is exact.

Includes bench_v_proj.py benchmark and 10 new tests.
@coderabbitai
Copy link

coderabbitai bot commented Mar 6, 2026

📝 Walkthrough

Walkthrough

This pull request implements a switch from Gram matrix to factored form for V projection in OSFT, introduces optional caching of V_high matrices to avoid redundant all-gathers, and adds distributed benchmarking with comprehensive test coverage to validate the optimization.

Changes

Cohort / File(s) Summary
Benchmarking Infrastructure
benchmarks/bench_v_proj.py
New comprehensive distributed benchmark script measuring three V projection execution modes (Gram, Factored, Cached) across simulated FSDP2 sharding with warmup phase, correctness validation, and aggregate reporting over 224 configurations.
OSFT Core Caching & Projection
src/mini_trainer/osft_utils.py
Introduces optional V_high caching via OSFT_CACHE_V flag; refactors project_gradient_to_orthogonal_space to accept optional cache_holder; implements per-module _osft_v_high_full cache lifecycle with all-gather optimization; updates gradient projection logic and adds reset/cleanup paths.
Cache & Factored Form Tests
tests/test_osft.py
Comprehensive 11-test suite validating V_high cache population, correctness matching Gram form, cache size efficiency, orthogonality preservation, cache disabling, and lifecycle (reset/state_dict exclusion) behavior.

Sequence Diagram

sequenceDiagram
    participant Grad as Gradient<br/>Computation
    participant Proj as project_gradients
    participant Cache as Cache<br/>Holder
    participant Comm as Distributed<br/>Comm
    participant SVD as SVD<br/>Projector

    Grad->>Proj: trigger projection
    Proj->>Cache: check OSFT_CACHE_V & cache_holder
    alt Cache enabled & holder exists
        Cache-->>Cache: has _osft_v_high_full?
        alt Not cached
            Proj->>Comm: all_gather_into_tensor(V_high_full)
            Comm-->>Proj: V_high_full (k_high × M)
            Proj->>Cache: store _osft_v_high_full on module
        else Already cached
            Cache-->>Proj: retrieve _osft_v_high_full
        end
    else Cache disabled or no holder
        Proj->>Comm: all_gather_into_tensor(V_high_full)
        Comm-->>Proj: V_high_full
    end
    Proj->>SVD: project_gradient_to_orthogonal_space(cache_holder)
    SVD->>SVD: coeff = local_dV @ V_high_full.T
    SVD->>SVD: local_dV -= coeff @ V_high_full
    SVD-->>Proj: updated dV (local, no all-reduce)
    Proj->>Proj: log cache metrics
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • NikhilNayak-debug
  • RobotSail

Poem

🐰 A cache was born beneath the factored sky,
Where Gram's bulky matrix said goodbye.
All-gather whispers (seven times less!),
Cached V_high tensors pass every test—
Hops of speedup, bunny blessed! 🚀

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: replacing Gram matrix V projection with a factored form to reduce communication overhead.
Linked Issues check ✅ Passed All code requirements from issue #73 are met: factored form replaces Gram computation, all-gather replaces all-reduce, optional caching via OSFT_CACHE_V flag is implemented, and communication savings (k_high×M vs M×M) are achieved.
Out of Scope Changes check ✅ Passed All changes directly support the core objective: osft_utils.py implements factored projection and caching, bench_v_proj.py benchmarks the optimization, and tests validate correctness and cache behavior.
Docstring Coverage ✅ Passed Docstring coverage is 91.30% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
benchmarks/bench_v_proj.py (1)

22-35: Fail fast on non-divisible benchmark shapes.

k_high // P and k_low // P silently discard remainder rows, and the slice at Line 34 only covers P * local_k_high rows. If this helper gets reused with a non-divisible shape, it will publish timings for the wrong problem size instead of stopping early.

Possible fix
 def bench_target(name, k_high, k_low, M, P, dev, n_iters=100):
     """Benchmark one OSFT target shape.  Returns dict of timings."""
+    if k_high % P != 0 or k_low % P != 0:
+        raise ValueError(
+            f"k_high={k_high} and k_low={k_low} must both be divisible by P={P}"
+        )
     local_k_high = k_high // P
     local_k_low = k_low // P
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_v_proj.py` around lines 22 - 35, The function bench_target
silently truncates k_high and k_low by integer division into
local_k_high/local_k_low and then slices V_full by rank * local_k_high; add
explicit validation at the start of bench_target to check that k_high % P == 0
and k_low % P == 0 (or otherwise that both divide evenly by P), and raise a
clear ValueError if not so that the benchmark fails fast; update the error
message to reference the offending values (k_high, k_low, P) so callers can
correct input shapes and avoid publishing wrong timings caused by the
partial-row truncation in the local_V slice.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/mini_trainer/osft_utils.py`:
- Around line 605-611: The current all_gather_into_tensor assumes equal per-rank
rows for local_V_high; pad local_V_high to target_rows =
math.ceil(svd_dict["rank_high"] / world_size) before calling
dist.all_gather_into_tensor so each rank provides the same shape (use
torch.nn.functional.pad or torch.zeros on the same device/dtype and
concatenate), perform the all_gather_into_tensor into V_high_full sized
(target_rows * world_size, cols), then slice V_high_full[:svd_dict["rank_high"],
:] to restore the original total rank and assign back into the downstream
variable; ensure import math is present at file top if missing.

In `@tests/test_osft.py`:
- Around line 1626-1640: In test_cache_disabled_by_default, pin the module-level
OSFT_CACHE_V constant to False so the test is hermetic: before calling
self._create_simple_osft_model(), use the pytest monkeypatch fixture to
monkeypatch.setattr(<module_that_defines_OSFT_CACHE_V>, "OSFT_CACHE_V", False)
(or equivalent setattr) so the constant is forced off for the duration of the
test; rely on monkeypatch to automatically restore the original value after the
test and keep the rest of the test logic (model.train(), project_gradients(),
and the assertions) unchanged.

---

Nitpick comments:
In `@benchmarks/bench_v_proj.py`:
- Around line 22-35: The function bench_target silently truncates k_high and
k_low by integer division into local_k_high/local_k_low and then slices V_full
by rank * local_k_high; add explicit validation at the start of bench_target to
check that k_high % P == 0 and k_low % P == 0 (or otherwise that both divide
evenly by P), and raise a clear ValueError if not so that the benchmark fails
fast; update the error message to reference the offending values (k_high, k_low,
P) so callers can correct input shapes and avoid publishing wrong timings caused
by the partial-row truncation in the local_V slice.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 20a6d6ef-db54-486a-a9a4-6bc05401b8a6

📥 Commits

Reviewing files that changed from the base of the PR and between 4d6dc87 and 0e9b3fe.

📒 Files selected for processing (3)
  • benchmarks/bench_v_proj.py
  • src/mini_trainer/osft_utils.py
  • tests/test_osft.py

Comment on lines +605 to +611
if dist.is_initialized() and dist.get_world_size() > 1:
world_size = dist.get_world_size()
V_high_full = torch.empty(
local_V_high.shape[0] * world_size, local_V_high.shape[1],
dtype=local_V_high.dtype, device=local_V_high.device,
)
dist.all_gather_into_tensor(V_high_full, local_V_high)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's look at the file structure and the specific lines mentioned
cd /tmp/repository 2>/dev/null || cd . 
wc -l src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 118


🏁 Script executed:

# Read the context around lines 605-611
sed -n '590,625p' src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 1780


🏁 Script executed:

# Find where local_V_high is computed and used
rg -n "local_V_high" src/mini_trainer/osft_utils.py -B 3 -A 3

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 1176


🏁 Script executed:

# Find rank_high and how it's set
rg -n "rank_high" src/mini_trainer/osft_utils.py -B 2 -A 2

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 3148


🏁 Script executed:

# Check the function signature and beginning to understand the data flow
sed -n '550,650p' src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 4562


🏁 Script executed:

# Check how V_high is created and if it's FSDP-distributed
rg -n "V_high" src/mini_trainer/osft_utils.py | head -50

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 3640


🏁 Script executed:

# Look for FSDP or distributed tensor patterns
rg -n "FSDP\|fully_shard\|to_local" src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Understand how SVD decomposition creates V_high and its shape
sed -n '450,520p' src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 3270


🏁 Script executed:

# Check if there's any padding or divisibility check already in place
rg -n "ceil\|divisib\|pad\|even\|uneven" src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Search for the definition of the function containing this code
rg -n "def project_gradient_to_orthogonal_space" src/mini_trainer/osft_utils.py -A 100 | head -120

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 5000


🏁 Script executed:

# Check how V_high is registered with FSDP2 or distributed tensor wrapping
rg -n "FSDP\|fully_shard\|distribute\|shard" src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Look for where osft_V_high is registered and how it's sharded
rg -n "osft_V_high\|register_parameter\|register_buffer" src/mini_trainer/osft_utils.py -B 2 -A 2

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Check if there's any test or example showing how rank_high varies
rg -n "rank_high\|rank_ratio" src/mini_trainer/osft_utils.py | head -30

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Search for auto_generate_target_osft_config to understand rank_ratio
rg -n "auto_generate_target_osft_config\|rank_ratio" src/mini_trainer/osft_utils.py -A 5

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Check torch.distributed documentation behavior for all_gather_into_tensor
# Let's see if there are any comments about shape expectations
grep -n "all_gather" src/mini_trainer/osft_utils.py -B 5 -A 5

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 697


🏁 Script executed:

# Let's search for where V_high gets registered with FSDP or how it's initialized
fd -name "*.py" src/mini_trainer | xargs grep -l "FSDP\|fully_shard" | head -5

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 313


🏁 Script executed:

# Check if there's any rank_ratio or top_k configuration that could make rank_high non-divisible
rg -n "rank_ratio\|top_k\|rank_high\|URR" src/mini_trainer/osft_utils.py -B 2 -A 2 | head -60

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Let's understand the actual shape of V_high before FSDP sharding
# Find where V_high enters svd_dict
sed -n '470,495p' src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 1179


🏁 Script executed:

# Search for any tests or usage examples that show different rank configurations
find . -name "*test*.py" -o -name "*example*.py" | xargs grep -l "rank_ratio\|rank_high\|OSFT" 2>/dev/null | head -3

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 208


🏁 Script executed:

# Check PyTorch's all_gather_into_tensor documentation or comments in code
python3 << 'EOF'
import torch
import torch.distributed as dist

# Verify all_gather_into_tensor behavior with unequal input shapes
help_text = """
torch.distributed.all_gather_into_tensor gathers tensors from all ranks into a single output tensor.
Key constraint: All input tensors must have the SAME shape on all ranks.
If shapes differ, the operation will fail or produce incorrect results.
"""
print(help_text)
EOF

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 200


🏁 Script executed:

# Check the test files for OSFT usage and rank configurations
cat ./regression_tests/test_osft_orthogonalization.py | head -100

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 3548


🏁 Script executed:

# Look for imports or setup of FSDP in the test files
grep -n "FSDP\|fully_shard\|world_size\|rank" ./regression_tests/test_osft_orthogonalization.py | head -20

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 982


🏁 Script executed:

# Check if there's documentation about how V_high sharding works
rg -n "dim-0 shard\|k_high\|world_size" src/mini_trainer/osft_utils.py -B 2 -A 2

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Check how rank_high values are computed in tests
grep -n "rank_high\|top_k\|rank_ratio" ./regression_tests/test_osft_*.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 590


🏁 Script executed:

# Look for configuration objects or rank calculations
rg -n "class.*Config\|def.*rank\|rank.*=" src/mini_trainer/osft_utils.py | head -30

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Check how rank_high is calculated from rank_ratio in setup_model
rg -n "rank_ratio\|top_k" src/mini_trainer/ -B 3 -A 3

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Search for setup_model_for_training to understand rank_ratio usage
find . -name "setup_model_for_training.py" -o -name "*setup*.py" | head -3

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 167


🏁 Script executed:

# Let's look for how top_k is determined from rank_ratio
grep -r "rank_ratio\|top_k" --include="*.py" src/mini_trainer/ | head -20

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 1783


🏁 Script executed:

# Check the actual setup_model function
fd "setup_model" src/mini_trainer -x cat {}

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 47667


Add padding to local_V_high before all_gather_into_tensor() to handle uneven rank sharding.

The allocation at line 608 assumes every rank has local_V_high.shape[0] == ceil(rank_high / world_size), but this is only guaranteed when rank_high % world_size == 0. Since rank_high is computed from rank_ratio using int(np.floor(...)), non-divisible cases are common. With uneven sharding, all_gather_into_tensor() will fail or rebuild a truncated tensor.

Pad each local shard to ceil(rank_high / world_size) rows and slice the result back to svd_dict["rank_high"] after gathering.

Suggested fix
+                full_k_high = svd_dict["rank_high"]
                 world_size = dist.get_world_size()
-                V_high_full = torch.empty(
-                    local_V_high.shape[0] * world_size, local_V_high.shape[1],
-                    dtype=local_V_high.dtype, device=local_V_high.device,
-                )
-                dist.all_gather_into_tensor(V_high_full, local_V_high)
+                rows_per_rank = math.ceil(full_k_high / world_size)
+                padded_local_V_high = torch.zeros(
+                    rows_per_rank, local_V_high.shape[1],
+                    dtype=local_V_high.dtype, device=local_V_high.device,
+                )
+                padded_local_V_high[: local_V_high.shape[0]].copy_(local_V_high)
+                gathered = torch.empty(
+                    rows_per_rank * world_size, local_V_high.shape[1],
+                    dtype=local_V_high.dtype, device=local_V_high.device,
+                )
+                dist.all_gather_into_tensor(gathered, padded_local_V_high)
+                V_high_full = gathered[:full_k_high]

Add import math at the top of the file if not already present.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/mini_trainer/osft_utils.py` around lines 605 - 611, The current
all_gather_into_tensor assumes equal per-rank rows for local_V_high; pad
local_V_high to target_rows = math.ceil(svd_dict["rank_high"] / world_size)
before calling dist.all_gather_into_tensor so each rank provides the same shape
(use torch.nn.functional.pad or torch.zeros on the same device/dtype and
concatenate), perform the all_gather_into_tensor into V_high_full sized
(target_rows * world_size, cols), then slice V_high_full[:svd_dict["rank_high"],
:] to restore the original total rank and assign back into the downstream
variable; ensure import math is present at file top if missing.

Comment on lines +1626 to +1640
def test_cache_disabled_by_default(self):
"""Cache should not be populated when OSFT_CACHE_V is at its default (off)."""

model = self._create_simple_osft_model()
model.train()

x = torch.randn(4, 32)
loss = model.linear(x).pow(2).sum()
loss.backward()
model.project_gradients()

for module in model.modules():
assert not hasattr(module, "_osft_v_high_full"), (
"Cache should not be populated when OSFT_CACHE_V is disabled"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Make this default-off test hermetic.

OSFT_CACHE_V is read at import time, so this test will flip if the suite is launched with OSFT_CACHE_V=1 in the environment. Pin the module constant to False here instead of relying on ambient process state.

Possible fix
-    def test_cache_disabled_by_default(self):
+    def test_cache_disabled_by_default(self, monkeypatch):
         """Cache should not be populated when OSFT_CACHE_V is at its default (off)."""
+        monkeypatch.setattr(osft_module, "OSFT_CACHE_V", False)

         model = self._create_simple_osft_model()
         model.train()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/test_osft.py` around lines 1626 - 1640, In
test_cache_disabled_by_default, pin the module-level OSFT_CACHE_V constant to
False so the test is hermetic: before calling self._create_simple_osft_model(),
use the pytest monkeypatch fixture to
monkeypatch.setattr(<module_that_defines_OSFT_CACHE_V>, "OSFT_CACHE_V", False)
(or equivalent setattr) so the constant is forced off for the duration of the
test; rely on monkeypatch to automatically restore the original value after the
test and keep the rest of the test logic (model.train(), project_gradients(),
and the assertions) unchanged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Replace Gram matrix V projection with factored form

1 participant