diff --git a/benchmarks/bench_v_proj.py b/benchmarks/bench_v_proj.py new file mode 100644 index 0000000..1526826 --- /dev/null +++ b/benchmarks/bench_v_proj.py @@ -0,0 +1,166 @@ +"""Benchmark: Gram vs factored V projection under FSDP2-style sharding. + +Measures the V projection step of project_gradient_to_orthogonal_space() +with Llama-8B shapes. Compares three modes: + + Gram: all-reduce (M, M) Gram matrix, then dV -= dV @ G + Factored: all-gather (k_high, M) V_high, then dV -= (dV @ V^T) @ V + Cached: reuse V_high from prior step (no comm), same matmuls + +Usage: + python bench_v_proj.py # requires 2 GPUs +""" + +import os +import time + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def bench_target(name, k_high, k_low, M, P, dev, n_iters=100): + """Benchmark one OSFT target shape. Returns dict of timings.""" + local_k_high = k_high // P + local_k_low = k_low // P + + # V_high has orthonormal rows (from SVD) — create via QR in float32 + V_full_f32 = torch.linalg.qr( + torch.randn(M, k_high, device=dev) + )[0].T[:k_high] # (k_high, M) orthonormal rows + V_full = V_full_f32.to(torch.bfloat16) + + # Each rank's shard of V_high + local_V = V_full[rank * local_k_high:(rank + 1) * local_k_high].contiguous() + local_dV = torch.randn(local_k_low, M, device=dev, dtype=torch.bfloat16) + dV_save = local_dV.clone() + + # Pre-allocate for all-gather + V_ag = torch.empty_like(V_full) + + # Warmup + for _ in range(10): + local_dV.copy_(dV_save) + G = torch.mm(local_V.T, local_V) + dist.all_reduce(G) + local_dV.add_(torch.mm(local_dV, G), alpha=-1.0) + + local_dV.copy_(dV_save) + dist.all_gather_into_tensor(V_ag, local_V) + coeff = torch.mm(local_dV, V_ag.T) + local_dV.addmm_(coeff, V_ag, alpha=-1.0) + torch.cuda.synchronize() + dist.barrier() + + # --- Gram form: all-reduce (M, M), then dV -= dV @ G --- + torch.cuda.synchronize() + dist.barrier() + t0 = time.perf_counter() + for _ in range(n_iters): + local_dV.copy_(dV_save) + G = torch.mm(local_V.T, local_V) + dist.all_reduce(G) + local_dV.add_(torch.mm(local_dV, G), alpha=-1.0) + torch.cuda.synchronize() + gram_ms = (time.perf_counter() - t0) / n_iters * 1000 + + # --- Factored form: all-gather (k_high, M), then two matmuls --- + torch.cuda.synchronize() + dist.barrier() + t0 = time.perf_counter() + for _ in range(n_iters): + local_dV.copy_(dV_save) + dist.all_gather_into_tensor(V_ag, local_V) + coeff = torch.mm(local_dV, V_ag.T) + local_dV.addmm_(coeff, V_ag, alpha=-1.0) + torch.cuda.synchronize() + fact_ms = (time.perf_counter() - t0) / n_iters * 1000 + + # --- Cached: no comm, just matmuls --- + torch.cuda.synchronize() + dist.barrier() + t0 = time.perf_counter() + for _ in range(n_iters): + local_dV.copy_(dV_save) + coeff = torch.mm(local_dV, V_ag.T) + local_dV.addmm_(coeff, V_ag, alpha=-1.0) + torch.cuda.synchronize() + cached_ms = (time.perf_counter() - t0) / n_iters * 1000 + + # Correctness: verify in float32 (bf16 accumulation differs by ~5%) + dV_f32 = dV_save.float() + V_f32 = V_full_f32 + gram_f32 = dV_f32 - dV_f32 @ (V_f32.T @ V_f32) + fact_f32 = dV_f32 - (dV_f32 @ V_f32.T) @ V_f32 + max_diff = (fact_f32 - gram_f32).abs().max().item() + + return { + "name": name, "k_high": k_high, "M": M, + "gram_bytes": M * M * 2, "fact_bytes": k_high * M * 2, + "gram_ms": gram_ms, "fact_ms": fact_ms, "cached_ms": cached_ms, + "max_diff": max_diff, + } + + +# Global for shard indexing +rank = 0 + + +def run(rank_, world_size): + global rank + rank = rank_ + os.environ.update(MASTER_ADDR="localhost", MASTER_PORT="29500", + RANK=str(rank), WORLD_SIZE=str(world_size)) + dist.init_process_group("nccl") + torch.cuda.set_device(rank) + torch.manual_seed(42) + dev = f"cuda:{rank}" + + # Llama-8B shapes, URR=0.5 + targets = [ + # name, k_high, k_low, M + ("down_proj", 2048, 2048, 14336), # (4096, 14336) → ratio = 7x + ("q_proj", 2048, 2048, 4096), # (4096, 4096) → ratio = 2x + ] + + results = [] + for name, k_high, k_low, M in targets: + r = bench_target(name, k_high, k_low, M, world_size, dev) + results.append(r) + + if rank == 0: + gpu = torch.cuda.get_device_name(0) + print(f"V projection benchmark — {world_size}× {gpu}, Llama-8B shapes, bf16") + print("=" * 70) + print() + for r in results: + ratio = r["M"] / r["k_high"] + print(f" {r['name']:10s} V_high ({r['k_high']}, {r['M']})" + f" M/k_high = {ratio:.0f}x") + print(f" {'':10s} Gram Factored Cached") + print(f" {'':10s} ---------- ---------- ----------") + print(f" {'comm':10s} all-reduce all-gather none") + print(f" {'bytes':10s} {r['gram_bytes']/1e6:>7.0f} MB {r['fact_bytes']/1e6:>7.0f} MB {0:>7.0f} MB") + print(f" {'time':10s} {r['gram_ms']:>7.2f} ms {r['fact_ms']:>7.2f} ms {r['cached_ms']:>7.2f} ms") + print(f" {'speedup':10s} {'—':>10s} {r['gram_ms']/r['fact_ms']:>7.1f}x {r['gram_ms']/r['cached_ms']:>7.1f}x") + ok = "ok" if r["max_diff"] < 1e-4 else "FAIL" + print(f" {'correct':10s} — f32 max diff = {r['max_diff']:.1e} ({ok})") + print() + + # Aggregate: 32 layers × 7 targets (4 q/k/v/o + 2 gate/up + 1 down) + sq = next(r for r in results if r["name"] == "q_proj") + dp = next(r for r in results if r["name"] == "down_proj") + gram_tot = (6 * sq["gram_ms"] + dp["gram_ms"]) * 32 + fact_tot = (6 * sq["fact_ms"] + dp["fact_ms"]) * 32 + cached_tot = (6 * sq["cached_ms"] + dp["cached_ms"]) * 32 + print(f" Aggregate (32 layers × 7 targets = 224):") + print(f" {'':10s} Gram Factored Cached") + print(f" {'total':10s} {gram_tot:>7.0f} ms {fact_tot:>7.0f} ms {cached_tot:>7.0f} ms") + print(f" {'speedup':10s} {'—':>10s} {gram_tot/fact_tot:>7.1f}x {gram_tot/cached_tot:>7.1f}x") + print() + + dist.destroy_process_group() + + +if __name__ == "__main__": + mp.spawn(run, args=(2,), nprocs=2, join=True) diff --git a/src/mini_trainer/osft_utils.py b/src/mini_trainer/osft_utils.py index 63d0837..4b7e996 100644 --- a/src/mini_trainer/osft_utils.py +++ b/src/mini_trainer/osft_utils.py @@ -30,6 +30,20 @@ os.getenv("OSFT_CACHE_CLEAR_INTERVAL", 5) ) # Clear GPU cache every N parameters during matrix reconstruction +# Opt-in: cache the all-gathered (full) V_high to avoid repeating the +# all-gather on every step. V_high is frozen (requires_grad=False), so the +# cache is exact. +# +# V projection uses the factored form: dV -= (dV @ V_high^T) @ V_high. +# Under FSDP2, V_high is dim-0 sharded, so the factored form requires an +# all-gather of V_high (k_high × M). Caching stores the all-gathered result. +# +# Default OFF because the cache is REPLICATED on every FSDP2 rank (not +# sharded), adding ~5.1 GB per rank for Llama-8B (all 224 targets in bf16). +# For Llama-70B+ the cache exceeds GPU memory. Set to "1" only after +# confirming sufficient memory headroom. +OSFT_CACHE_V = os.getenv("OSFT_CACHE_V", "0") == "1" + def _supports_use_batch() -> bool: """Check if torch.distributed send/recv_object_list support the use_batch parameter (PyTorch 2.9+).""" @@ -514,11 +528,34 @@ def reconstruct_weight_matrix( return reconstructed -def project_gradient_to_orthogonal_space(svd_dict: SVDDecompositionDict): +def project_gradient_to_orthogonal_space( + svd_dict: SVDDecompositionDict, + cache_holder: "nn.Module | None" = None, +): """ - Projects the gradient of the low-rank parameters (U_low, V_low) to be orthogonal to the frozen high-rank subspace. + Projects the gradient of the low-rank parameters (U_low, V_low) to be + orthogonal to the frozen high-rank subspace. + + Both projections use the factored form: + dU -= U_high @ (U_high^T @ dU) + dV -= (dV @ V_high^T) @ V_high - This step ensures that learning new tasks does not interfere with previously learned representations by enforcing an orthogonality constraint. + Under FSDP2, U_high is dim-0 sharded along N (the large dimension), so + U_high^T @ dU contracts over the sharded dim → partial sum → all-reduce + of a small (k_high, k_low) matrix. + + V_high is dim-0 sharded along k_high (the small dimension), so the + factored form requires an all-gather of V_high to get the full + (k_high, M) tensor. This is M/k_high fewer bytes than the Gram + matrix all-reduce (k_high × M vs M × M) — 2x for square weights, + 7x for down_proj where k_high = min(N, M) × (1 - URR). + + Args: + svd_dict: Dictionary containing the SVD decomposition components. + cache_holder: Optional module on which to cache the all-gathered + V_high. V_high is frozen, so the cache is exact. When provided + and OSFT_CACHE_V is enabled, V_high is all-gathered once and + reused on subsequent steps. TODO(osilkin): Add mixed-precision gradients here """ @@ -551,21 +588,37 @@ def project_gradient_to_orthogonal_space(svd_dict: SVDDecompositionDict): else: dU.copy_(local_dU) - # Repeat projection for V_low using V_high + # Project V_low gradients: dV -= (dV @ V_high^T) @ V_high + # All-gather V_high from FSDP2 shards (or use cache) — see docstring for cost analysis. if svd_dict["V_low"].grad is not None: dV = svd_dict["V_low"].grad local_V_high = getattr(V_high, "to_local", lambda: V_high)() local_dV = getattr(dV, "to_local", lambda: dV)() - # Compute Gram matrix G = V_high^T @ V_high for global projection across row-sharded V_high - # Assumes column dimension is consistent across ranks (row sharding over singular vectors) - G_local = torch.mm(local_V_high.transpose(0, 1), local_V_high) - if dist.is_initialized() and dist.get_world_size() > 1: - dist.all_reduce(G_local, op=dist.ReduceOp.SUM) + # V_high is frozen — reuse cached all-gathered tensor when available. + can_cache = OSFT_CACHE_V and cache_holder is not None + cached = getattr(cache_holder, "_osft_v_high_full", None) if can_cache else None - # Apply projection: dV = dV - dV @ G (use local shard of dV) - update = torch.mm(local_dV, G_local) - local_dV.add_(update, alpha=-1.0) + if cached is not None: + V_high_full = cached + else: + if dist.is_initialized() and dist.get_world_size() > 1: + world_size = dist.get_world_size() + V_high_full = torch.empty( + local_V_high.shape[0] * world_size, local_V_high.shape[1], + dtype=local_V_high.dtype, device=local_V_high.device, + ) + dist.all_gather_into_tensor(V_high_full, local_V_high) + else: + V_high_full = local_V_high + if can_cache: + # .detach() ensures plain Tensor, not nn.Parameter — avoids + # nn.Module.__setattr__ registering it into state_dict. + cache_holder._osft_v_high_full = V_high_full.detach() + + # Two local matmuls — no (M, M) intermediate + coeff = torch.mm(local_dV, V_high_full.transpose(0, 1)) # (k_low/P, k_high) + local_dV.addmm_(coeff, V_high_full, alpha=-1.0) # (k_low/P, M) if hasattr(dV, "_local_tensor"): dV._local_tensor.copy_(local_dV) @@ -966,6 +1019,10 @@ def _reset_osft_metadata(self): self.osft_paramspec_registry = {} self._osft_handles = {} self.osft_params = {} + # Clear any cached all-gathered V_high — V_high changes on reinit. + for module in self.modules(): + if hasattr(module, "_osft_v_high_full"): + del module._osft_v_high_full @staticmethod def _load_non_distributed( @@ -1982,7 +2039,14 @@ def project_gradients(self): with the high-rank subspace encoding prior task knowledge. This method should be called after backpropagation and before optimizer step. + + When ``OSFT_CACHE_V=1`` is set, the all-gathered V_high tensor is + cached on each module after the first step. V_high is frozen, so + the cache is exact. This eliminates per-step V all-gather traffic. + Default is off because the cache is replicated on every FSDP2 rank + (~5.1 GB for Llama-8B, infeasible for 70B+). """ + caches_populated_this_call = 0 for module in self.modules(): # Only process real OSFT-attached linear modules, not the top-level container if ( @@ -1991,11 +2055,27 @@ def project_gradients(self): and hasattr(module, "osft_S_high") and hasattr(module, "osft_V_high") ): + had_cache = hasattr(module, "_osft_v_high_full") try: svd_dict = self.get_svd_dict_for_module(module) except ValueError as err: raise ValueError(f"error in projecting gradients for module: {module}") from err - project_gradient_to_orthogonal_space(svd_dict) + project_gradient_to_orthogonal_space(svd_dict, cache_holder=module) + if not had_cache and hasattr(module, "_osft_v_high_full"): + caches_populated_this_call += 1 + + if caches_populated_this_call > 0: + total_bytes = sum( + module._osft_v_high_full.nelement() * module._osft_v_high_full.element_size() + for module in self.modules() + if hasattr(module, "_osft_v_high_full") + ) + log_rank_0( + f"Cached {caches_populated_this_call} V_high tensors " + f"({total_bytes / 1e9:.2f} GB). " + f"Subsequent steps skip V all-gathers. " + f"Set OSFT_CACHE_V=0 to disable." + ) def prepare_state_dict_for_save(self, state_dict): """Reconstruct dense weights into ``state_dict`` for saving with memory optimization.""" diff --git a/tests/test_osft.py b/tests/test_osft.py index 300cbf9..c3725f8 100644 --- a/tests/test_osft.py +++ b/tests/test_osft.py @@ -1398,6 +1398,327 @@ def test_orthogonality_tracker(self): assert "param2" in summary +class TestVProjectionCache: + """Test V_high caching in factored V projection.""" + + def _create_simple_osft_model(self, hidden_size=32, rank_ratio=0.5): + """Create a simple model with OSFT for testing.""" + + class SimpleModel(nn.Module): + def __init__(self, config, **kwargs): + super().__init__() + self.config = config + self.linear = nn.Linear(hidden_size, hidden_size, bias=False) + self.dtype = torch.float32 + nn.init.normal_(self.linear.weight, mean=0.0, std=0.02) + + OSFTModelClass = create_osft_model_class(SimpleModel) + config = MagicMock() + config.vocab_size = 1000 + osft_config = {"linear.weight": int(hidden_size * rank_ratio)} + + model = OSFTModelClass( + config=config, + osft_config={}, + initialize_osft=False, + upcast_dtype=torch.float32, + output_dtype=torch.float32, + ) + model.osft_config = osft_config + model.osft_unfreeze_rank_ratio = rank_ratio + model.reinitialize_osft(decompose_existing_weights=True) + return model + + def test_cache_populated_after_first_projection(self, monkeypatch): + """V_high cache should be set on the module after the first call.""" + monkeypatch.setattr(osft_module, "OSFT_CACHE_V", True) + model = self._create_simple_osft_model() + model.train() + + # Before projection: no cache + for module in model.modules(): + if hasattr(module, "osft_params"): + assert not hasattr(module, "_osft_v_high_full") + + # Forward + backward + x = torch.randn(4, 32) + loss = model.linear(x).pow(2).sum() + loss.backward() + + model.project_gradients() + + # After projection: cache should exist + cached_count = 0 + for module in model.modules(): + if hasattr(module, "_osft_v_high_full"): + cached_count += 1 + V = module._osft_v_high_full + assert V.ndim == 2 + # Cached V_high should match V_high shape (k_high, M) + assert V.shape == module.osft_V_high.shape + assert cached_count > 0 + + def test_cached_v_high_matches_original(self, monkeypatch): + """Cached V_high should be identical to the original V_high (single GPU).""" + monkeypatch.setattr(osft_module, "OSFT_CACHE_V", True) + model = self._create_simple_osft_model() + model.train() + + x = torch.randn(4, 32) + loss = model.linear(x).pow(2).sum() + loss.backward() + model.project_gradients() + + for module in model.modules(): + if hasattr(module, "_osft_v_high_full") and hasattr(module, "osft_V_high"): + cached = module._osft_v_high_full + original = module.osft_V_high + assert torch.equal(cached, original), "Cached V_high differs from original" + + def test_cache_smaller_than_gram(self, monkeypatch): + """Cached V_high (k_high, M) should be smaller than the Gram matrix (M, M).""" + monkeypatch.setattr(osft_module, "OSFT_CACHE_V", True) + model = self._create_simple_osft_model(hidden_size=32, rank_ratio=0.5) + model.train() + + x = torch.randn(4, 32) + loss = model.linear(x).pow(2).sum() + loss.backward() + model.project_gradients() + + for module in model.modules(): + if hasattr(module, "_osft_v_high_full") and hasattr(module, "osft_V_high"): + V = module._osft_v_high_full + M = V.shape[1] + cache_elements = V.nelement() # k_high * M + gram_elements = M * M # M * M + assert cache_elements < gram_elements, ( + f"Cache ({cache_elements}) should be smaller than Gram ({gram_elements})" + ) + + def test_factored_matches_gram_projection(self): + """Factored V projection should produce the same result as Gram-based projection.""" + torch.manual_seed(42) + model = self._create_simple_osft_model(hidden_size=32, rank_ratio=0.5) + model.train() + + # Get V_high for manual Gram-based projection + V_high = None + for module in model.modules(): + if hasattr(module, "osft_V_high"): + V_high = module.osft_V_high.clone() + break + assert V_high is not None + + # Forward + backward + x = torch.randn(4, 32) + loss = model.linear(x).pow(2).sum() + loss.backward() + + # Save raw gradient before projection + for module in model.modules(): + if hasattr(module, "osft_V_high") and hasattr(module, "osft_params"): + raw_dV = module.osft_params.V_low.grad.clone() + break + + # Apply factored projection (the actual code path) + model.project_gradients() + + for module in model.modules(): + if hasattr(module, "osft_V_high") and hasattr(module, "osft_params"): + factored_result = module.osft_params.V_low.grad.clone() + break + + # Compute Gram-based projection manually for comparison + G = V_high.T @ V_high + gram_result = raw_dV - raw_dV @ G + + assert torch.allclose(factored_result, gram_result, atol=1e-6), ( + f"Factored and Gram projections differ: max diff = " + f"{(factored_result - gram_result).abs().max().item()}" + ) + + def test_factored_matches_gram_rectangular(self): + """Factored and Gram agree for rectangular weights (the down_proj-like 7x case). + + For down_proj, N < M so k_high = min(N,M)/2 is small relative to M. + This exercises the high-ratio regime where the factored form saves 7x. + """ + torch.manual_seed(42) + # Mimic down_proj shape ratio: N=16, M=56 → k_high=8, M=56 → ratio=7x + N, M = 16, 56 + k_high = N // 2 # 8 + k_low = N - k_high # 8 + + # Create orthonormal V_high (k_high, M) via QR + V_high = torch.linalg.qr(torch.randn(M, k_high))[0].T # (k_high, M) + assert V_high.shape == (k_high, M) + + dV = torch.randn(k_low, M) + + # Gram form: dV - dV @ (V_high^T @ V_high) + G = V_high.T @ V_high # (M, M) = (56, 56) + gram_result = dV - dV @ G + + # Factored form: dV - (dV @ V_high^T) @ V_high + coeff = dV @ V_high.T # (k_low, k_high) = (8, 8) + factored_result = dV - coeff @ V_high + + # Verify sizes confirm the 7x ratio + assert G.nelement() == M * M # 3136 + assert V_high.nelement() == k_high * M # 448 + assert G.nelement() / V_high.nelement() == M / k_high # 7.0 + + assert torch.allclose(factored_result, gram_result, atol=1e-6), ( + f"Factored and Gram differ for rectangular case: max diff = " + f"{(factored_result - gram_result).abs().max().item()}" + ) + + def test_projection_identical_with_and_without_cache(self, monkeypatch): + """Projection results should be identical whether or not the cache is used.""" + monkeypatch.setattr(osft_module, "OSFT_CACHE_V", True) + torch.manual_seed(42) + model = self._create_simple_osft_model() + model.train() + + # Step 1: populates cache + x = torch.randn(4, 32) + loss = model.linear(x).pow(2).sum() + loss.backward() + model.project_gradients() + + osft_params = [p for n, p in model.named_parameters() if "osft_params" in n] + optimizer = torch.optim.SGD(osft_params, lr=1e-3) + optimizer.step() + optimizer.zero_grad() + + # Step 2: uses cache — save projected gradient + x2 = torch.randn(4, 32) + loss2 = model.linear(x2).pow(2).sum() + loss2.backward() + model.project_gradients() + + grad_with_cache = {} + for module in model.modules(): + if hasattr(module, "osft_V_high"): + grad_with_cache["V_low"] = module.osft_params.V_low.grad.clone() + grad_with_cache["U_low"] = module.osft_params.U_low.grad.clone() + + # Now clear cache and redo the same step + optimizer.zero_grad() + for module in model.modules(): + if hasattr(module, "_osft_v_high_full"): + del module._osft_v_high_full + + loss2b = model.linear(x2).pow(2).sum() + loss2b.backward() + model.project_gradients() + + for module in model.modules(): + if hasattr(module, "osft_V_high"): + assert torch.equal(module.osft_params.V_low.grad, grad_with_cache["V_low"]), ( + "V_low gradient differs with vs without cache" + ) + assert torch.equal(module.osft_params.U_low.grad, grad_with_cache["U_low"]), ( + "U_low gradient differs with vs without cache" + ) + + def test_cache_disabled_by_default(self): + """Cache should not be populated when OSFT_CACHE_V is at its default (off).""" + + model = self._create_simple_osft_model() + model.train() + + x = torch.randn(4, 32) + loss = model.linear(x).pow(2).sum() + loss.backward() + model.project_gradients() + + for module in model.modules(): + assert not hasattr(module, "_osft_v_high_full"), ( + "Cache should not be populated when OSFT_CACHE_V is disabled" + ) + + def test_orthogonality_maintained_with_cache(self, monkeypatch): + """Orthogonality must hold across multiple steps with caching active.""" + monkeypatch.setattr(osft_module, "OSFT_CACHE_V", True) + model = self._create_simple_osft_model() + model.train() + tracker = OrthogonalityTracker(margin_deg=1.0) + + osft_params = [p for n, p in model.named_parameters() if "osft_params" in n] + optimizer = torch.optim.AdamW(osft_params, lr=1e-4) + optim_wrapper(optimizer, model) + + for step in range(1, 11): + x = torch.randn(4, 32) + loss = model.linear(x).pow(2).sum() + loss.backward() + + # project_gradients is called inside optim_wrapper's step() + optimizer.step() + + for module in model.modules(): + if ( + hasattr(module, "osft_params") + and hasattr(module, "osft_U_high") + and hasattr(module, "osft_S_high") + and hasattr(module, "osft_V_high") + ): + check_parameter_orthogonality(model, module, step, tracker) + + optimizer.zero_grad() + + assert tracker.is_successful(), f"Orthogonality violated with V_high caching:\n{tracker.get_summary()}" + + def test_cache_cleared_by_reset_osft_metadata(self, monkeypatch): + """_reset_osft_metadata must clear cached V_high tensors. + + reinitialize_osft calls _reset_osft_metadata, which creates new V_high + tensors. Any cached all-gathered V_high from the old decomposition + would be stale. This test verifies the cache is cleared. + """ + monkeypatch.setattr(osft_module, "OSFT_CACHE_V", True) + model = self._create_simple_osft_model() + model.train() + + # Populate cache + x = torch.randn(4, 32) + loss = model.linear(x).pow(2).sum() + loss.backward() + model.project_gradients() + + # Verify cache exists + cached_modules = [m for m in model.modules() if hasattr(m, "_osft_v_high_full")] + assert len(cached_modules) > 0, "Cache should be populated before reset" + + # _reset_osft_metadata is the mechanism reinitialize_osft uses + model._reset_osft_metadata() + + # Cache must be gone + for module in model.modules(): + assert not hasattr(module, "_osft_v_high_full"), "Cache was not cleared by _reset_osft_metadata" + + def test_cache_not_in_state_dict(self, monkeypatch): + """Cached V_high tensors must not appear in model state_dict.""" + monkeypatch.setattr(osft_module, "OSFT_CACHE_V", True) + model = self._create_simple_osft_model() + model.train() + + x = torch.randn(4, 32) + loss = model.linear(x).pow(2).sum() + loss.backward() + model.project_gradients() + + # Verify cache exists + assert any(hasattr(m, "_osft_v_high_full") for m in model.modules()) + + # Verify it's not in state_dict + sd = model.state_dict() + for key in sd: + assert "_osft_v_high_full" not in key, f"Cache leaked into state_dict: {key}" + + class TestLazyInitTokenizerAlignment: """Ensure memory-efficient loading aligns tokenizers before broadcasting."""