Skip to content

Commit 863806a

Browse files
committed
update: test_galore_projection_type
1 parent 598e35e commit 863806a

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

tests/test_optimizer_parameters.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@
22
import torch
33
from torch import nn
44

5-
from pytorch_optimizer import SAM, WSAM, Lookahead, PCGrad, Ranger21, SafeFP16Optimizer, load_optimizer
5+
from pytorch_optimizer import (
6+
SAM,
7+
WSAM,
8+
GaLoreProjector,
9+
Lookahead,
10+
PCGrad,
11+
Ranger21,
12+
SafeFP16Optimizer,
13+
load_optimizer,
14+
)
615
from tests.constants import PULLBACK_MOMENTUM
716
from tests.utils import Example, simple_parameter, simple_zero_rank_parameter
817

@@ -254,3 +263,16 @@ def test_ranger_parameters():
254263
# test lookahead step `k`
255264
with pytest.raises(ValueError):
256265
opt(None, k=-1)
266+
267+
268+
def test_galore_projection_type():
269+
p = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32)
270+
271+
with pytest.raises(NotImplementedError):
272+
GaLoreProjector(projection_type='invalid').project(p, 1)
273+
274+
with pytest.raises(NotImplementedError):
275+
GaLoreProjector(projection_type='invalid').project_back(p)
276+
277+
with pytest.raises(ValueError):
278+
GaLoreProjector.get_orthogonal_matrix(p, 1, projection_type='std')

0 commit comments

Comments
 (0)