File tree Expand file tree Collapse file tree 2 files changed +41
-21
lines changed Expand file tree Collapse file tree 2 files changed +41
-21
lines changed Original file line number Diff line number Diff line change @@ -38,24 +38,3 @@ def test_load_optimizers_valid(valid_optimizer_names):
3838def test_load_optimizers_invalid (invalid_optimizer_names ):
3939 with pytest .raises (NotImplementedError ):
4040 load_optimizers (invalid_optimizer_names )
41-
42-
43- @pytest .mark .parametrize ('optimizer_names' , VALID_OPTIMIZER_NAMES )
44- def test_learning_rate (optimizer_names ):
45- with pytest .raises (ValueError ):
46- optimizer = load_optimizers (optimizer_names )
47- optimizer (None , lr = - 1e-2 )
48-
49-
50- @pytest .mark .parametrize ('optimizer_names' , VALID_OPTIMIZER_NAMES )
51- def test_epsilon (optimizer_names ):
52- with pytest .raises (ValueError ):
53- optimizer = load_optimizers (optimizer_names )
54- optimizer (None , eps = - 1e-6 )
55-
56-
57- @pytest .mark .parametrize ('optimizer_names' , VALID_OPTIMIZER_NAMES )
58- def test_weight_decay (optimizer_names ):
59- with pytest .raises (ValueError ):
60- optimizer = load_optimizers (optimizer_names )
61- optimizer (None , weight_decay = - 1e-3 )
Original file line number Diff line number Diff line change 1+ from typing import List
2+
3+ import pytest
4+
5+ from pytorch_optimizer import load_optimizers
6+
7+ VALID_OPTIMIZER_NAMES : List [str ] = [
8+ 'adamp' ,
9+ 'sgdp' ,
10+ 'madgrad' ,
11+ 'ranger' ,
12+ 'ranger21' ,
13+ 'radam' ,
14+ 'adabound' ,
15+ 'adahessian' ,
16+ 'adabelief' ,
17+ 'diffgrad' ,
18+ 'diffrgrad' ,
19+ 'lamb' ,
20+ ]
21+
22+
23+ @pytest .mark .parametrize ('optimizer_names' , VALID_OPTIMIZER_NAMES )
24+ def test_learning_rate (optimizer_names ):
25+ with pytest .raises (ValueError ):
26+ optimizer = load_optimizers (optimizer_names )
27+ optimizer (None , lr = - 1e-2 )
28+
29+
30+ @pytest .mark .parametrize ('optimizer_names' , VALID_OPTIMIZER_NAMES )
31+ def test_epsilon (optimizer_names ):
32+ with pytest .raises (ValueError ):
33+ optimizer = load_optimizers (optimizer_names )
34+ optimizer (None , eps = - 1e-6 )
35+
36+
37+ @pytest .mark .parametrize ('optimizer_names' , VALID_OPTIMIZER_NAMES )
38+ def test_weight_decay (optimizer_names ):
39+ with pytest .raises (ValueError ):
40+ optimizer = load_optimizers (optimizer_names )
41+ optimizer (None , weight_decay = - 1e-3 )
You can’t perform that action at this time.
0 commit comments