Skip to content

Commit 0e9b3fe

Browse files
committed
Replace Gram matrix V projection with factored form
Replace dV -= dV @ (V_high^T @ V_high) with the factored form dV -= (dV @ V_high^T) @ V_high. Under FSDP2 this replaces an (M, M) all-reduce with a (k_high, M) all-gather — M/k_high fewer bytes (2x for square weights, 7x for down_proj) and all-gather is cheaper per byte. Add opt-in caching of the all-gathered V_high (OSFT_CACHE_V=1, default off). V_high is frozen, so the cache is exact. Includes bench_v_proj.py benchmark and 10 new tests.
1 parent 4d6dc87 commit 0e9b3fe

File tree

3 files changed

+580
-13
lines changed

3 files changed

+580
-13
lines changed

benchmarks/bench_v_proj.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
"""Benchmark: Gram vs factored V projection under FSDP2-style sharding.
2+
3+
Measures the V projection step of project_gradient_to_orthogonal_space()
4+
with Llama-8B shapes. Compares three modes:
5+
6+
Gram: all-reduce (M, M) Gram matrix, then dV -= dV @ G
7+
Factored: all-gather (k_high, M) V_high, then dV -= (dV @ V^T) @ V
8+
Cached: reuse V_high from prior step (no comm), same matmuls
9+
10+
Usage:
11+
python bench_v_proj.py # requires 2 GPUs
12+
"""
13+
14+
import os
15+
import time
16+
17+
import torch
18+
import torch.distributed as dist
19+
import torch.multiprocessing as mp
20+
21+
22+
def bench_target(name, k_high, k_low, M, P, dev, n_iters=100):
23+
"""Benchmark one OSFT target shape. Returns dict of timings."""
24+
local_k_high = k_high // P
25+
local_k_low = k_low // P
26+
27+
# V_high has orthonormal rows (from SVD) — create via QR in float32
28+
V_full_f32 = torch.linalg.qr(
29+
torch.randn(M, k_high, device=dev)
30+
)[0].T[:k_high] # (k_high, M) orthonormal rows
31+
V_full = V_full_f32.to(torch.bfloat16)
32+
33+
# Each rank's shard of V_high
34+
local_V = V_full[rank * local_k_high:(rank + 1) * local_k_high].contiguous()
35+
local_dV = torch.randn(local_k_low, M, device=dev, dtype=torch.bfloat16)
36+
dV_save = local_dV.clone()
37+
38+
# Pre-allocate for all-gather
39+
V_ag = torch.empty_like(V_full)
40+
41+
# Warmup
42+
for _ in range(10):
43+
local_dV.copy_(dV_save)
44+
G = torch.mm(local_V.T, local_V)
45+
dist.all_reduce(G)
46+
local_dV.add_(torch.mm(local_dV, G), alpha=-1.0)
47+
48+
local_dV.copy_(dV_save)
49+
dist.all_gather_into_tensor(V_ag, local_V)
50+
coeff = torch.mm(local_dV, V_ag.T)
51+
local_dV.addmm_(coeff, V_ag, alpha=-1.0)
52+
torch.cuda.synchronize()
53+
dist.barrier()
54+
55+
# --- Gram form: all-reduce (M, M), then dV -= dV @ G ---
56+
torch.cuda.synchronize()
57+
dist.barrier()
58+
t0 = time.perf_counter()
59+
for _ in range(n_iters):
60+
local_dV.copy_(dV_save)
61+
G = torch.mm(local_V.T, local_V)
62+
dist.all_reduce(G)
63+
local_dV.add_(torch.mm(local_dV, G), alpha=-1.0)
64+
torch.cuda.synchronize()
65+
gram_ms = (time.perf_counter() - t0) / n_iters * 1000
66+
67+
# --- Factored form: all-gather (k_high, M), then two matmuls ---
68+
torch.cuda.synchronize()
69+
dist.barrier()
70+
t0 = time.perf_counter()
71+
for _ in range(n_iters):
72+
local_dV.copy_(dV_save)
73+
dist.all_gather_into_tensor(V_ag, local_V)
74+
coeff = torch.mm(local_dV, V_ag.T)
75+
local_dV.addmm_(coeff, V_ag, alpha=-1.0)
76+
torch.cuda.synchronize()
77+
fact_ms = (time.perf_counter() - t0) / n_iters * 1000
78+
79+
# --- Cached: no comm, just matmuls ---
80+
torch.cuda.synchronize()
81+
dist.barrier()
82+
t0 = time.perf_counter()
83+
for _ in range(n_iters):
84+
local_dV.copy_(dV_save)
85+
coeff = torch.mm(local_dV, V_ag.T)
86+
local_dV.addmm_(coeff, V_ag, alpha=-1.0)
87+
torch.cuda.synchronize()
88+
cached_ms = (time.perf_counter() - t0) / n_iters * 1000
89+
90+
# Correctness: verify in float32 (bf16 accumulation differs by ~5%)
91+
dV_f32 = dV_save.float()
92+
V_f32 = V_full_f32
93+
gram_f32 = dV_f32 - dV_f32 @ (V_f32.T @ V_f32)
94+
fact_f32 = dV_f32 - (dV_f32 @ V_f32.T) @ V_f32
95+
max_diff = (fact_f32 - gram_f32).abs().max().item()
96+
97+
return {
98+
"name": name, "k_high": k_high, "M": M,
99+
"gram_bytes": M * M * 2, "fact_bytes": k_high * M * 2,
100+
"gram_ms": gram_ms, "fact_ms": fact_ms, "cached_ms": cached_ms,
101+
"max_diff": max_diff,
102+
}
103+
104+
105+
# Global for shard indexing
106+
rank = 0
107+
108+
109+
def run(rank_, world_size):
110+
global rank
111+
rank = rank_
112+
os.environ.update(MASTER_ADDR="localhost", MASTER_PORT="29500",
113+
RANK=str(rank), WORLD_SIZE=str(world_size))
114+
dist.init_process_group("nccl")
115+
torch.cuda.set_device(rank)
116+
torch.manual_seed(42)
117+
dev = f"cuda:{rank}"
118+
119+
# Llama-8B shapes, URR=0.5
120+
targets = [
121+
# name, k_high, k_low, M
122+
("down_proj", 2048, 2048, 14336), # (4096, 14336) → ratio = 7x
123+
("q_proj", 2048, 2048, 4096), # (4096, 4096) → ratio = 2x
124+
]
125+
126+
results = []
127+
for name, k_high, k_low, M in targets:
128+
r = bench_target(name, k_high, k_low, M, world_size, dev)
129+
results.append(r)
130+
131+
if rank == 0:
132+
gpu = torch.cuda.get_device_name(0)
133+
print(f"V projection benchmark — {world_size}× {gpu}, Llama-8B shapes, bf16")
134+
print("=" * 70)
135+
print()
136+
for r in results:
137+
ratio = r["M"] / r["k_high"]
138+
print(f" {r['name']:10s} V_high ({r['k_high']}, {r['M']})"
139+
f" M/k_high = {ratio:.0f}x")
140+
print(f" {'':10s} Gram Factored Cached")
141+
print(f" {'':10s} ---------- ---------- ----------")
142+
print(f" {'comm':10s} all-reduce all-gather none")
143+
print(f" {'bytes':10s} {r['gram_bytes']/1e6:>7.0f} MB {r['fact_bytes']/1e6:>7.0f} MB {0:>7.0f} MB")
144+
print(f" {'time':10s} {r['gram_ms']:>7.2f} ms {r['fact_ms']:>7.2f} ms {r['cached_ms']:>7.2f} ms")
145+
print(f" {'speedup':10s} {'—':>10s} {r['gram_ms']/r['fact_ms']:>7.1f}x {r['gram_ms']/r['cached_ms']:>7.1f}x")
146+
ok = "ok" if r["max_diff"] < 1e-4 else "FAIL"
147+
print(f" {'correct':10s} — f32 max diff = {r['max_diff']:.1e} ({ok})")
148+
print()
149+
150+
# Aggregate: 32 layers × 7 targets (4 q/k/v/o + 2 gate/up + 1 down)
151+
sq = next(r for r in results if r["name"] == "q_proj")
152+
dp = next(r for r in results if r["name"] == "down_proj")
153+
gram_tot = (6 * sq["gram_ms"] + dp["gram_ms"]) * 32
154+
fact_tot = (6 * sq["fact_ms"] + dp["fact_ms"]) * 32
155+
cached_tot = (6 * sq["cached_ms"] + dp["cached_ms"]) * 32
156+
print(f" Aggregate (32 layers × 7 targets = 224):")
157+
print(f" {'':10s} Gram Factored Cached")
158+
print(f" {'total':10s} {gram_tot:>7.0f} ms {fact_tot:>7.0f} ms {cached_tot:>7.0f} ms")
159+
print(f" {'speedup':10s} {'—':>10s} {gram_tot/fact_tot:>7.1f}x {gram_tot/cached_tot:>7.1f}x")
160+
print()
161+
162+
dist.destroy_process_group()
163+
164+
165+
if __name__ == "__main__":
166+
mp.spawn(run, args=(2,), nprocs=2, join=True)

src/mini_trainer/osft_utils.py

Lines changed: 93 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,20 @@
3030
os.getenv("OSFT_CACHE_CLEAR_INTERVAL", 5)
3131
) # Clear GPU cache every N parameters during matrix reconstruction
3232

33+
# Opt-in: cache the all-gathered (full) V_high to avoid repeating the
34+
# all-gather on every step. V_high is frozen (requires_grad=False), so the
35+
# cache is exact.
36+
#
37+
# V projection uses the factored form: dV -= (dV @ V_high^T) @ V_high.
38+
# Under FSDP2, V_high is dim-0 sharded, so the factored form requires an
39+
# all-gather of V_high (k_high × M). Caching stores the all-gathered result.
40+
#
41+
# Default OFF because the cache is REPLICATED on every FSDP2 rank (not
42+
# sharded), adding ~5.1 GB per rank for Llama-8B (all 224 targets in bf16).
43+
# For Llama-70B+ the cache exceeds GPU memory. Set to "1" only after
44+
# confirming sufficient memory headroom.
45+
OSFT_CACHE_V = os.getenv("OSFT_CACHE_V", "0") == "1"
46+
3347

3448
def _supports_use_batch() -> bool:
3549
"""Check if torch.distributed send/recv_object_list support the use_batch parameter (PyTorch 2.9+)."""
@@ -514,11 +528,34 @@ def reconstruct_weight_matrix(
514528
return reconstructed
515529

516530

517-
def project_gradient_to_orthogonal_space(svd_dict: SVDDecompositionDict):
531+
def project_gradient_to_orthogonal_space(
532+
svd_dict: SVDDecompositionDict,
533+
cache_holder: "nn.Module | None" = None,
534+
):
518535
"""
519-
Projects the gradient of the low-rank parameters (U_low, V_low) to be orthogonal to the frozen high-rank subspace.
536+
Projects the gradient of the low-rank parameters (U_low, V_low) to be
537+
orthogonal to the frozen high-rank subspace.
538+
539+
Both projections use the factored form:
540+
dU -= U_high @ (U_high^T @ dU)
541+
dV -= (dV @ V_high^T) @ V_high
520542
521-
This step ensures that learning new tasks does not interfere with previously learned representations by enforcing an orthogonality constraint.
543+
Under FSDP2, U_high is dim-0 sharded along N (the large dimension), so
544+
U_high^T @ dU contracts over the sharded dim → partial sum → all-reduce
545+
of a small (k_high, k_low) matrix.
546+
547+
V_high is dim-0 sharded along k_high (the small dimension), so the
548+
factored form requires an all-gather of V_high to get the full
549+
(k_high, M) tensor. This is M/k_high fewer bytes than the Gram
550+
matrix all-reduce (k_high × M vs M × M) — 2x for square weights,
551+
7x for down_proj where k_high = min(N, M) × (1 - URR).
552+
553+
Args:
554+
svd_dict: Dictionary containing the SVD decomposition components.
555+
cache_holder: Optional module on which to cache the all-gathered
556+
V_high. V_high is frozen, so the cache is exact. When provided
557+
and OSFT_CACHE_V is enabled, V_high is all-gathered once and
558+
reused on subsequent steps.
522559
523560
TODO(osilkin): Add mixed-precision gradients here
524561
"""
@@ -551,21 +588,37 @@ def project_gradient_to_orthogonal_space(svd_dict: SVDDecompositionDict):
551588
else:
552589
dU.copy_(local_dU)
553590

554-
# Repeat projection for V_low using V_high
591+
# Project V_low gradients: dV -= (dV @ V_high^T) @ V_high
592+
# All-gather V_high from FSDP2 shards (or use cache) — see docstring for cost analysis.
555593
if svd_dict["V_low"].grad is not None:
556594
dV = svd_dict["V_low"].grad
557595
local_V_high = getattr(V_high, "to_local", lambda: V_high)()
558596
local_dV = getattr(dV, "to_local", lambda: dV)()
559597

560-
# Compute Gram matrix G = V_high^T @ V_high for global projection across row-sharded V_high
561-
# Assumes column dimension is consistent across ranks (row sharding over singular vectors)
562-
G_local = torch.mm(local_V_high.transpose(0, 1), local_V_high)
563-
if dist.is_initialized() and dist.get_world_size() > 1:
564-
dist.all_reduce(G_local, op=dist.ReduceOp.SUM)
598+
# V_high is frozen — reuse cached all-gathered tensor when available.
599+
can_cache = OSFT_CACHE_V and cache_holder is not None
600+
cached = getattr(cache_holder, "_osft_v_high_full", None) if can_cache else None
565601

566-
# Apply projection: dV = dV - dV @ G (use local shard of dV)
567-
update = torch.mm(local_dV, G_local)
568-
local_dV.add_(update, alpha=-1.0)
602+
if cached is not None:
603+
V_high_full = cached
604+
else:
605+
if dist.is_initialized() and dist.get_world_size() > 1:
606+
world_size = dist.get_world_size()
607+
V_high_full = torch.empty(
608+
local_V_high.shape[0] * world_size, local_V_high.shape[1],
609+
dtype=local_V_high.dtype, device=local_V_high.device,
610+
)
611+
dist.all_gather_into_tensor(V_high_full, local_V_high)
612+
else:
613+
V_high_full = local_V_high
614+
if can_cache:
615+
# .detach() ensures plain Tensor, not nn.Parameter — avoids
616+
# nn.Module.__setattr__ registering it into state_dict.
617+
cache_holder._osft_v_high_full = V_high_full.detach()
618+
619+
# Two local matmuls — no (M, M) intermediate
620+
coeff = torch.mm(local_dV, V_high_full.transpose(0, 1)) # (k_low/P, k_high)
621+
local_dV.addmm_(coeff, V_high_full, alpha=-1.0) # (k_low/P, M)
569622

570623
if hasattr(dV, "_local_tensor"):
571624
dV._local_tensor.copy_(local_dV)
@@ -966,6 +1019,10 @@ def _reset_osft_metadata(self):
9661019
self.osft_paramspec_registry = {}
9671020
self._osft_handles = {}
9681021
self.osft_params = {}
1022+
# Clear any cached all-gathered V_high — V_high changes on reinit.
1023+
for module in self.modules():
1024+
if hasattr(module, "_osft_v_high_full"):
1025+
del module._osft_v_high_full
9691026

9701027
@staticmethod
9711028
def _load_non_distributed(
@@ -1982,7 +2039,14 @@ def project_gradients(self):
19822039
with the high-rank subspace encoding prior task knowledge.
19832040
19842041
This method should be called after backpropagation and before optimizer step.
2042+
2043+
When ``OSFT_CACHE_V=1`` is set, the all-gathered V_high tensor is
2044+
cached on each module after the first step. V_high is frozen, so
2045+
the cache is exact. This eliminates per-step V all-gather traffic.
2046+
Default is off because the cache is replicated on every FSDP2 rank
2047+
(~5.1 GB for Llama-8B, infeasible for 70B+).
19852048
"""
2049+
caches_populated_this_call = 0
19862050
for module in self.modules():
19872051
# Only process real OSFT-attached linear modules, not the top-level container
19882052
if (
@@ -1991,11 +2055,27 @@ def project_gradients(self):
19912055
and hasattr(module, "osft_S_high")
19922056
and hasattr(module, "osft_V_high")
19932057
):
2058+
had_cache = hasattr(module, "_osft_v_high_full")
19942059
try:
19952060
svd_dict = self.get_svd_dict_for_module(module)
19962061
except ValueError as err:
19972062
raise ValueError(f"error in projecting gradients for module: {module}") from err
1998-
project_gradient_to_orthogonal_space(svd_dict)
2063+
project_gradient_to_orthogonal_space(svd_dict, cache_holder=module)
2064+
if not had_cache and hasattr(module, "_osft_v_high_full"):
2065+
caches_populated_this_call += 1
2066+
2067+
if caches_populated_this_call > 0:
2068+
total_bytes = sum(
2069+
module._osft_v_high_full.nelement() * module._osft_v_high_full.element_size()
2070+
for module in self.modules()
2071+
if hasattr(module, "_osft_v_high_full")
2072+
)
2073+
log_rank_0(
2074+
f"Cached {caches_populated_this_call} V_high tensors "
2075+
f"({total_bytes / 1e9:.2f} GB). "
2076+
f"Subsequent steps skip V all-gathers. "
2077+
f"Set OSFT_CACHE_V=0 to disable."
2078+
)
19992079

20002080
def prepare_state_dict_for_save(self, state_dict):
20012081
"""Reconstruct dense weights into ``state_dict`` for saving with memory optimization."""

0 commit comments

Comments
 (0)