Replace Gram matrix V projection with factored form#74
Replace Gram matrix V projection with factored form#74stmcgovern wants to merge 1 commit intoRed-Hat-AI-Innovation-Team:mainfrom
Conversation
Replace dV -= dV @ (V_high^T @ V_high) with the factored form dV -= (dV @ V_high^T) @ V_high. Under FSDP2 this replaces an (M, M) all-reduce with a (k_high, M) all-gather — M/k_high fewer bytes (2x for square weights, 7x for down_proj) and all-gather is cheaper per byte. Add opt-in caching of the all-gathered V_high (OSFT_CACHE_V=1, default off). V_high is frozen, so the cache is exact. Includes bench_v_proj.py benchmark and 10 new tests.
📝 WalkthroughWalkthroughThis pull request implements a switch from Gram matrix to factored form for V projection in OSFT, introduces optional caching of V_high matrices to avoid redundant all-gathers, and adds distributed benchmarking with comprehensive test coverage to validate the optimization. Changes
Sequence DiagramsequenceDiagram
participant Grad as Gradient<br/>Computation
participant Proj as project_gradients
participant Cache as Cache<br/>Holder
participant Comm as Distributed<br/>Comm
participant SVD as SVD<br/>Projector
Grad->>Proj: trigger projection
Proj->>Cache: check OSFT_CACHE_V & cache_holder
alt Cache enabled & holder exists
Cache-->>Cache: has _osft_v_high_full?
alt Not cached
Proj->>Comm: all_gather_into_tensor(V_high_full)
Comm-->>Proj: V_high_full (k_high × M)
Proj->>Cache: store _osft_v_high_full on module
else Already cached
Cache-->>Proj: retrieve _osft_v_high_full
end
else Cache disabled or no holder
Proj->>Comm: all_gather_into_tensor(V_high_full)
Comm-->>Proj: V_high_full
end
Proj->>SVD: project_gradient_to_orthogonal_space(cache_holder)
SVD->>SVD: coeff = local_dV @ V_high_full.T
SVD->>SVD: local_dV -= coeff @ V_high_full
SVD-->>Proj: updated dV (local, no all-reduce)
Proj->>Proj: log cache metrics
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
benchmarks/bench_v_proj.py (1)
22-35: Fail fast on non-divisible benchmark shapes.
k_high // Pandk_low // Psilently discard remainder rows, and the slice at Line 34 only coversP * local_k_highrows. If this helper gets reused with a non-divisible shape, it will publish timings for the wrong problem size instead of stopping early.Possible fix
def bench_target(name, k_high, k_low, M, P, dev, n_iters=100): """Benchmark one OSFT target shape. Returns dict of timings.""" + if k_high % P != 0 or k_low % P != 0: + raise ValueError( + f"k_high={k_high} and k_low={k_low} must both be divisible by P={P}" + ) local_k_high = k_high // P local_k_low = k_low // P🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@benchmarks/bench_v_proj.py` around lines 22 - 35, The function bench_target silently truncates k_high and k_low by integer division into local_k_high/local_k_low and then slices V_full by rank * local_k_high; add explicit validation at the start of bench_target to check that k_high % P == 0 and k_low % P == 0 (or otherwise that both divide evenly by P), and raise a clear ValueError if not so that the benchmark fails fast; update the error message to reference the offending values (k_high, k_low, P) so callers can correct input shapes and avoid publishing wrong timings caused by the partial-row truncation in the local_V slice.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/mini_trainer/osft_utils.py`:
- Around line 605-611: The current all_gather_into_tensor assumes equal per-rank
rows for local_V_high; pad local_V_high to target_rows =
math.ceil(svd_dict["rank_high"] / world_size) before calling
dist.all_gather_into_tensor so each rank provides the same shape (use
torch.nn.functional.pad or torch.zeros on the same device/dtype and
concatenate), perform the all_gather_into_tensor into V_high_full sized
(target_rows * world_size, cols), then slice V_high_full[:svd_dict["rank_high"],
:] to restore the original total rank and assign back into the downstream
variable; ensure import math is present at file top if missing.
In `@tests/test_osft.py`:
- Around line 1626-1640: In test_cache_disabled_by_default, pin the module-level
OSFT_CACHE_V constant to False so the test is hermetic: before calling
self._create_simple_osft_model(), use the pytest monkeypatch fixture to
monkeypatch.setattr(<module_that_defines_OSFT_CACHE_V>, "OSFT_CACHE_V", False)
(or equivalent setattr) so the constant is forced off for the duration of the
test; rely on monkeypatch to automatically restore the original value after the
test and keep the rest of the test logic (model.train(), project_gradients(),
and the assertions) unchanged.
---
Nitpick comments:
In `@benchmarks/bench_v_proj.py`:
- Around line 22-35: The function bench_target silently truncates k_high and
k_low by integer division into local_k_high/local_k_low and then slices V_full
by rank * local_k_high; add explicit validation at the start of bench_target to
check that k_high % P == 0 and k_low % P == 0 (or otherwise that both divide
evenly by P), and raise a clear ValueError if not so that the benchmark fails
fast; update the error message to reference the offending values (k_high, k_low,
P) so callers can correct input shapes and avoid publishing wrong timings caused
by the partial-row truncation in the local_V slice.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 20a6d6ef-db54-486a-a9a4-6bc05401b8a6
📒 Files selected for processing (3)
benchmarks/bench_v_proj.pysrc/mini_trainer/osft_utils.pytests/test_osft.py
| if dist.is_initialized() and dist.get_world_size() > 1: | ||
| world_size = dist.get_world_size() | ||
| V_high_full = torch.empty( | ||
| local_V_high.shape[0] * world_size, local_V_high.shape[1], | ||
| dtype=local_V_high.dtype, device=local_V_high.device, | ||
| ) | ||
| dist.all_gather_into_tensor(V_high_full, local_V_high) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, let's look at the file structure and the specific lines mentioned
cd /tmp/repository 2>/dev/null || cd .
wc -l src/mini_trainer/osft_utils.pyRepository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 118
🏁 Script executed:
# Read the context around lines 605-611
sed -n '590,625p' src/mini_trainer/osft_utils.pyRepository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 1780
🏁 Script executed:
# Find where local_V_high is computed and used
rg -n "local_V_high" src/mini_trainer/osft_utils.py -B 3 -A 3Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 1176
🏁 Script executed:
# Find rank_high and how it's set
rg -n "rank_high" src/mini_trainer/osft_utils.py -B 2 -A 2Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 3148
🏁 Script executed:
# Check the function signature and beginning to understand the data flow
sed -n '550,650p' src/mini_trainer/osft_utils.pyRepository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 4562
🏁 Script executed:
# Check how V_high is created and if it's FSDP-distributed
rg -n "V_high" src/mini_trainer/osft_utils.py | head -50Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 3640
🏁 Script executed:
# Look for FSDP or distributed tensor patterns
rg -n "FSDP\|fully_shard\|to_local" src/mini_trainer/osft_utils.pyRepository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 65
🏁 Script executed:
# Understand how SVD decomposition creates V_high and its shape
sed -n '450,520p' src/mini_trainer/osft_utils.pyRepository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 3270
🏁 Script executed:
# Check if there's any padding or divisibility check already in place
rg -n "ceil\|divisib\|pad\|even\|uneven" src/mini_trainer/osft_utils.pyRepository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 65
🏁 Script executed:
# Search for the definition of the function containing this code
rg -n "def project_gradient_to_orthogonal_space" src/mini_trainer/osft_utils.py -A 100 | head -120Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 5000
🏁 Script executed:
# Check how V_high is registered with FSDP2 or distributed tensor wrapping
rg -n "FSDP\|fully_shard\|distribute\|shard" src/mini_trainer/osft_utils.pyRepository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 65
🏁 Script executed:
# Look for where osft_V_high is registered and how it's sharded
rg -n "osft_V_high\|register_parameter\|register_buffer" src/mini_trainer/osft_utils.py -B 2 -A 2Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 65
🏁 Script executed:
# Check if there's any test or example showing how rank_high varies
rg -n "rank_high\|rank_ratio" src/mini_trainer/osft_utils.py | head -30Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 65
🏁 Script executed:
# Search for auto_generate_target_osft_config to understand rank_ratio
rg -n "auto_generate_target_osft_config\|rank_ratio" src/mini_trainer/osft_utils.py -A 5Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 65
🏁 Script executed:
# Check torch.distributed documentation behavior for all_gather_into_tensor
# Let's see if there are any comments about shape expectations
grep -n "all_gather" src/mini_trainer/osft_utils.py -B 5 -A 5Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 697
🏁 Script executed:
# Let's search for where V_high gets registered with FSDP or how it's initialized
fd -name "*.py" src/mini_trainer | xargs grep -l "FSDP\|fully_shard" | head -5Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 313
🏁 Script executed:
# Check if there's any rank_ratio or top_k configuration that could make rank_high non-divisible
rg -n "rank_ratio\|top_k\|rank_high\|URR" src/mini_trainer/osft_utils.py -B 2 -A 2 | head -60Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 65
🏁 Script executed:
# Let's understand the actual shape of V_high before FSDP sharding
# Find where V_high enters svd_dict
sed -n '470,495p' src/mini_trainer/osft_utils.pyRepository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 1179
🏁 Script executed:
# Search for any tests or usage examples that show different rank configurations
find . -name "*test*.py" -o -name "*example*.py" | xargs grep -l "rank_ratio\|rank_high\|OSFT" 2>/dev/null | head -3Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 208
🏁 Script executed:
# Check PyTorch's all_gather_into_tensor documentation or comments in code
python3 << 'EOF'
import torch
import torch.distributed as dist
# Verify all_gather_into_tensor behavior with unequal input shapes
help_text = """
torch.distributed.all_gather_into_tensor gathers tensors from all ranks into a single output tensor.
Key constraint: All input tensors must have the SAME shape on all ranks.
If shapes differ, the operation will fail or produce incorrect results.
"""
print(help_text)
EOFRepository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 200
🏁 Script executed:
# Check the test files for OSFT usage and rank configurations
cat ./regression_tests/test_osft_orthogonalization.py | head -100Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 3548
🏁 Script executed:
# Look for imports or setup of FSDP in the test files
grep -n "FSDP\|fully_shard\|world_size\|rank" ./regression_tests/test_osft_orthogonalization.py | head -20Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 982
🏁 Script executed:
# Check if there's documentation about how V_high sharding works
rg -n "dim-0 shard\|k_high\|world_size" src/mini_trainer/osft_utils.py -B 2 -A 2Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 65
🏁 Script executed:
# Check how rank_high values are computed in tests
grep -n "rank_high\|top_k\|rank_ratio" ./regression_tests/test_osft_*.pyRepository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 590
🏁 Script executed:
# Look for configuration objects or rank calculations
rg -n "class.*Config\|def.*rank\|rank.*=" src/mini_trainer/osft_utils.py | head -30Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 65
🏁 Script executed:
# Check how rank_high is calculated from rank_ratio in setup_model
rg -n "rank_ratio\|top_k" src/mini_trainer/ -B 3 -A 3Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 65
🏁 Script executed:
# Search for setup_model_for_training to understand rank_ratio usage
find . -name "setup_model_for_training.py" -o -name "*setup*.py" | head -3Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 167
🏁 Script executed:
# Let's look for how top_k is determined from rank_ratio
grep -r "rank_ratio\|top_k" --include="*.py" src/mini_trainer/ | head -20Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 1783
🏁 Script executed:
# Check the actual setup_model function
fd "setup_model" src/mini_trainer -x cat {}Repository: Red-Hat-AI-Innovation-Team/mini_trainer
Length of output: 47667
Add padding to local_V_high before all_gather_into_tensor() to handle uneven rank sharding.
The allocation at line 608 assumes every rank has local_V_high.shape[0] == ceil(rank_high / world_size), but this is only guaranteed when rank_high % world_size == 0. Since rank_high is computed from rank_ratio using int(np.floor(...)), non-divisible cases are common. With uneven sharding, all_gather_into_tensor() will fail or rebuild a truncated tensor.
Pad each local shard to ceil(rank_high / world_size) rows and slice the result back to svd_dict["rank_high"] after gathering.
Suggested fix
+ full_k_high = svd_dict["rank_high"]
world_size = dist.get_world_size()
- V_high_full = torch.empty(
- local_V_high.shape[0] * world_size, local_V_high.shape[1],
- dtype=local_V_high.dtype, device=local_V_high.device,
- )
- dist.all_gather_into_tensor(V_high_full, local_V_high)
+ rows_per_rank = math.ceil(full_k_high / world_size)
+ padded_local_V_high = torch.zeros(
+ rows_per_rank, local_V_high.shape[1],
+ dtype=local_V_high.dtype, device=local_V_high.device,
+ )
+ padded_local_V_high[: local_V_high.shape[0]].copy_(local_V_high)
+ gathered = torch.empty(
+ rows_per_rank * world_size, local_V_high.shape[1],
+ dtype=local_V_high.dtype, device=local_V_high.device,
+ )
+ dist.all_gather_into_tensor(gathered, padded_local_V_high)
+ V_high_full = gathered[:full_k_high]Add import math at the top of the file if not already present.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/mini_trainer/osft_utils.py` around lines 605 - 611, The current
all_gather_into_tensor assumes equal per-rank rows for local_V_high; pad
local_V_high to target_rows = math.ceil(svd_dict["rank_high"] / world_size)
before calling dist.all_gather_into_tensor so each rank provides the same shape
(use torch.nn.functional.pad or torch.zeros on the same device/dtype and
concatenate), perform the all_gather_into_tensor into V_high_full sized
(target_rows * world_size, cols), then slice V_high_full[:svd_dict["rank_high"],
:] to restore the original total rank and assign back into the downstream
variable; ensure import math is present at file top if missing.
| def test_cache_disabled_by_default(self): | ||
| """Cache should not be populated when OSFT_CACHE_V is at its default (off).""" | ||
|
|
||
| model = self._create_simple_osft_model() | ||
| model.train() | ||
|
|
||
| x = torch.randn(4, 32) | ||
| loss = model.linear(x).pow(2).sum() | ||
| loss.backward() | ||
| model.project_gradients() | ||
|
|
||
| for module in model.modules(): | ||
| assert not hasattr(module, "_osft_v_high_full"), ( | ||
| "Cache should not be populated when OSFT_CACHE_V is disabled" | ||
| ) |
There was a problem hiding this comment.
Make this default-off test hermetic.
OSFT_CACHE_V is read at import time, so this test will flip if the suite is launched with OSFT_CACHE_V=1 in the environment. Pin the module constant to False here instead of relying on ambient process state.
Possible fix
- def test_cache_disabled_by_default(self):
+ def test_cache_disabled_by_default(self, monkeypatch):
"""Cache should not be populated when OSFT_CACHE_V is at its default (off)."""
+ monkeypatch.setattr(osft_module, "OSFT_CACHE_V", False)
model = self._create_simple_osft_model()
model.train()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/test_osft.py` around lines 1626 - 1640, In
test_cache_disabled_by_default, pin the module-level OSFT_CACHE_V constant to
False so the test is hermetic: before calling self._create_simple_osft_model(),
use the pytest monkeypatch fixture to
monkeypatch.setattr(<module_that_defines_OSFT_CACHE_V>, "OSFT_CACHE_V", False)
(or equivalent setattr) so the constant is forced off for the duration of the
test; rely on monkeypatch to automatically restore the original value after the
test and keep the rest of the test logic (model.train(), project_gradients(),
and the assertions) unchanged.
Fixes #73
Replace dV -= dV @ (V_high^T @ V_high) with the factored form dV -= (dV @ V_high^T) @ V_high. Under FSDP2 this replaces an (M, M) all-reduce with a (k_high, M) all-gather — M/k_high fewer bytes (2x for square weights, 7x for down_proj) and all-gather is cheaper per byte.
Add opt-in caching of the all-gathered V_high (OSFT_CACHE_V=1, default off). V_high is frozen, so the cache is exact.
Includes bench_v_proj.py benchmark and 10 new tests.
Summary by CodeRabbit
New Features
Tests