Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions benchmarks/bench_v_proj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""Benchmark: Gram vs factored V projection under FSDP2-style sharding.

Measures the V projection step of project_gradient_to_orthogonal_space()
with Llama-8B shapes. Compares three modes:

Gram: all-reduce (M, M) Gram matrix, then dV -= dV @ G
Factored: all-gather (k_high, M) V_high, then dV -= (dV @ V^T) @ V
Cached: reuse V_high from prior step (no comm), same matmuls

Usage:
python bench_v_proj.py # requires 2 GPUs
"""

import os
import time

import torch
import torch.distributed as dist
import torch.multiprocessing as mp


def bench_target(name, k_high, k_low, M, P, dev, n_iters=100):
"""Benchmark one OSFT target shape. Returns dict of timings."""
local_k_high = k_high // P
local_k_low = k_low // P

# V_high has orthonormal rows (from SVD) — create via QR in float32
V_full_f32 = torch.linalg.qr(
torch.randn(M, k_high, device=dev)
)[0].T[:k_high] # (k_high, M) orthonormal rows
V_full = V_full_f32.to(torch.bfloat16)

# Each rank's shard of V_high
local_V = V_full[rank * local_k_high:(rank + 1) * local_k_high].contiguous()
local_dV = torch.randn(local_k_low, M, device=dev, dtype=torch.bfloat16)
dV_save = local_dV.clone()

# Pre-allocate for all-gather
V_ag = torch.empty_like(V_full)

# Warmup
for _ in range(10):
local_dV.copy_(dV_save)
G = torch.mm(local_V.T, local_V)
dist.all_reduce(G)
local_dV.add_(torch.mm(local_dV, G), alpha=-1.0)

local_dV.copy_(dV_save)
dist.all_gather_into_tensor(V_ag, local_V)
coeff = torch.mm(local_dV, V_ag.T)
local_dV.addmm_(coeff, V_ag, alpha=-1.0)
torch.cuda.synchronize()
dist.barrier()

# --- Gram form: all-reduce (M, M), then dV -= dV @ G ---
torch.cuda.synchronize()
dist.barrier()
t0 = time.perf_counter()
for _ in range(n_iters):
local_dV.copy_(dV_save)
G = torch.mm(local_V.T, local_V)
dist.all_reduce(G)
local_dV.add_(torch.mm(local_dV, G), alpha=-1.0)
torch.cuda.synchronize()
gram_ms = (time.perf_counter() - t0) / n_iters * 1000

# --- Factored form: all-gather (k_high, M), then two matmuls ---
torch.cuda.synchronize()
dist.barrier()
t0 = time.perf_counter()
for _ in range(n_iters):
local_dV.copy_(dV_save)
dist.all_gather_into_tensor(V_ag, local_V)
coeff = torch.mm(local_dV, V_ag.T)
local_dV.addmm_(coeff, V_ag, alpha=-1.0)
torch.cuda.synchronize()
fact_ms = (time.perf_counter() - t0) / n_iters * 1000

# --- Cached: no comm, just matmuls ---
torch.cuda.synchronize()
dist.barrier()
t0 = time.perf_counter()
for _ in range(n_iters):
local_dV.copy_(dV_save)
coeff = torch.mm(local_dV, V_ag.T)
local_dV.addmm_(coeff, V_ag, alpha=-1.0)
torch.cuda.synchronize()
cached_ms = (time.perf_counter() - t0) / n_iters * 1000

# Correctness: verify in float32 (bf16 accumulation differs by ~5%)
dV_f32 = dV_save.float()
V_f32 = V_full_f32
gram_f32 = dV_f32 - dV_f32 @ (V_f32.T @ V_f32)
fact_f32 = dV_f32 - (dV_f32 @ V_f32.T) @ V_f32
max_diff = (fact_f32 - gram_f32).abs().max().item()

return {
"name": name, "k_high": k_high, "M": M,
"gram_bytes": M * M * 2, "fact_bytes": k_high * M * 2,
"gram_ms": gram_ms, "fact_ms": fact_ms, "cached_ms": cached_ms,
"max_diff": max_diff,
}


# Global for shard indexing
rank = 0


def run(rank_, world_size):
global rank
rank = rank_
os.environ.update(MASTER_ADDR="localhost", MASTER_PORT="29500",
RANK=str(rank), WORLD_SIZE=str(world_size))
dist.init_process_group("nccl")
torch.cuda.set_device(rank)
torch.manual_seed(42)
dev = f"cuda:{rank}"

# Llama-8B shapes, URR=0.5
targets = [
# name, k_high, k_low, M
("down_proj", 2048, 2048, 14336), # (4096, 14336) → ratio = 7x
("q_proj", 2048, 2048, 4096), # (4096, 4096) → ratio = 2x
]

results = []
for name, k_high, k_low, M in targets:
r = bench_target(name, k_high, k_low, M, world_size, dev)
results.append(r)

if rank == 0:
gpu = torch.cuda.get_device_name(0)
print(f"V projection benchmark — {world_size}× {gpu}, Llama-8B shapes, bf16")
print("=" * 70)
print()
for r in results:
ratio = r["M"] / r["k_high"]
print(f" {r['name']:10s} V_high ({r['k_high']}, {r['M']})"
f" M/k_high = {ratio:.0f}x")
print(f" {'':10s} Gram Factored Cached")
print(f" {'':10s} ---------- ---------- ----------")
print(f" {'comm':10s} all-reduce all-gather none")
print(f" {'bytes':10s} {r['gram_bytes']/1e6:>7.0f} MB {r['fact_bytes']/1e6:>7.0f} MB {0:>7.0f} MB")
print(f" {'time':10s} {r['gram_ms']:>7.2f} ms {r['fact_ms']:>7.2f} ms {r['cached_ms']:>7.2f} ms")
print(f" {'speedup':10s} {'—':>10s} {r['gram_ms']/r['fact_ms']:>7.1f}x {r['gram_ms']/r['cached_ms']:>7.1f}x")
ok = "ok" if r["max_diff"] < 1e-4 else "FAIL"
print(f" {'correct':10s} — f32 max diff = {r['max_diff']:.1e} ({ok})")
print()

# Aggregate: 32 layers × 7 targets (4 q/k/v/o + 2 gate/up + 1 down)
sq = next(r for r in results if r["name"] == "q_proj")
dp = next(r for r in results if r["name"] == "down_proj")
gram_tot = (6 * sq["gram_ms"] + dp["gram_ms"]) * 32
fact_tot = (6 * sq["fact_ms"] + dp["fact_ms"]) * 32
cached_tot = (6 * sq["cached_ms"] + dp["cached_ms"]) * 32
print(f" Aggregate (32 layers × 7 targets = 224):")
print(f" {'':10s} Gram Factored Cached")
print(f" {'total':10s} {gram_tot:>7.0f} ms {fact_tot:>7.0f} ms {cached_tot:>7.0f} ms")
print(f" {'speedup':10s} {'—':>10s} {gram_tot/fact_tot:>7.1f}x {gram_tot/cached_tot:>7.1f}x")
print()

dist.destroy_process_group()


if __name__ == "__main__":
mp.spawn(run, args=(2,), nprocs=2, join=True)
106 changes: 93 additions & 13 deletions src/mini_trainer/osft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@
os.getenv("OSFT_CACHE_CLEAR_INTERVAL", 5)
) # Clear GPU cache every N parameters during matrix reconstruction

# Opt-in: cache the all-gathered (full) V_high to avoid repeating the
# all-gather on every step. V_high is frozen (requires_grad=False), so the
# cache is exact.
#
# V projection uses the factored form: dV -= (dV @ V_high^T) @ V_high.
# Under FSDP2, V_high is dim-0 sharded, so the factored form requires an
# all-gather of V_high (k_high × M). Caching stores the all-gathered result.
#
# Default OFF because the cache is REPLICATED on every FSDP2 rank (not
# sharded), adding ~5.1 GB per rank for Llama-8B (all 224 targets in bf16).
# For Llama-70B+ the cache exceeds GPU memory. Set to "1" only after
# confirming sufficient memory headroom.
OSFT_CACHE_V = os.getenv("OSFT_CACHE_V", "0") == "1"


def _supports_use_batch() -> bool:
"""Check if torch.distributed send/recv_object_list support the use_batch parameter (PyTorch 2.9+)."""
Expand Down Expand Up @@ -514,11 +528,34 @@ def reconstruct_weight_matrix(
return reconstructed


def project_gradient_to_orthogonal_space(svd_dict: SVDDecompositionDict):
def project_gradient_to_orthogonal_space(
svd_dict: SVDDecompositionDict,
cache_holder: "nn.Module | None" = None,
):
"""
Projects the gradient of the low-rank parameters (U_low, V_low) to be orthogonal to the frozen high-rank subspace.
Projects the gradient of the low-rank parameters (U_low, V_low) to be
orthogonal to the frozen high-rank subspace.

Both projections use the factored form:
dU -= U_high @ (U_high^T @ dU)
dV -= (dV @ V_high^T) @ V_high

This step ensures that learning new tasks does not interfere with previously learned representations by enforcing an orthogonality constraint.
Under FSDP2, U_high is dim-0 sharded along N (the large dimension), so
U_high^T @ dU contracts over the sharded dim → partial sum → all-reduce
of a small (k_high, k_low) matrix.

V_high is dim-0 sharded along k_high (the small dimension), so the
factored form requires an all-gather of V_high to get the full
(k_high, M) tensor. This is M/k_high fewer bytes than the Gram
matrix all-reduce (k_high × M vs M × M) — 2x for square weights,
7x for down_proj where k_high = min(N, M) × (1 - URR).

Args:
svd_dict: Dictionary containing the SVD decomposition components.
cache_holder: Optional module on which to cache the all-gathered
V_high. V_high is frozen, so the cache is exact. When provided
and OSFT_CACHE_V is enabled, V_high is all-gathered once and
reused on subsequent steps.

TODO(osilkin): Add mixed-precision gradients here
"""
Expand Down Expand Up @@ -551,21 +588,37 @@ def project_gradient_to_orthogonal_space(svd_dict: SVDDecompositionDict):
else:
dU.copy_(local_dU)

# Repeat projection for V_low using V_high
# Project V_low gradients: dV -= (dV @ V_high^T) @ V_high
# All-gather V_high from FSDP2 shards (or use cache) — see docstring for cost analysis.
if svd_dict["V_low"].grad is not None:
dV = svd_dict["V_low"].grad
local_V_high = getattr(V_high, "to_local", lambda: V_high)()
local_dV = getattr(dV, "to_local", lambda: dV)()

# Compute Gram matrix G = V_high^T @ V_high for global projection across row-sharded V_high
# Assumes column dimension is consistent across ranks (row sharding over singular vectors)
G_local = torch.mm(local_V_high.transpose(0, 1), local_V_high)
if dist.is_initialized() and dist.get_world_size() > 1:
dist.all_reduce(G_local, op=dist.ReduceOp.SUM)
# V_high is frozen — reuse cached all-gathered tensor when available.
can_cache = OSFT_CACHE_V and cache_holder is not None
cached = getattr(cache_holder, "_osft_v_high_full", None) if can_cache else None

# Apply projection: dV = dV - dV @ G (use local shard of dV)
update = torch.mm(local_dV, G_local)
local_dV.add_(update, alpha=-1.0)
if cached is not None:
V_high_full = cached
else:
if dist.is_initialized() and dist.get_world_size() > 1:
world_size = dist.get_world_size()
V_high_full = torch.empty(
local_V_high.shape[0] * world_size, local_V_high.shape[1],
dtype=local_V_high.dtype, device=local_V_high.device,
)
dist.all_gather_into_tensor(V_high_full, local_V_high)
Comment on lines +605 to +611
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, let's look at the file structure and the specific lines mentioned
cd /tmp/repository 2>/dev/null || cd . 
wc -l src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 118


🏁 Script executed:

# Read the context around lines 605-611
sed -n '590,625p' src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 1780


🏁 Script executed:

# Find where local_V_high is computed and used
rg -n "local_V_high" src/mini_trainer/osft_utils.py -B 3 -A 3

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 1176


🏁 Script executed:

# Find rank_high and how it's set
rg -n "rank_high" src/mini_trainer/osft_utils.py -B 2 -A 2

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 3148


🏁 Script executed:

# Check the function signature and beginning to understand the data flow
sed -n '550,650p' src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 4562


🏁 Script executed:

# Check how V_high is created and if it's FSDP-distributed
rg -n "V_high" src/mini_trainer/osft_utils.py | head -50

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 3640


🏁 Script executed:

# Look for FSDP or distributed tensor patterns
rg -n "FSDP\|fully_shard\|to_local" src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Understand how SVD decomposition creates V_high and its shape
sed -n '450,520p' src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 3270


🏁 Script executed:

# Check if there's any padding or divisibility check already in place
rg -n "ceil\|divisib\|pad\|even\|uneven" src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Search for the definition of the function containing this code
rg -n "def project_gradient_to_orthogonal_space" src/mini_trainer/osft_utils.py -A 100 | head -120

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 5000


🏁 Script executed:

# Check how V_high is registered with FSDP2 or distributed tensor wrapping
rg -n "FSDP\|fully_shard\|distribute\|shard" src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Look for where osft_V_high is registered and how it's sharded
rg -n "osft_V_high\|register_parameter\|register_buffer" src/mini_trainer/osft_utils.py -B 2 -A 2

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Check if there's any test or example showing how rank_high varies
rg -n "rank_high\|rank_ratio" src/mini_trainer/osft_utils.py | head -30

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Search for auto_generate_target_osft_config to understand rank_ratio
rg -n "auto_generate_target_osft_config\|rank_ratio" src/mini_trainer/osft_utils.py -A 5

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Check torch.distributed documentation behavior for all_gather_into_tensor
# Let's see if there are any comments about shape expectations
grep -n "all_gather" src/mini_trainer/osft_utils.py -B 5 -A 5

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 697


🏁 Script executed:

# Let's search for where V_high gets registered with FSDP or how it's initialized
fd -name "*.py" src/mini_trainer | xargs grep -l "FSDP\|fully_shard" | head -5

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 313


🏁 Script executed:

# Check if there's any rank_ratio or top_k configuration that could make rank_high non-divisible
rg -n "rank_ratio\|top_k\|rank_high\|URR" src/mini_trainer/osft_utils.py -B 2 -A 2 | head -60

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Let's understand the actual shape of V_high before FSDP sharding
# Find where V_high enters svd_dict
sed -n '470,495p' src/mini_trainer/osft_utils.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 1179


🏁 Script executed:

# Search for any tests or usage examples that show different rank configurations
find . -name "*test*.py" -o -name "*example*.py" | xargs grep -l "rank_ratio\|rank_high\|OSFT" 2>/dev/null | head -3

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 208


🏁 Script executed:

# Check PyTorch's all_gather_into_tensor documentation or comments in code
python3 << 'EOF'
import torch
import torch.distributed as dist

# Verify all_gather_into_tensor behavior with unequal input shapes
help_text = """
torch.distributed.all_gather_into_tensor gathers tensors from all ranks into a single output tensor.
Key constraint: All input tensors must have the SAME shape on all ranks.
If shapes differ, the operation will fail or produce incorrect results.
"""
print(help_text)
EOF

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 200


🏁 Script executed:

# Check the test files for OSFT usage and rank configurations
cat ./regression_tests/test_osft_orthogonalization.py | head -100

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 3548


🏁 Script executed:

# Look for imports or setup of FSDP in the test files
grep -n "FSDP\|fully_shard\|world_size\|rank" ./regression_tests/test_osft_orthogonalization.py | head -20

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 982


🏁 Script executed:

# Check if there's documentation about how V_high sharding works
rg -n "dim-0 shard\|k_high\|world_size" src/mini_trainer/osft_utils.py -B 2 -A 2

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Check how rank_high values are computed in tests
grep -n "rank_high\|top_k\|rank_ratio" ./regression_tests/test_osft_*.py

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 590


🏁 Script executed:

# Look for configuration objects or rank calculations
rg -n "class.*Config\|def.*rank\|rank.*=" src/mini_trainer/osft_utils.py | head -30

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Check how rank_high is calculated from rank_ratio in setup_model
rg -n "rank_ratio\|top_k" src/mini_trainer/ -B 3 -A 3

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 65


🏁 Script executed:

# Search for setup_model_for_training to understand rank_ratio usage
find . -name "setup_model_for_training.py" -o -name "*setup*.py" | head -3

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 167


🏁 Script executed:

# Let's look for how top_k is determined from rank_ratio
grep -r "rank_ratio\|top_k" --include="*.py" src/mini_trainer/ | head -20

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 1783


🏁 Script executed:

# Check the actual setup_model function
fd "setup_model" src/mini_trainer -x cat {}

Repository: Red-Hat-AI-Innovation-Team/mini_trainer

Length of output: 47667


Add padding to local_V_high before all_gather_into_tensor() to handle uneven rank sharding.

The allocation at line 608 assumes every rank has local_V_high.shape[0] == ceil(rank_high / world_size), but this is only guaranteed when rank_high % world_size == 0. Since rank_high is computed from rank_ratio using int(np.floor(...)), non-divisible cases are common. With uneven sharding, all_gather_into_tensor() will fail or rebuild a truncated tensor.

Pad each local shard to ceil(rank_high / world_size) rows and slice the result back to svd_dict["rank_high"] after gathering.

Suggested fix
+                full_k_high = svd_dict["rank_high"]
                 world_size = dist.get_world_size()
-                V_high_full = torch.empty(
-                    local_V_high.shape[0] * world_size, local_V_high.shape[1],
-                    dtype=local_V_high.dtype, device=local_V_high.device,
-                )
-                dist.all_gather_into_tensor(V_high_full, local_V_high)
+                rows_per_rank = math.ceil(full_k_high / world_size)
+                padded_local_V_high = torch.zeros(
+                    rows_per_rank, local_V_high.shape[1],
+                    dtype=local_V_high.dtype, device=local_V_high.device,
+                )
+                padded_local_V_high[: local_V_high.shape[0]].copy_(local_V_high)
+                gathered = torch.empty(
+                    rows_per_rank * world_size, local_V_high.shape[1],
+                    dtype=local_V_high.dtype, device=local_V_high.device,
+                )
+                dist.all_gather_into_tensor(gathered, padded_local_V_high)
+                V_high_full = gathered[:full_k_high]

Add import math at the top of the file if not already present.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/mini_trainer/osft_utils.py` around lines 605 - 611, The current
all_gather_into_tensor assumes equal per-rank rows for local_V_high; pad
local_V_high to target_rows = math.ceil(svd_dict["rank_high"] / world_size)
before calling dist.all_gather_into_tensor so each rank provides the same shape
(use torch.nn.functional.pad or torch.zeros on the same device/dtype and
concatenate), perform the all_gather_into_tensor into V_high_full sized
(target_rows * world_size, cols), then slice V_high_full[:svd_dict["rank_high"],
:] to restore the original total rank and assign back into the downstream
variable; ensure import math is present at file top if missing.

else:
V_high_full = local_V_high
if can_cache:
# .detach() ensures plain Tensor, not nn.Parameter — avoids
# nn.Module.__setattr__ registering it into state_dict.
cache_holder._osft_v_high_full = V_high_full.detach()

# Two local matmuls — no (M, M) intermediate
coeff = torch.mm(local_dV, V_high_full.transpose(0, 1)) # (k_low/P, k_high)
local_dV.addmm_(coeff, V_high_full, alpha=-1.0) # (k_low/P, M)

if hasattr(dV, "_local_tensor"):
dV._local_tensor.copy_(local_dV)
Expand Down Expand Up @@ -966,6 +1019,10 @@ def _reset_osft_metadata(self):
self.osft_paramspec_registry = {}
self._osft_handles = {}
self.osft_params = {}
# Clear any cached all-gathered V_high — V_high changes on reinit.
for module in self.modules():
if hasattr(module, "_osft_v_high_full"):
del module._osft_v_high_full

@staticmethod
def _load_non_distributed(
Expand Down Expand Up @@ -1982,7 +2039,14 @@ def project_gradients(self):
with the high-rank subspace encoding prior task knowledge.

This method should be called after backpropagation and before optimizer step.

When ``OSFT_CACHE_V=1`` is set, the all-gathered V_high tensor is
cached on each module after the first step. V_high is frozen, so
the cache is exact. This eliminates per-step V all-gather traffic.
Default is off because the cache is replicated on every FSDP2 rank
(~5.1 GB for Llama-8B, infeasible for 70B+).
"""
caches_populated_this_call = 0
for module in self.modules():
# Only process real OSFT-attached linear modules, not the top-level container
if (
Expand All @@ -1991,11 +2055,27 @@ def project_gradients(self):
and hasattr(module, "osft_S_high")
and hasattr(module, "osft_V_high")
):
had_cache = hasattr(module, "_osft_v_high_full")
try:
svd_dict = self.get_svd_dict_for_module(module)
except ValueError as err:
raise ValueError(f"error in projecting gradients for module: {module}") from err
project_gradient_to_orthogonal_space(svd_dict)
project_gradient_to_orthogonal_space(svd_dict, cache_holder=module)
if not had_cache and hasattr(module, "_osft_v_high_full"):
caches_populated_this_call += 1

if caches_populated_this_call > 0:
total_bytes = sum(
module._osft_v_high_full.nelement() * module._osft_v_high_full.element_size()
for module in self.modules()
if hasattr(module, "_osft_v_high_full")
)
log_rank_0(
f"Cached {caches_populated_this_call} V_high tensors "
f"({total_bytes / 1e9:.2f} GB). "
f"Subsequent steps skip V all-gathers. "
f"Set OSFT_CACHE_V=0 to disable."
)

def prepare_state_dict_for_save(self, state_dict):
"""Reconstruct dense weights into ``state_dict`` for saving with memory optimization."""
Expand Down
Loading