22import torch
33
44from pytorch_optimizer .base .exception import NoSparseGradientError
5- from pytorch_optimizer .optimizer import SAM , TRAC , WSAM , AdamP , Lookahead , load_optimizer
5+ from pytorch_optimizer .optimizer import SAM , TRAC , WSAM , AdamP , Lookahead , OrthoGrad , load_optimizer
66from tests .constants import NO_SPARSE_OPTIMIZERS , SPARSE_OPTIMIZERS , VALID_OPTIMIZER_NAMES
77from tests .utils import build_environment , simple_parameter , simple_sparse_parameter , sphere_loss
88
99
10- @pytest .mark .parametrize ('optimizer_name' , [* VALID_OPTIMIZER_NAMES , 'lookahead' , 'trac' ])
10+ @pytest .mark .parametrize ('optimizer_name' , [* VALID_OPTIMIZER_NAMES , 'lookahead' , 'trac' , 'orthograd' ])
1111def test_no_gradients (optimizer_name ):
1212 if optimizer_name in {'lomo' , 'adalomo' , 'adammini' , 'demo' }:
1313 pytest .skip (f'skip { optimizer_name } optimizer.' )
@@ -28,6 +28,8 @@ def test_no_gradients(optimizer_name):
2828 optimizer = Lookahead (load_optimizer ('adamw' )(params ), k = 1 )
2929 elif optimizer_name == 'trac' :
3030 optimizer = TRAC (load_optimizer ('adamw' )(params ))
31+ elif optimizer_name == 'orthograd' :
32+ optimizer = OrthoGrad (load_optimizer ('adamw' )(params ))
3133 else :
3234 optimizer = load_optimizer (optimizer_name )(params )
3335
0 commit comments