Skip to content

Commit 097f977

Browse files
committed
update: test_no_gradients
1 parent 0f53d02 commit 097f977

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/test_gradients.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import torch
33

44
from pytorch_optimizer.base.exception import NoSparseGradientError
5-
from pytorch_optimizer.optimizer import SAM, TRAC, WSAM, AdamP, Lookahead, load_optimizer
5+
from pytorch_optimizer.optimizer import SAM, TRAC, WSAM, AdamP, Lookahead, OrthoGrad, load_optimizer
66
from tests.constants import NO_SPARSE_OPTIMIZERS, SPARSE_OPTIMIZERS, VALID_OPTIMIZER_NAMES
77
from tests.utils import build_environment, simple_parameter, simple_sparse_parameter, sphere_loss
88

99

10-
@pytest.mark.parametrize('optimizer_name', [*VALID_OPTIMIZER_NAMES, 'lookahead', 'trac'])
10+
@pytest.mark.parametrize('optimizer_name', [*VALID_OPTIMIZER_NAMES, 'lookahead', 'trac', 'orthograd'])
1111
def test_no_gradients(optimizer_name):
1212
if optimizer_name in {'lomo', 'adalomo', 'adammini', 'demo'}:
1313
pytest.skip(f'skip {optimizer_name} optimizer.')
@@ -28,6 +28,8 @@ def test_no_gradients(optimizer_name):
2828
optimizer = Lookahead(load_optimizer('adamw')(params), k=1)
2929
elif optimizer_name == 'trac':
3030
optimizer = TRAC(load_optimizer('adamw')(params))
31+
elif optimizer_name == 'orthograd':
32+
optimizer = OrthoGrad(load_optimizer('adamw')(params))
3133
else:
3234
optimizer = load_optimizer(optimizer_name)(params)
3335

0 commit comments

Comments
 (0)