3030 os .getenv ("OSFT_CACHE_CLEAR_INTERVAL" , 5 )
3131) # Clear GPU cache every N parameters during matrix reconstruction
3232
33+ # Opt-in: cache the all-gathered (full) V_high to avoid repeating the
34+ # all-gather on every step. V_high is frozen (requires_grad=False), so the
35+ # cache is exact.
36+ #
37+ # V projection uses the factored form: dV -= (dV @ V_high^T) @ V_high.
38+ # Under FSDP2, V_high is dim-0 sharded, so the factored form requires an
39+ # all-gather of V_high (k_high × M). Caching stores the all-gathered result.
40+ #
41+ # Default OFF because the cache is REPLICATED on every FSDP2 rank (not
42+ # sharded), adding ~5.1 GB per rank for Llama-8B (all 224 targets in bf16).
43+ # For Llama-70B+ the cache exceeds GPU memory. Set to "1" only after
44+ # confirming sufficient memory headroom.
45+ OSFT_CACHE_V = os .getenv ("OSFT_CACHE_V" , "0" ) == "1"
46+
3347
3448def _supports_use_batch () -> bool :
3549 """Check if torch.distributed send/recv_object_list support the use_batch parameter (PyTorch 2.9+)."""
@@ -514,11 +528,34 @@ def reconstruct_weight_matrix(
514528 return reconstructed
515529
516530
517- def project_gradient_to_orthogonal_space (svd_dict : SVDDecompositionDict ):
531+ def project_gradient_to_orthogonal_space (
532+ svd_dict : SVDDecompositionDict ,
533+ cache_holder : "nn.Module | None" = None ,
534+ ):
518535 """
519- Projects the gradient of the low-rank parameters (U_low, V_low) to be orthogonal to the frozen high-rank subspace.
536+ Projects the gradient of the low-rank parameters (U_low, V_low) to be
537+ orthogonal to the frozen high-rank subspace.
538+
539+ Both projections use the factored form:
540+ dU -= U_high @ (U_high^T @ dU)
541+ dV -= (dV @ V_high^T) @ V_high
520542
521- This step ensures that learning new tasks does not interfere with previously learned representations by enforcing an orthogonality constraint.
543+ Under FSDP2, U_high is dim-0 sharded along N (the large dimension), so
544+ U_high^T @ dU contracts over the sharded dim → partial sum → all-reduce
545+ of a small (k_high, k_low) matrix.
546+
547+ V_high is dim-0 sharded along k_high (the small dimension), so the
548+ factored form requires an all-gather of V_high to get the full
549+ (k_high, M) tensor. This is M/k_high fewer bytes than the Gram
550+ matrix all-reduce (k_high × M vs M × M) — 2x for square weights,
551+ 7x for down_proj where k_high = min(N, M) × (1 - URR).
552+
553+ Args:
554+ svd_dict: Dictionary containing the SVD decomposition components.
555+ cache_holder: Optional module on which to cache the all-gathered
556+ V_high. V_high is frozen, so the cache is exact. When provided
557+ and OSFT_CACHE_V is enabled, V_high is all-gathered once and
558+ reused on subsequent steps.
522559
523560 TODO(osilkin): Add mixed-precision gradients here
524561 """
@@ -551,21 +588,37 @@ def project_gradient_to_orthogonal_space(svd_dict: SVDDecompositionDict):
551588 else :
552589 dU .copy_ (local_dU )
553590
554- # Repeat projection for V_low using V_high
591+ # Project V_low gradients: dV -= (dV @ V_high^T) @ V_high
592+ # All-gather V_high from FSDP2 shards (or use cache) — see docstring for cost analysis.
555593 if svd_dict ["V_low" ].grad is not None :
556594 dV = svd_dict ["V_low" ].grad
557595 local_V_high = getattr (V_high , "to_local" , lambda : V_high )()
558596 local_dV = getattr (dV , "to_local" , lambda : dV )()
559597
560- # Compute Gram matrix G = V_high^T @ V_high for global projection across row-sharded V_high
561- # Assumes column dimension is consistent across ranks (row sharding over singular vectors)
562- G_local = torch .mm (local_V_high .transpose (0 , 1 ), local_V_high )
563- if dist .is_initialized () and dist .get_world_size () > 1 :
564- dist .all_reduce (G_local , op = dist .ReduceOp .SUM )
598+ # V_high is frozen — reuse cached all-gathered tensor when available.
599+ can_cache = OSFT_CACHE_V and cache_holder is not None
600+ cached = getattr (cache_holder , "_osft_v_high_full" , None ) if can_cache else None
565601
566- # Apply projection: dV = dV - dV @ G (use local shard of dV)
567- update = torch .mm (local_dV , G_local )
568- local_dV .add_ (update , alpha = - 1.0 )
602+ if cached is not None :
603+ V_high_full = cached
604+ else :
605+ if dist .is_initialized () and dist .get_world_size () > 1 :
606+ world_size = dist .get_world_size ()
607+ V_high_full = torch .empty (
608+ local_V_high .shape [0 ] * world_size , local_V_high .shape [1 ],
609+ dtype = local_V_high .dtype , device = local_V_high .device ,
610+ )
611+ dist .all_gather_into_tensor (V_high_full , local_V_high )
612+ else :
613+ V_high_full = local_V_high
614+ if can_cache :
615+ # .detach() ensures plain Tensor, not nn.Parameter — avoids
616+ # nn.Module.__setattr__ registering it into state_dict.
617+ cache_holder ._osft_v_high_full = V_high_full .detach ()
618+
619+ # Two local matmuls — no (M, M) intermediate
620+ coeff = torch .mm (local_dV , V_high_full .transpose (0 , 1 )) # (k_low/P, k_high)
621+ local_dV .addmm_ (coeff , V_high_full , alpha = - 1.0 ) # (k_low/P, M)
569622
570623 if hasattr (dV , "_local_tensor" ):
571624 dV ._local_tensor .copy_ (local_dV )
@@ -966,6 +1019,10 @@ def _reset_osft_metadata(self):
9661019 self .osft_paramspec_registry = {}
9671020 self ._osft_handles = {}
9681021 self .osft_params = {}
1022+ # Clear any cached all-gathered V_high — V_high changes on reinit.
1023+ for module in self .modules ():
1024+ if hasattr (module , "_osft_v_high_full" ):
1025+ del module ._osft_v_high_full
9691026
9701027 @staticmethod
9711028 def _load_non_distributed (
@@ -1982,7 +2039,14 @@ def project_gradients(self):
19822039 with the high-rank subspace encoding prior task knowledge.
19832040
19842041 This method should be called after backpropagation and before optimizer step.
2042+
2043+ When ``OSFT_CACHE_V=1`` is set, the all-gathered V_high tensor is
2044+ cached on each module after the first step. V_high is frozen, so
2045+ the cache is exact. This eliminates per-step V all-gather traffic.
2046+ Default is off because the cache is replicated on every FSDP2 rank
2047+ (~5.1 GB for Llama-8B, infeasible for 70B+).
19852048 """
2049+ caches_populated_this_call = 0
19862050 for module in self .modules ():
19872051 # Only process real OSFT-attached linear modules, not the top-level container
19882052 if (
@@ -1991,11 +2055,27 @@ def project_gradients(self):
19912055 and hasattr (module , "osft_S_high" )
19922056 and hasattr (module , "osft_V_high" )
19932057 ):
2058+ had_cache = hasattr (module , "_osft_v_high_full" )
19942059 try :
19952060 svd_dict = self .get_svd_dict_for_module (module )
19962061 except ValueError as err :
19972062 raise ValueError (f"error in projecting gradients for module: { module } " ) from err
1998- project_gradient_to_orthogonal_space (svd_dict )
2063+ project_gradient_to_orthogonal_space (svd_dict , cache_holder = module )
2064+ if not had_cache and hasattr (module , "_osft_v_high_full" ):
2065+ caches_populated_this_call += 1
2066+
2067+ if caches_populated_this_call > 0 :
2068+ total_bytes = sum (
2069+ module ._osft_v_high_full .nelement () * module ._osft_v_high_full .element_size ()
2070+ for module in self .modules ()
2071+ if hasattr (module , "_osft_v_high_full" )
2072+ )
2073+ log_rank_0 (
2074+ f"Cached { caches_populated_this_call } V_high tensors "
2075+ f"({ total_bytes / 1e9 :.2f} GB). "
2076+ f"Subsequent steps skip V all-gathers. "
2077+ f"Set OSFT_CACHE_V=0 to disable."
2078+ )
19992079
20002080 def prepare_state_dict_for_save (self , state_dict ):
20012081 """Reconstruct dense weights into ``state_dict`` for saving with memory optimization."""
0 commit comments