Skip to content

Commit 4592326

Browse files
committed
improve kronecker_factor_update_fn logic
Signed-off-by: Hao Wu <skyw@nvidia.com>
1 parent b211d91 commit 4592326

File tree

1 file changed

+23
-33
lines changed

1 file changed

+23
-33
lines changed

emerging_optimizers/soap/soap.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from functools import partial
1516
from itertools import chain
1617
from typing import Callable, Iterable, List, Optional, Tuple, Union
1718

@@ -197,35 +198,34 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
197198
# Exponential moving average of squared gradient values
198199
state["exp_avg_sq"] = torch.zeros_like(grad)
199200

201+
if "Q" not in state:
202+
state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape]
203+
204+
# Define kronecker_factor_update_fn based on whether to use KL-Shampoo here
205+
# because it needs access to state and group
206+
kronecker_factor_update_fn = partial(
207+
update_kronecker_factors, precondition_1d=group["precondition_1d"]
208+
)
209+
if group["use_kl_shampoo"]:
210+
kronecker_factor_update_fn = partial(
211+
update_kronecker_factors_kl_shampoo,
212+
eigenbasis_list=state["Q"],
213+
eps=group["eps"],
214+
)
215+
200216
# Initialize kronecker factor matrices
201217
if "GG" not in state:
202218
state["GG"] = init_kronecker_factors(
203219
grad,
204220
precondition_1d=group["precondition_1d"],
205221
)
206222

207-
assert "Q" not in state, "Q should not be initialized yet"
208-
state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape]
209-
210223
# Update preconditioner matrices with gradient statistics,
211224
# do not use shampoo_beta for EMA at first step
212225
with utils.fp32_matmul_precision(group["fp32_matmul_prec"]):
213-
kronecker_factor_update_kwargs = dict(
214-
kronecker_factor_list=state["GG"],
215-
grad=grad,
216-
shampoo_beta=0.0,
226+
kronecker_factor_update_fn(
227+
kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=group["shampoo_beta"]
217228
)
218-
if group["use_kl_shampoo"]:
219-
update_kronecker_factors_kl_shampoo(
220-
**kronecker_factor_update_kwargs,
221-
eigenbasis_list=state["Q"],
222-
eps=group["eps"],
223-
)
224-
else:
225-
update_kronecker_factors(
226-
**kronecker_factor_update_kwargs,
227-
precondition_1d=group["precondition_1d"],
228-
)
229229

230230
# Increment step counter
231231
state["step"] += 1
@@ -299,21 +299,11 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
299299

300300
torch.cuda.nvtx.range_push("update_kronecker_factors")
301301
with utils.fp32_matmul_precision(group["fp32_matmul_prec"]):
302-
if group["use_kl_shampoo"]:
303-
update_kronecker_factors_kl_shampoo(
304-
kronecker_factor_list=state["GG"],
305-
grad=grad,
306-
shampoo_beta=shampoo_beta,
307-
eigenbasis_list=state["Q"],
308-
eps=group["eps"],
309-
)
310-
else:
311-
update_kronecker_factors(
312-
kronecker_factor_list=state["GG"],
313-
grad=grad,
314-
shampoo_beta=shampoo_beta,
315-
precondition_1d=group["precondition_1d"],
316-
)
302+
kronecker_factor_update_fn(
303+
kronecker_factor_list=state["GG"],
304+
grad=grad,
305+
shampoo_beta=0.0,
306+
)
317307
torch.cuda.nvtx.range_pop()
318308

319309
# If current step is the last step to skip preconditioning, initialize eigenbases and

0 commit comments

Comments
 (0)