55from .sparse_util import METHOD_SPARSE , validate_code , initialize_dict , collate_params
66import torch
77import torch .nn .functional as F
8- from typing import Optional , cast
8+ from typing import Optional , cast , Tuple
99from ..eps import get_eps
1010from torch_staintools .constants import CONFIG
1111
12+
13+ @torch .compile
1214def update_dict_cd (dictionary : torch .Tensor , x : torch .Tensor , code : torch .Tensor ,
1315 positive : bool = True ,
14- dead_thresh = 1e-7 , rng : torch .Generator = None ):
16+ dead_thresh = 1e-7 ,
17+ rng : torch .Generator = None ) -> Tuple [torch .Tensor , torch .Tensor ]:
1518 """Update the dictionary (stain matrix) using Block Coordinate Descent algorithm.
1619
1720 Can satisfy the positive constraint of dictionaries if specified.
18-
21+ Side effects: code is updated inplace.
1922
2023 Args:
21- dictionary: Tensor of shape (n_features, n_components) Value of the dictionary at the previous iteration.
24+ dictionary: Tensor of shape (n_features, n_components).
25+ Value of the dictionary at the previous iteration.
2226 x: Tensor of shape (n_samples, n_components)
2327 Sparse coding of the data against which to optimize the dictionary.
2428 code: Tensor of shape (n_samples, n_components)
@@ -28,7 +32,7 @@ def update_dict_cd(dictionary: torch.Tensor, x: torch.Tensor, code: torch.Tensor
2832 rng: torch.Generator for initialization of dictionary and code.
2933
3034 Returns:
31-
35+ torch.Tensor, torch.Tensor, corresponding to the weight and the updated code.
3236 """
3337 n_components = dictionary .size (1 )
3438
@@ -38,10 +42,18 @@ def update_dict_cd(dictionary: torch.Tensor, x: torch.Tensor, code: torch.Tensor
3842 for k in range (n_components ):
3943 d_k = dictionary [:, k ]
4044 z_k = code [:, k ]
41- update_term = torch .outer (z_k , d_k )
42- # Update k'th atom
43- R += update_term
44- new_d_k = torch .mv (R .T , z_k )
45+
46+ # vanilla. new_d = (R + z*d^T)^T * z
47+ # new_d = R^T*z + (d*z^T)*z = R^T*z + d*(z^T*z)
48+ # update_term = torch.outer(z_k, d_k)
49+ # R += update_term
50+ # new_d_k = torch.mv(R.T, z_k) # target
51+
52+ # R^T*z
53+ rtz = torch .mv (R .T , z_k )
54+ ztz = torch .dot (z_k , z_k )
55+ new_d_k = rtz + (d_k * ztz )
56+
4557 if positive :
4658 new_d_k = torch .clamp (new_d_k , min = 0 )
4759
@@ -60,14 +72,22 @@ def update_dict_cd(dictionary: torch.Tensor, x: torch.Tensor, code: torch.Tensor
6072 d_k_standard = new_d_k / (d_norm + get_eps (dictionary ))
6173 d_k_final = torch .where (is_dead , d_k_random , d_k_standard )
6274 z_k_final = torch .where (is_dead , torch .zeros_like (z_k ), z_k )
75+
76+ # fused
77+ # must be done before updating the dict
78+ r_delta = torch .outer (z_k , d_k ) - torch .outer (z_k_final , d_k_final )
79+
6380 dictionary [:, k ] = d_k_final
6481 code [:, k ] = z_k_final
65- R -= torch .outer (z_k_final , d_k_final )
6682
67- return dictionary
83+ #R -= torch.outer(z_k_final, d_k_final)
84+ R += r_delta
85+
86+ return dictionary , code
6887
6988
70- def update_dict_ridge (x , code , lambd = 1e-4 ):
89+ @torch .compile
90+ def update_dict_ridge (x : torch .Tensor , code : torch .Tensor , lambd : float ) -> Tuple [torch .Tensor , torch .Tensor ]:
7191 """Update an (unconstrained) dictionary with ridge regression
7292
7393 This is equivalent to a Newton step with the (L2-regularized) squared
@@ -80,17 +100,17 @@ def update_dict_ridge(x, code, lambd=1e-4):
80100 lambd: weight decay parameter
81101
82102 Returns:
83-
103+ torch.Tensor, torch.Tensor, corresponding to the weight and the unmodified code.
84104 """
85105
86106 rhs = torch .mm (code .T , x )
87107 M = torch .mm (code .T , code )
88108 M .diagonal ().add_ (lambd * x .size (0 ))
89109 L = torch .linalg .cholesky (M )
90- V = torch .cholesky_solve (rhs , L ).T
110+ weight = torch .cholesky_solve (rhs , L ).T
91111
92- V = F .normalize (V , dim = 0 , eps = 1e-12 )
93- return V
112+ weight = F .normalize (weight , dim = 0 , eps = 1e-12 )
113+ return weight , code
94114
95115
96116def sparse_code (x : torch .Tensor ,
@@ -118,7 +138,6 @@ def sparse_code(x: torch.Tensor,
118138 raise ValueError ("invalid algorithm parameter '{}'." .format (algorithm ))
119139 return z
120140
121-
122141def dict_learning_loop (x : torch .Tensor ,
123142 z0 : torch .Tensor ,
124143 weight : torch .Tensor ,
@@ -135,7 +154,6 @@ def dict_learning_loop(x: torch.Tensor,
135154
136155 for _ in range (steps ):
137156 # infer sparse coefficients and compute loss
138-
139157 z = sparse_code (x , weight , alpha , z0 , algorithm = cast (METHOD_SPARSE , algorithm ),
140158 lr = lr , maxiter = maxiter , tol = tol ,
141159 positive_code = CONFIG .DICT_POSITIVE_CODE ).contiguous ()
@@ -145,36 +163,37 @@ def dict_learning_loop(x: torch.Tensor,
145163 if CONFIG .DICT_PERSIST_CODE :
146164 z0 = z
147165 else :
148- z0 = validate_code (algorithm , init , None , weight , x , rng )
166+ z0 = validate_code (algorithm , init , z0 = None , x = x , weight = weight , rng = rng )
149167
150168 # update dictionary
151169 if CONFIG .DICT_POSITIVE_DICTIONARY :
152- weight = update_dict_cd (weight , x , z , positive = True , rng = rng )
170+ weight , z = update_dict_cd (weight , x , z , positive = True , rng = rng )
153171 else :
154- weight = update_dict_ridge (x , z , lambd = lambd_ridge )
172+ weight , z = update_dict_ridge (x , z , lambd = lambd_ridge )
155173
156174 return weight
157175
158176
159177def dict_learning (x : torch .Tensor ,
160178 n_components : int ,
161179 algorithm : METHOD_SPARSE ,
162- * , alpha : float = 1e-1 ,
163- lambd_ridge : float = 1e-2 ,
164- steps : int = 60 ,
165- rng : torch .Generator = None ,
166- init : Optional [str ] = 'zero' ,
167- lr : Optional [float ] = None ,
168- maxiter : int = 50 ,
169- tol : float = 1e-5 , ):
180+ * , alpha : float ,
181+ lambd_ridge : float ,
182+ steps : int ,
183+ rng : Optional [ torch .Generator ] ,
184+ init : Optional [str ],
185+ lr : Optional [float ],
186+ maxiter : int ,
187+ tol : float , ):
170188 n_samples , n_features = x .shape
189+ # pixel x c
171190 x = x .contiguous ()
172-
191+ # c x stain
173192 weight = initialize_dict (n_features = n_features , n_components = n_components , device = x .device ,
174193 rng = rng , positive_dict = CONFIG .DICT_POSITIVE_DICTIONARY )
175194
176195 # initialize
177- z0 = validate_code (algorithm , init , None , weight , x , rng )
196+ z0 = validate_code (algorithm , init , z0 = None , x = x , weight = weight , rng = rng )
178197 assert z0 is not None
179198 lr , alpha , tol = collate_params (z0 , x , lr , weight , alpha , tol )
180199 return dict_learning_loop (x , z0 , weight , alpha , algorithm , lambd_ridge = lambd_ridge ,
0 commit comments