11"""
22code directly adapted from https://github.com/rfeinman/pytorch-lasso
33"""
4+ from typing import Optional
5+
46import torch
5- import torch .nn .functional as F
67
78from torch_staintools .functional .compile import lazy_compile
9+ from torch_staintools .functional .optimization .sparse_util import collate_params
10+
11+ def _preprocess_input (z0 : torch .Tensor ,
12+ x : torch .Tensor ,
13+ lr : Optional [float | torch .Tensor ],
14+ weight : torch .Tensor ,
15+ alpha : float | torch .Tensor ,
16+ tol : float ):
17+ lr , alpha , tol = collate_params (x , lr , weight , alpha , tol )
18+ z0 = z0 .contiguous ()
19+ x = x .contiguous ()
20+ weight = weight .contiguous ()
21+ tol = z0 .numel () * tol
22+ return z0 , x , weight , lr , alpha , tol
23+
24+
25+ def _grad_precompute (x : torch .Tensor , weight : torch .Tensor ):
26+ # return Hessian and bias
27+ return torch .mm (weight .T , weight ), torch .mm (x , weight )
28+
29+ def _softshrink (x : torch .Tensor , lambd : torch .Tensor ) -> torch .Tensor :
30+ lambd = lambd .clamp_min (0 )
31+ return x .sign () * (x .abs () - lambd ).clamp_min (0 )
32+
33+ def softshrink (x : torch .Tensor , lambd : torch .Tensor , positive : bool ) -> torch .Tensor :
34+ if positive :
35+ return (x - lambd ).clamp_min (0 )
36+ return _softshrink (x , lambd )
37+
38+ def cd_step (
39+ z : torch .Tensor ,
40+ b : torch .Tensor ,
41+ s : torch .Tensor ,
42+ alpha : torch .Tensor ,
43+ positive_code : bool ,
44+ ) -> tuple [torch .Tensor , torch .Tensor ]:
45+ z = torch .nan_to_num (z , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
46+ b = torch .nan_to_num (b , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
47+ s = torch .nan_to_num (s , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
48+ alpha = torch .nan_to_num (alpha , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
49+
50+ z_proposal = softshrink (b , alpha , positive_code )
51+
52+ z_diff = z_proposal - z
53+
54+ k = z_diff .abs ().argmax (dim = 1 )
55+ kk = k .unsqueeze (1 )
56+
57+ z_diff_selected = z_diff .gather (1 , kk )
58+
59+ one_hot = torch .nn .functional .one_hot (
60+ k , num_classes = z .size (1 )
61+ ).to (dtype = z .dtype )
62+ s_col_vec = torch .mm (one_hot , s .T )
63+
64+ b_next = b + s_col_vec * z_diff_selected
65+
66+ z_next_selected = z_proposal .gather (1 , kk )
67+ z_next = z .scatter (1 , kk , z_next_selected )
68+
69+ finite_row = (
70+ torch .isfinite (z ).all (dim = 1 ) &
71+ torch .isfinite (b ).all (dim = 1 ) &
72+ torch .isfinite (z_next ).all (dim = 1 ) &
73+ torch .isfinite (b_next ).all (dim = 1 )
74+ ).unsqueeze (1 )
75+ z_next = torch .where (finite_row , z_next , z )
76+ b_next = torch .where (finite_row , b_next , b )
77+
78+ return z_next , b_next
79+
80+
81+ @lazy_compile
82+ def cd_loop (
83+ z : torch .Tensor ,
84+ b : torch .Tensor ,
85+ s : torch .Tensor ,
86+ alpha : torch .Tensor ,
87+ tol : float ,
88+ maxiter : int ,
89+ positive_code : bool ,
90+ ) -> torch .Tensor :
891
92+ is_converged = torch .zeros_like (z [:, 0 ], dtype = torch .bool )
93+ for _ in range (maxiter ):
94+ z_next , b_next = cd_step (z , b , s , alpha , positive_code )
95+
96+ update = (z_next - z ).abs ().sum (dim = 1 ) # [N]
97+ just_finished = update <= tol
998
10- def coord_descent (x : torch .Tensor , z0 : torch .Tensor , weight : torch .Tensor ,
99+ # freeze if converged. can't early break here.
100+ cvf_2d = is_converged .unsqueeze (1 )
101+ z = torch .where (cvf_2d , z , z_next )
102+ b = torch .where (cvf_2d , b , b_next )
103+
104+ is_converged = is_converged | just_finished
105+
106+ return softshrink (b , alpha , positive = positive_code )
107+
108+
109+ def coord_descent (x : torch .Tensor ,
110+ z0 : torch .Tensor ,
111+ weight : torch .Tensor ,
11112 alpha : torch .Tensor ,
12113 maxiter : int , tol : float ,
13114 positive_code : bool ):
14115 """ modified coord_descent"""
15- if isinstance (alpha , torch .Tensor ):
16- assert alpha .numel () == 1
17- alpha = alpha .item ()
18- input_dim , code_dim = weight .shape # [D,K]
19- batch_size , input_dim1 = x .shape # [N,D]
20- assert input_dim1 == input_dim
21- tol = tol * code_dim
22- if z0 is None :
23- z = x .new_zeros (batch_size , code_dim ) # [N,K]
24- else :
25- assert z0 .shape == (batch_size , code_dim )
26- z = z0
27-
28- b = torch .mm (x , weight ) # [N,K]
29-
30- # precompute S = I - W^T @ W
31- S = - torch .mm (weight .T , weight ) # [K,K]
32- S .diagonal ().add_ (1. )
33-
34-
35- def cd_update (z , b ):
36- if positive_code :
37- z_next = (b - alpha ).clamp_min (0 )
38- else :
39- z_next = F .softshrink (b , alpha ) # [N,K]
40- z_diff = z_next - z # [N,K]
41- k = z_diff .abs ().argmax (1 ) # [N]
42- kk = k .unsqueeze (1 ) # [N,1]
43- b = b + S [:, k ].T * z_diff .gather (1 , kk ) # [N,K] += [N,K] * [N,1]
44- z = z .scatter (1 , kk , z_next .gather (1 , kk ))
45- return z , b
46-
47- active = torch .arange (batch_size , device = weight .device )
48- for i in range (maxiter ):
49- if len (active ) == 0 :
50- break
51- z_old = z [active ]
52- z_new , b [active ] = cd_update (z_old , b [active ])
53- update = (z_new - z_old ).abs ().sum (1 )
54- z [active ] = z_new
55- active = active [update > tol ]
56-
57- z = F .softshrink (b , alpha )
116+ # lr set to one to avoid L computation. Lr is not used in CD
117+ z0 , x , weight , lr , alpha , tol = _preprocess_input (z0 , x , 1 , weight , alpha , tol )
118+
119+ hessian , b = _grad_precompute (x , weight )
120+ code_dim = weight .size (1 )
121+ # S = I - H
122+ s = torch .eye (code_dim , device = x .device , dtype = x .dtype ) - hessian
123+ z = cd_loop (z0 , b , s , alpha , tol = tol , maxiter = maxiter , positive_code = positive_code )
58124 return z
59125
60126def rss_grad (z_k : torch .Tensor , x : torch .Tensor , weight : torch .Tensor ):
@@ -65,15 +131,6 @@ def rss_grad(z_k: torch.Tensor, x: torch.Tensor, weight: torch.Tensor):
65131def rss_grad_fast (z_k : torch .Tensor , hessian : torch .Tensor , b : torch .Tensor ):
66132 return torch .mm (z_k , hessian ) - b
67133
68- def _grad_precompute (x : torch .Tensor , weight : torch .Tensor ):
69- # return Hessian and bias
70- return torch .mm (weight .T , weight ), torch .mm (x , weight )
71-
72- def softshrink (x : torch .Tensor , lambd : torch .Tensor ) -> torch .Tensor :
73- lambd = lambd .clamp_min (0 )
74- return x .sign () * (x .abs () - lambd ).clamp_min (0 )
75-
76-
77134def ista_step (
78135 z : torch .Tensor ,
79136 hessian : torch .Tensor ,
@@ -105,13 +162,10 @@ def ista_step(
105162
106163 # guard lr
107164 lr_safe = torch .nan_to_num (lr , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
108- z_proposal = z - lr * g_safe
109- threshold = alpha * lr
110- if positive :
111- z_next = (z_proposal - threshold ).clamp_min (0 )
112- else :
113- # z_next = F.softshrink(z_prev - lr * rss_grad(z_prev, x, weight), alpha * lr)
114- z_next = softshrink (z_k_safe - lr_safe * g_safe , alpha * lr_safe )
165+ z_proposal = z - lr_safe * g_safe
166+ threshold = alpha * lr_safe
167+
168+ z_next = softshrink (z_proposal , threshold , positive )
115169 finite_mask = torch .isfinite (z ) & torch .isfinite (g ) & torch .isfinite (lr )
116170 return torch .where (finite_mask , z_next , z )
117171
@@ -226,10 +280,12 @@ def ista(x: torch.Tensor, z0: torch.Tensor,
226280 Returns:
227281
228282 """
229- # lr, alpha, tol = collate_params(z0, x, lr, weight, alpha, tol)
230- z0 = z0 .contiguous ()
231- x = x .contiguous ()
232- weight = weight .contiguous ()
283+ # lr, alpha, tol = collate_params(x, lr, weight, alpha, tol)
284+ # z0 = z0.contiguous()
285+ # x = x.contiguous()
286+ # weight = weight.contiguous()
287+ # tol = z0.numel() * tol
288+ z0 , x , weight , lr , alpha , tol = _preprocess_input (z0 , x , lr , weight , alpha , tol )
233289 hessian , b = _grad_precompute (x , weight )
234290 # hessian = hessian.contiguous()
235291 # b = b.contiguous()
@@ -256,10 +312,12 @@ def fista(x: torch.Tensor, z0: torch.Tensor,
256312 Returns:
257313
258314 """
259- # lr, alpha, tol = collate_params(z0, x, lr, weight, alpha, tol)
260- z0 = z0 .contiguous ()
261- x = x .contiguous ()
262- weight = weight .contiguous ()
315+ # lr, alpha, tol = collate_params(x, lr, weight, alpha, tol)
316+ # z0 = z0.contiguous()
317+ # x = x.contiguous()
318+ # weight = weight.contiguous()
319+ # tol = z0.numel() * tol
320+ z0 , x , weight , lr , alpha , tol = _preprocess_input (z0 , x , lr , weight , alpha , tol )
263321 hessian , b = _grad_precompute (x , weight )
264322 # hessian = hessian.contiguous()
265323 # b = b.contiguous()
0 commit comments