Skip to content

Commit 88a890c

Browse files
committed
polishing dict learning. kernelizing rss_grad. optimizing bcd.
1 parent fb33d83 commit 88a890c

File tree

8 files changed

+102
-64
lines changed

8 files changed

+102
-64
lines changed

torch_staintools/augmentor/factory.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,16 @@ def build(method: AUG_TYPE_SUPPORTED,
5858
aug_method: Callable
5959
match method:
6060
case 'macenko' | 'vahadane':
61-
return Augmentor.build(method=method, concentration_method=concentration_method,
62-
rng=rng, target_stain_idx=target_stain_idx,
61+
return Augmentor.build(method=method,
62+
concentration_method=concentration_method,
63+
rng=rng,
64+
target_stain_idx=target_stain_idx,
6365
sigma_alpha=sigma_alpha,
64-
sigma_beta=sigma_beta, luminosity_threshold=luminosity_threshold,
66+
sigma_beta=sigma_beta,
67+
luminosity_threshold=luminosity_threshold,
6568
use_cache=use_cache,
6669
regularizer=regularizer,
67-
cache_size_limit=cache_size_limit, device=device, load_path=load_path)
70+
cache_size_limit=cache_size_limit,
71+
device=device, load_path=load_path)
6872
case _:
6973
raise NotImplementedError(f"{method} not implemented.")

torch_staintools/constants/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class _Config:
1212
DICT_POSITIVE_CODE: bool = True
1313

1414

15-
CONFIG = _Config()
15+
CONFIG: _Config = _Config()
1616

1717

1818

torch_staintools/constants/param.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class _Param:
2020
OPTIM_SPARSE_DEFAULT_MAX_ITER: float = 50
2121

2222

23-
PARAM = _Param()
23+
PARAM: _Param = _Param()
2424

2525

2626

torch_staintools/functional/concentration/implementation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
@dataclass(frozen=True)
2424
class ConcentCfg:
2525
algorithm: METHOD_FACTORIZE = 'fista'
26-
regularizer: float = CONFIG.OPTIM_DEFAULT_SPARSE_LAMBDA
26+
regularizer: float = PARAM.OPTIM_DEFAULT_SPARSE_LAMBDA
2727
rng: Optional[torch.Generator] = None
2828
maxiter: int = PARAM.OPTIM_SPARSE_DEFAULT_MAX_ITER
2929
lr: Optional[float] = None

torch_staintools/functional/optimization/dict_learning.py

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,24 @@
55
from .sparse_util import METHOD_SPARSE, validate_code, initialize_dict, collate_params
66
import torch
77
import torch.nn.functional as F
8-
from typing import Optional, cast
8+
from typing import Optional, cast, Tuple
99
from ..eps import get_eps
1010
from torch_staintools.constants import CONFIG
1111

12+
13+
@torch.compile
1214
def update_dict_cd(dictionary: torch.Tensor, x: torch.Tensor, code: torch.Tensor,
1315
positive: bool = True,
14-
dead_thresh=1e-7, rng: torch.Generator = None):
16+
dead_thresh=1e-7,
17+
rng: torch.Generator = None) -> Tuple[torch.Tensor, torch.Tensor]:
1518
"""Update the dictionary (stain matrix) using Block Coordinate Descent algorithm.
1619
1720
Can satisfy the positive constraint of dictionaries if specified.
18-
21+
Side effects: code is updated inplace.
1922
2023
Args:
21-
dictionary: Tensor of shape (n_features, n_components) Value of the dictionary at the previous iteration.
24+
dictionary: Tensor of shape (n_features, n_components).
25+
Value of the dictionary at the previous iteration.
2226
x: Tensor of shape (n_samples, n_components)
2327
Sparse coding of the data against which to optimize the dictionary.
2428
code: Tensor of shape (n_samples, n_components)
@@ -28,7 +32,7 @@ def update_dict_cd(dictionary: torch.Tensor, x: torch.Tensor, code: torch.Tensor
2832
rng: torch.Generator for initialization of dictionary and code.
2933
3034
Returns:
31-
35+
torch.Tensor, torch.Tensor, corresponding to the weight and the updated code.
3236
"""
3337
n_components = dictionary.size(1)
3438

@@ -38,10 +42,18 @@ def update_dict_cd(dictionary: torch.Tensor, x: torch.Tensor, code: torch.Tensor
3842
for k in range(n_components):
3943
d_k = dictionary[:, k]
4044
z_k = code[:, k]
41-
update_term = torch.outer(z_k, d_k)
42-
# Update k'th atom
43-
R += update_term
44-
new_d_k = torch.mv(R.T, z_k)
45+
46+
# vanilla. new_d = (R + z*d^T)^T * z
47+
# new_d = R^T*z + (d*z^T)*z = R^T*z + d*(z^T*z)
48+
# update_term = torch.outer(z_k, d_k)
49+
# R += update_term
50+
# new_d_k = torch.mv(R.T, z_k) # target
51+
52+
# R^T*z
53+
rtz = torch.mv(R.T, z_k)
54+
ztz = torch.dot(z_k, z_k)
55+
new_d_k = rtz + (d_k * ztz)
56+
4557
if positive:
4658
new_d_k = torch.clamp(new_d_k, min=0)
4759

@@ -60,14 +72,22 @@ def update_dict_cd(dictionary: torch.Tensor, x: torch.Tensor, code: torch.Tensor
6072
d_k_standard = new_d_k / (d_norm + get_eps(dictionary))
6173
d_k_final = torch.where(is_dead, d_k_random, d_k_standard)
6274
z_k_final = torch.where(is_dead, torch.zeros_like(z_k), z_k)
75+
76+
# fused
77+
# must be done before updating the dict
78+
r_delta = torch.outer(z_k, d_k) - torch.outer(z_k_final, d_k_final)
79+
6380
dictionary[:, k] = d_k_final
6481
code[:, k] = z_k_final
65-
R -= torch.outer(z_k_final, d_k_final)
6682

67-
return dictionary
83+
#R -= torch.outer(z_k_final, d_k_final)
84+
R += r_delta
85+
86+
return dictionary, code
6887

6988

70-
def update_dict_ridge(x, code, lambd=1e-4):
89+
@torch.compile
90+
def update_dict_ridge(x: torch.Tensor, code: torch.Tensor, lambd: float) -> Tuple[torch.Tensor, torch.Tensor]:
7191
"""Update an (unconstrained) dictionary with ridge regression
7292
7393
This is equivalent to a Newton step with the (L2-regularized) squared
@@ -80,17 +100,17 @@ def update_dict_ridge(x, code, lambd=1e-4):
80100
lambd: weight decay parameter
81101
82102
Returns:
83-
103+
torch.Tensor, torch.Tensor, corresponding to the weight and the unmodified code.
84104
"""
85105

86106
rhs = torch.mm(code.T, x)
87107
M = torch.mm(code.T, code)
88108
M.diagonal().add_(lambd * x.size(0))
89109
L = torch.linalg.cholesky(M)
90-
V = torch.cholesky_solve(rhs, L).T
110+
weight = torch.cholesky_solve(rhs, L).T
91111

92-
V = F.normalize(V, dim=0, eps=1e-12)
93-
return V
112+
weight = F.normalize(weight, dim=0, eps=1e-12)
113+
return weight, code
94114

95115

96116
def sparse_code(x: torch.Tensor,
@@ -118,7 +138,6 @@ def sparse_code(x: torch.Tensor,
118138
raise ValueError("invalid algorithm parameter '{}'.".format(algorithm))
119139
return z
120140

121-
122141
def dict_learning_loop(x: torch.Tensor,
123142
z0: torch.Tensor,
124143
weight: torch.Tensor,
@@ -135,7 +154,6 @@ def dict_learning_loop(x: torch.Tensor,
135154

136155
for _ in range(steps):
137156
# infer sparse coefficients and compute loss
138-
139157
z = sparse_code(x, weight, alpha, z0, algorithm=cast(METHOD_SPARSE, algorithm),
140158
lr=lr, maxiter=maxiter, tol=tol,
141159
positive_code=CONFIG.DICT_POSITIVE_CODE).contiguous()
@@ -145,36 +163,37 @@ def dict_learning_loop(x: torch.Tensor,
145163
if CONFIG.DICT_PERSIST_CODE:
146164
z0 = z
147165
else:
148-
z0 = validate_code(algorithm, init, None, weight, x, rng)
166+
z0 = validate_code(algorithm, init, z0=None, x=x, weight=weight, rng=rng)
149167

150168
# update dictionary
151169
if CONFIG.DICT_POSITIVE_DICTIONARY:
152-
weight = update_dict_cd(weight, x, z, positive=True, rng=rng)
170+
weight, z = update_dict_cd(weight, x, z, positive=True, rng=rng)
153171
else:
154-
weight = update_dict_ridge(x, z, lambd=lambd_ridge)
172+
weight, z = update_dict_ridge(x, z, lambd=lambd_ridge)
155173

156174
return weight
157175

158176

159177
def dict_learning(x: torch.Tensor,
160178
n_components: int,
161179
algorithm: METHOD_SPARSE,
162-
*, alpha: float = 1e-1,
163-
lambd_ridge: float = 1e-2,
164-
steps: int = 60,
165-
rng: torch.Generator = None,
166-
init: Optional[str] = 'zero',
167-
lr: Optional[float] = None,
168-
maxiter: int = 50,
169-
tol: float = 1e-5, ):
180+
*, alpha: float,
181+
lambd_ridge: float,
182+
steps: int,
183+
rng: Optional[torch.Generator],
184+
init: Optional[str],
185+
lr: Optional[float],
186+
maxiter: int,
187+
tol: float, ):
170188
n_samples, n_features = x.shape
189+
# pixel x c
171190
x = x.contiguous()
172-
191+
# c x stain
173192
weight = initialize_dict(n_features=n_features, n_components=n_components, device=x.device,
174193
rng=rng, positive_dict=CONFIG.DICT_POSITIVE_DICTIONARY)
175194

176195
# initialize
177-
z0 = validate_code(algorithm, init, None, weight, x, rng)
196+
z0 = validate_code(algorithm, init, z0=None, x=x, weight=weight, rng=rng)
178197
assert z0 is not None
179198
lr, alpha, tol = collate_params(z0, x, lr, weight, alpha, tol)
180199
return dict_learning_loop(x, z0, weight, alpha, algorithm, lambd_ridge=lambd_ridge,

torch_staintools/functional/optimization/solver.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,16 @@ def cd_update(z, b):
5656
return z
5757

5858
def rss_grad(z_k: torch.Tensor, x: torch.Tensor, weight: torch.Tensor):
59+
# kernelize it?
5960
resid = torch.matmul(z_k, weight.T) - x
6061
return torch.matmul(resid, weight)
6162

63+
def rss_grad_fast(z_k: torch.Tensor, hessian: torch.Tensor, b: torch.Tensor):
64+
return torch.mm(z_k, hessian) - b
65+
66+
def _grad_precompute(x: torch.Tensor, weight: torch.Tensor):
67+
# return Hessian and bias
68+
return torch.mm(weight.T, weight), torch.mm(x, weight)
6269

6370
def softshrink(x: torch.Tensor, lambd: torch.Tensor) -> torch.Tensor:
6471
lambd = lambd.clamp_min(0)
@@ -67,8 +74,8 @@ def softshrink(x: torch.Tensor, lambd: torch.Tensor) -> torch.Tensor:
6774

6875
def ista_step(
6976
z: torch.Tensor,
70-
x: torch.Tensor,
71-
weight: torch.Tensor,
77+
hessian: torch.Tensor,
78+
b: torch.Tensor,
7279
alpha: torch.Tensor,
7380
lr: torch.Tensor,
7481
positive: bool,
@@ -77,8 +84,10 @@ def ista_step(
7784
7885
Args:
7986
z: code. num_pixels x num_stain
80-
x: OD space. num_pixels x num_channel
81-
weight: init from stain matrix --> num_channel x num_stain
87+
# x: OD space. num_pixels x num_channel
88+
# weight: init from stain matrix --> num_channel x num_stain
89+
hessian: precomputed wtw
90+
b: precomputed xw
8291
alpha: tensor form of the ista penalizer
8392
lr: tensor form of step size
8493
positive: if force z to be positive
@@ -88,7 +97,8 @@ def ista_step(
8897

8998

9099
z_k_safe = torch.nan_to_num(z, nan=0.0, posinf=0.0, neginf=0.0)
91-
g = rss_grad(z_k_safe, x, weight) # same shape as z
100+
# g = rss_grad(z_k_safe, x, weight) # same shape as z
101+
g = rss_grad_fast(z_k_safe, hessian, b)
92102
g_safe = torch.nan_to_num(g, nan=0.0, posinf=0.0, neginf=0.0)
93103

94104
# guard lr
@@ -108,15 +118,15 @@ def fista_step(
108118
z: torch.Tensor,
109119
y: torch.Tensor,
110120
t: torch.Tensor,
111-
x: torch.Tensor,
112-
weight: torch.Tensor,
121+
hessian: torch.Tensor,
122+
b: torch.Tensor,
113123
alpha: torch.Tensor,
114124
lr: torch.Tensor,
115125
positive_code: bool,
116126
tol: float
117127
):
118128

119-
z_next = ista_step(y, x, weight, alpha, lr, positive_code)
129+
z_next = ista_step(y, hessian, b, alpha, lr, positive_code)
120130
delta = z_next - z
121131
diff = delta.abs().sum()
122132
just_finished = diff <= tol
@@ -128,12 +138,12 @@ def fista_step(
128138

129139

130140
@torch.compile
131-
def ista_loop(z: torch.Tensor, x: torch.Tensor, weight: torch.Tensor,
141+
def ista_loop(z: torch.Tensor, hessian: torch.Tensor, b: torch.Tensor,
132142
alpha: torch.Tensor, lr: torch.Tensor,
133143
tol: float, maxiter: int, positive_code: bool):
134144
is_converged = torch.tensor(False, device=z.device, dtype=torch.bool)
135145
for _ in range(maxiter):
136-
z_next = ista_step(z, x, weight, alpha, lr, positive_code)
146+
z_next = ista_step(z, hessian, b, alpha, lr, positive_code)
137147
# check convergence
138148
diff = (z - z_next).abs().sum()
139149
just_finished = diff <= tol
@@ -146,8 +156,8 @@ def ista_loop(z: torch.Tensor, x: torch.Tensor, weight: torch.Tensor,
146156
@torch.compile
147157
def fista_loop(
148158
z: torch.Tensor,
149-
x: torch.Tensor,
150-
weight: torch.Tensor,
159+
hessian: torch.Tensor,
160+
b: torch.Tensor,
151161
alpha: torch.Tensor,
152162
lr: torch.Tensor,
153163
tol: float,
@@ -158,8 +168,10 @@ def fista_loop(
158168
159169
Args:
160170
z: Initial guess
161-
x: Data input (OD space)
162-
weight: Dictionary matrix
171+
# x: Data input (OD space)
172+
# weight: Dictionary matrix
173+
hessian: precomputed wtw
174+
b: precomputed xw
163175
alpha: Regularization strength
164176
lr: Learning rate
165177
maxiter: Maximum iterations
@@ -176,7 +188,7 @@ def fista_loop(
176188
for i in range(maxiter):
177189

178190
z_next, y_next, t_next, just_finished = fista_step(z, y, t,
179-
x, weight,
191+
hessian, b,
180192
alpha, lr,
181193
positive_code, tol)
182194

@@ -212,8 +224,10 @@ def ista(x: torch.Tensor, z0: torch.Tensor,
212224
z0 = z0.contiguous()
213225
x = x.contiguous()
214226
weight = weight.contiguous()
215-
216-
return ista_loop(z0, x, weight, alpha, lr, tol, maxiter, positive_code)
227+
hessian, b = _grad_precompute(x, weight)
228+
# hessian = hessian.contiguous()
229+
# b = b.contiguous()
230+
return ista_loop(z0, hessian, b, alpha, lr, tol, maxiter, positive_code)
217231

218232

219233
def fista(x: torch.Tensor, z0: torch.Tensor,
@@ -240,5 +254,7 @@ def fista(x: torch.Tensor, z0: torch.Tensor,
240254
z0 = z0.contiguous()
241255
x = x.contiguous()
242256
weight = weight.contiguous()
243-
244-
return fista_loop(z0, x, weight, alpha, lr, tol, maxiter, positive_code)
257+
hessian, b = _grad_precompute(x, weight)
258+
# hessian = hessian.contiguous()
259+
# b = b.contiguous()
260+
return fista_loop(z0, hessian, b, alpha, lr, tol, maxiter, positive_code)

torch_staintools/functional/optimization/sparse_util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Literal, get_args, Tuple
1+
from typing import Optional, Literal, get_args, Tuple, cast
22
import torch
33
from torch.nn import functional as F
44
from torch_staintools.constants import PARAM
@@ -81,14 +81,14 @@ def initialize_dict(n_features: int, n_components: int,
8181

8282

8383
def validate_code(algorithm: METHOD_SPARSE,
84-
init: str, z0: Optional[torch.Tensor],
85-
x: torch.Tensor, weight, rng):
84+
init: Optional[MODE_INIT], z0: Optional[torch.Tensor],
85+
x: torch.Tensor, weight: torch.Tensor, rng):
8686
# initialize code variable
8787
n_samples = x.size(0)
8888
n_components = weight.size(1)
8989
init = _init_defaults.get(algorithm, 'zero') if init is None else init
9090
if z0 is None:
91-
z0 = initialize_code(x, weight, mode=init, rng=rng)
91+
z0 = initialize_code(x, weight, mode=cast(MODE_INIT, init), rng=rng)
9292
assert z0.shape == (n_samples, n_components)
9393
return z0
9494

0 commit comments

Comments
 (0)