Skip to content

Commit ebdae45

Browse files
committed
revise Lipschitz constant computation
1 parent 16f7445 commit ebdae45

File tree

3 files changed

+51
-45
lines changed

3 files changed

+51
-45
lines changed

torch_staintools/functional/optimization/solver.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ..eps import get_eps
88
import torch.nn.functional as F
99

10-
from ..utility import as_scalar
10+
from .sparse_util import as_scalar
1111

1212

1313
def coord_descent(x: torch.Tensor, z0: torch.Tensor, weight: torch.Tensor,
@@ -57,25 +57,6 @@ def cd_update(z, b):
5757
z = F.softshrink(b, alpha)
5858
return z
5959

60-
def _lipschitz_constant(w: torch.Tensor):
61-
"""find the Lipscitz constant to compute the learning rate in ISTA
62-
63-
Args:
64-
w: weights w in f(z) = ||Wz - x||^2
65-
66-
Returns:
67-
68-
"""
69-
# L = torch.linalg.norm(W, ord=2) ** 2
70-
# W has nan
71-
WtW = torch.matmul(w.t(), w)
72-
WtW += torch.eye(WtW.size(0)).to(w.device) * get_eps(WtW)
73-
L = torch.linalg.eigvalsh(WtW)[-1].squeeze()
74-
L_is_finite = torch.isfinite(L).all()
75-
L = torch.where(L_is_finite, L, torch.linalg.norm(w, ord=2) ** 2)
76-
L = L.abs()
77-
return L + torch.finfo(L.dtype).eps
78-
7960
def rss_grad(z_k: torch.Tensor, x: torch.Tensor, weight: torch.Tensor):
8061
resid = torch.matmul(z_k, weight.T) - x
8162
return torch.matmul(resid, weight)
@@ -208,19 +189,6 @@ def fista_loop(
208189

209190
return z
210191

211-
def __collate_params(z0: torch.Tensor,
212-
x: torch.Tensor,
213-
lr: str| float,
214-
weight: torch.Tensor,
215-
alpha: float | torch.Tensor,
216-
tol: float) -> Tuple[torch.Tensor, torch.Tensor, float]:
217-
if lr == 'auto':
218-
L = _lipschitz_constant(weight)
219-
lr = 1 / L
220-
tol = z0.numel() * tol
221-
alpha = as_scalar(alpha, x)
222-
lr = as_scalar(lr, x)
223-
return lr, alpha, tol
224192

225193
def ista(x, z0, weight, alpha=0.01, lr: str | float = 'auto',
226194
maxiter: int = 50,
@@ -240,15 +208,17 @@ def ista(x, z0, weight, alpha=0.01, lr: str | float = 'auto',
240208
Returns:
241209
242210
"""
243-
lr, alpha, tol = __collate_params(z0, x, lr, weight, alpha, tol)
211+
# lr, alpha, tol = collate_params(z0, x, lr, weight, alpha, tol)
244212
z0 = z0.contiguous()
245213
x = x.contiguous()
246214
weight = weight.contiguous()
247215

248216
return ista_loop(z0, x, weight, alpha, lr, tol, maxiter, positive_code)
249217

250218

251-
def fista(x, z0, weight, alpha=0.01, lr: str | float = 'auto',
219+
def fista(x: torch.Tensor, z0: torch.Tensor,
220+
weight: torch.Tensor,
221+
alpha: torch.Tensor, lr: str | float = 'auto',
252222
maxiter: int = 50,
253223
tol: float = 1e-5, positive_code: bool = False):
254224
"""Fast ISTA solver
@@ -266,7 +236,7 @@ def fista(x, z0, weight, alpha=0.01, lr: str | float = 'auto',
266236
Returns:
267237
268238
"""
269-
lr, alpha, tol = __collate_params(z0, x, lr, weight, alpha, tol)
239+
# lr, alpha, tol = collate_params(z0, x, lr, weight, alpha, tol)
270240
z0 = z0.contiguous()
271241
x = x.contiguous()
272242
weight = weight.contiguous()

torch_staintools/functional/optimization/sparse_util.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Optional, Literal, get_args
1+
from typing import Optional, Literal, get_args, Tuple
22
import torch
33
from torch.nn import functional as F
44
from torch_staintools.constants import CONST
5-
5+
from torch_staintools.functional.eps import get_eps
66

77
METHOD_ISTA = Literal['ista']
88
METHOD_FISTA = Literal['fista']
@@ -94,3 +94,46 @@ def validate_code(algorithm: METHOD_SPARSE,
9494
z0 = initialize_code(x, weight, mode=init, rng=rng)
9595
assert z0.shape == (n_samples, n_components)
9696
return z0
97+
98+
99+
def lipschitz_constant(w: torch.Tensor):
100+
"""find the Lipschitz constant to compute the learning rate in ISTA
101+
102+
Args:
103+
w: weights w in f(z) = ||Wz - x||^2
104+
105+
Returns:
106+
107+
"""
108+
# L = torch.linalg.norm(W, ord=2) ** 2
109+
# W has nan
110+
# WtW = torch.matmul(w.t(), w)
111+
# WtW += torch.eye(WtW.size(0)).to(w.device) * get_eps(WtW)
112+
# L = torch.linalg.eigvalsh(WtW)[-1].squeeze()
113+
# L_is_finite = torch.isfinite(L).all()
114+
# L = torch.where(L_is_finite, L, torch.linalg.norm(w, ord=2) ** 2)
115+
# L = L.abs()
116+
L = torch.linalg.norm(w, ord=2) ** 2
117+
return L + torch.finfo(L.dtype).eps
118+
119+
120+
def collate_params(z0: torch.Tensor,
121+
x: torch.Tensor,
122+
lr: str| float,
123+
weight: torch.Tensor,
124+
alpha: float | torch.Tensor,
125+
tol: float) -> Tuple[torch.Tensor, torch.Tensor, float]:
126+
if lr == 'auto':
127+
L = lipschitz_constant(weight)
128+
lr = 1 / L
129+
tol = z0.numel() * tol
130+
alpha = as_scalar(alpha, x)
131+
lr = as_scalar(lr, x)
132+
return lr, alpha, tol
133+
134+
135+
def as_scalar(v: float | torch.Tensor, like: torch.Tensor) -> torch.Tensor:
136+
if isinstance(v, torch.Tensor):
137+
# will except on non-scalar
138+
return v.to(device=like.device, dtype=like.dtype).reshape(())
139+
return torch.tensor(v, device=like.device, dtype=like.dtype)

torch_staintools/functional/utility/implementation.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,3 @@ def nanstd(data: torch.Tensor, dim: Optional[int | tuple] = None,
9292
sum_dev2 = ((data - mean) ** 2).nansum(dim=dim, keepdim=True)
9393
# sqrt and normalize by corrected degrees of freedom
9494
return torch.sqrt(sum_dev2 / (non_nan_count - correction))
95-
96-
97-
def as_scalar(v: float | torch.Tensor, like: torch.Tensor) -> torch.Tensor:
98-
if isinstance(v, torch.Tensor):
99-
# will except on non-scalar
100-
return v.to(device=like.device, dtype=like.dtype).reshape(())
101-
return torch.tensor(v, device=like.device, dtype=like.dtype)

0 commit comments

Comments
 (0)