77from ..eps import get_eps
88import torch .nn .functional as F
99
10- from .. utility import as_scalar
10+ from .sparse_util import as_scalar
1111
1212
1313def coord_descent (x : torch .Tensor , z0 : torch .Tensor , weight : torch .Tensor ,
@@ -57,25 +57,6 @@ def cd_update(z, b):
5757 z = F .softshrink (b , alpha )
5858 return z
5959
60- def _lipschitz_constant (w : torch .Tensor ):
61- """find the Lipscitz constant to compute the learning rate in ISTA
62-
63- Args:
64- w: weights w in f(z) = ||Wz - x||^2
65-
66- Returns:
67-
68- """
69- # L = torch.linalg.norm(W, ord=2) ** 2
70- # W has nan
71- WtW = torch .matmul (w .t (), w )
72- WtW += torch .eye (WtW .size (0 )).to (w .device ) * get_eps (WtW )
73- L = torch .linalg .eigvalsh (WtW )[- 1 ].squeeze ()
74- L_is_finite = torch .isfinite (L ).all ()
75- L = torch .where (L_is_finite , L , torch .linalg .norm (w , ord = 2 ) ** 2 )
76- L = L .abs ()
77- return L + torch .finfo (L .dtype ).eps
78-
7960def rss_grad (z_k : torch .Tensor , x : torch .Tensor , weight : torch .Tensor ):
8061 resid = torch .matmul (z_k , weight .T ) - x
8162 return torch .matmul (resid , weight )
@@ -208,19 +189,6 @@ def fista_loop(
208189
209190 return z
210191
211- def __collate_params (z0 : torch .Tensor ,
212- x : torch .Tensor ,
213- lr : str | float ,
214- weight : torch .Tensor ,
215- alpha : float | torch .Tensor ,
216- tol : float ) -> Tuple [torch .Tensor , torch .Tensor , float ]:
217- if lr == 'auto' :
218- L = _lipschitz_constant (weight )
219- lr = 1 / L
220- tol = z0 .numel () * tol
221- alpha = as_scalar (alpha , x )
222- lr = as_scalar (lr , x )
223- return lr , alpha , tol
224192
225193def ista (x , z0 , weight , alpha = 0.01 , lr : str | float = 'auto' ,
226194 maxiter : int = 50 ,
@@ -240,15 +208,17 @@ def ista(x, z0, weight, alpha=0.01, lr: str | float = 'auto',
240208 Returns:
241209
242210 """
243- lr , alpha , tol = __collate_params (z0 , x , lr , weight , alpha , tol )
211+ # lr, alpha, tol = collate_params (z0, x, lr, weight, alpha, tol)
244212 z0 = z0 .contiguous ()
245213 x = x .contiguous ()
246214 weight = weight .contiguous ()
247215
248216 return ista_loop (z0 , x , weight , alpha , lr , tol , maxiter , positive_code )
249217
250218
251- def fista (x , z0 , weight , alpha = 0.01 , lr : str | float = 'auto' ,
219+ def fista (x : torch .Tensor , z0 : torch .Tensor ,
220+ weight : torch .Tensor ,
221+ alpha : torch .Tensor , lr : str | float = 'auto' ,
252222 maxiter : int = 50 ,
253223 tol : float = 1e-5 , positive_code : bool = False ):
254224 """Fast ISTA solver
@@ -266,7 +236,7 @@ def fista(x, z0, weight, alpha=0.01, lr: str | float = 'auto',
266236 Returns:
267237
268238 """
269- lr , alpha , tol = __collate_params (z0 , x , lr , weight , alpha , tol )
239+ # lr, alpha, tol = collate_params (z0, x, lr, weight, alpha, tol)
270240 z0 = z0 .contiguous ()
271241 x = x .contiguous ()
272242 weight = weight .contiguous ()
0 commit comments