-
Notifications
You must be signed in to change notification settings - Fork 17
Description
Problem
project_gradient_to_orthogonal_space computes V projection via the Gram
matrix form:
G = V_high^T @ V_high # (M, M) partial sum
dist.all_reduce(G) # 411 MB for down_proj
dV -= dV @ GThe Gram matrix is M × M regardless of k_high. For down_proj
(M = 14336, k_high = 2048), the all-reduce is 411 MB — 7x larger than
the V_high tensor itself (59 MB).
Proposed fix
Replace with the algebraically identical factored form (which git history shows had been done on single GPU):
dist.all_gather_into_tensor(V_high_full, local_V_high) # 59 MB for down_proj
coeff = local_dV @ V_high_full.T # (k_low/P, k_high) — local
local_dV -= coeff @ V_high_full # (k_low/P, M) — localThe communication strategy is the improvement for the factored form.
The all-gather message is k_high × M vs M × M for the all-reduce —
M / k_high fewer bytes (2-7x depending on target shape), and all-gather
is cheaper per byte than all-reduce.
V_high is frozen, so the all-gathered result can be cached across steps
(OSFT_CACHE_V=1, default off).
This is opt-in since it's the unsharded matrix and memory considerations are required.