Skip to content

Commit ffd5788

Browse files
committed
update: test cases
1 parent d017c6c commit ffd5788

File tree

2 files changed

+3
-13
lines changed

2 files changed

+3
-13
lines changed

tests/test_optimizer_parameters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22
import torch
33
from torch import nn
44

5+
from pytorch_optimizer.base.exception import NoClosureError
56
from pytorch_optimizer.optimizer import (
7+
BSAM,
68
SAM,
79
WSAM,
810
Lookahead,
911
LookSAM,
10-
BSAM,
1112
PCGrad,
1213
Ranger21,
1314
SafeFP16Optimizer,
1415
load_optimizer,
1516
)
16-
from pytorch_optimizer.base.exception import NoClosureError
1717
from pytorch_optimizer.optimizer.galore_utils import GaLoreProjector
1818
from tests.constants import PULLBACK_MOMENTUM
1919
from tests.utils import Example, simple_parameter

tests/test_optimizers.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,7 @@
66
from torch import nn
77

88
from pytorch_optimizer.base.exception import NoClosureError, ZeroParameterSizeError
9-
from pytorch_optimizer.optimizer import (
10-
BSAM,
11-
SAM,
12-
WSAM,
13-
DynamicLossScaler,
14-
LookSAM,
15-
Muon,
16-
load_optimizer,
17-
)
9+
from pytorch_optimizer.optimizer import DynamicLossScaler, Muon, load_optimizer
1810
from pytorch_optimizer.optimizer.alig import l2_projection
1911
from pytorch_optimizer.optimizer.grokfast import gradfilter_ema, gradfilter_ma
2012
from pytorch_optimizer.optimizer.scion import build_lmo_norm
@@ -244,12 +236,10 @@ def test_hessian_optimizer(optimizer_name):
244236
optimizer = load_optimizer(optimizer_name)([param], **parameters)
245237
optimizer.zero_grad(set_to_none=True)
246238

247-
# Hutchinson (internal) estimator
248239
sphere_loss(param).backward(create_graph=True)
249240
optimizer.step()
250241
optimizer.zero_grad(set_to_none=True)
251242

252-
# External estimator
253243
sphere_loss(param).backward()
254244
optimizer.step(hessian=torch.zeros_like(param).unsqueeze(0))
255245

0 commit comments

Comments
 (0)