Skip to content

Commit 32572d0

Browse files
committed
refactoring. param preprocessing logic is now performed in optimization rather than outside the function calls; cleaning CD implementation for torch.compile
1 parent 435499d commit 32572d0

File tree

3 files changed

+131
-73
lines changed

3 files changed

+131
-73
lines changed

torch_staintools/functional/concentration/implementation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_concentrations_single(od_flatten: torch.Tensor,
6262
computed concentration: num_stains x num_pixel_in_tissue_mask
6363
"""
6464
z0 = initialize_code(od_flatten, stain_matrix.T, 'zero', rng=rng)
65-
lr, regularizer, tol = collate_params(z0, od_flatten, lr, stain_matrix.T, regularizer, tol)
65+
lr, regularizer, tol = collate_params(od_flatten, lr, stain_matrix.T, regularizer, tol)
6666
match algorithm:
6767
case 'cd':
6868
return coord_descent(od_flatten, z0, stain_matrix.T,

torch_staintools/functional/optimization/solver.py

Lines changed: 127 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,126 @@
11
"""
22
code directly adapted from https://github.com/rfeinman/pytorch-lasso
33
"""
4+
from typing import Optional
5+
46
import torch
5-
import torch.nn.functional as F
67

78
from torch_staintools.functional.compile import lazy_compile
9+
from torch_staintools.functional.optimization.sparse_util import collate_params
10+
11+
def _preprocess_input(z0: torch.Tensor,
12+
x: torch.Tensor,
13+
lr: Optional[float | torch.Tensor],
14+
weight: torch.Tensor,
15+
alpha: float | torch.Tensor,
16+
tol: float):
17+
lr, alpha, tol = collate_params(x, lr, weight, alpha, tol)
18+
z0 = z0.contiguous()
19+
x = x.contiguous()
20+
weight = weight.contiguous()
21+
tol = z0.numel() * tol
22+
return z0, x, weight, lr, alpha, tol
23+
24+
25+
def _grad_precompute(x: torch.Tensor, weight: torch.Tensor):
26+
# return Hessian and bias
27+
return torch.mm(weight.T, weight), torch.mm(x, weight)
28+
29+
def _softshrink(x: torch.Tensor, lambd: torch.Tensor) -> torch.Tensor:
30+
lambd = lambd.clamp_min(0)
31+
return x.sign() * (x.abs() - lambd).clamp_min(0)
32+
33+
def softshrink(x: torch.Tensor, lambd: torch.Tensor, positive: bool) -> torch.Tensor:
34+
if positive:
35+
return (x - lambd).clamp_min(0)
36+
return _softshrink(x, lambd)
37+
38+
def cd_step(
39+
z: torch.Tensor,
40+
b: torch.Tensor,
41+
s: torch.Tensor,
42+
alpha: torch.Tensor,
43+
positive_code: bool,
44+
) -> tuple[torch.Tensor, torch.Tensor]:
45+
z = torch.nan_to_num(z, nan=0.0, posinf=0.0, neginf=0.0)
46+
b = torch.nan_to_num(b, nan=0.0, posinf=0.0, neginf=0.0)
47+
s = torch.nan_to_num(s, nan=0.0, posinf=0.0, neginf=0.0)
48+
alpha = torch.nan_to_num(alpha, nan=0.0, posinf=0.0, neginf=0.0)
49+
50+
z_proposal = softshrink(b, alpha, positive_code)
51+
52+
z_diff = z_proposal - z
53+
54+
k = z_diff.abs().argmax(dim=1)
55+
kk = k.unsqueeze(1)
56+
57+
z_diff_selected = z_diff.gather(1, kk)
58+
59+
one_hot = torch.nn.functional.one_hot(
60+
k, num_classes=z.size(1)
61+
).to(dtype=z.dtype)
62+
s_col_vec = torch.mm(one_hot, s.T)
63+
64+
b_next = b + s_col_vec * z_diff_selected
65+
66+
z_next_selected = z_proposal.gather(1, kk)
67+
z_next = z.scatter(1, kk, z_next_selected)
68+
69+
finite_row = (
70+
torch.isfinite(z).all(dim=1) &
71+
torch.isfinite(b).all(dim=1) &
72+
torch.isfinite(z_next).all(dim=1) &
73+
torch.isfinite(b_next).all(dim=1)
74+
).unsqueeze(1)
75+
z_next = torch.where(finite_row, z_next, z)
76+
b_next = torch.where(finite_row, b_next, b)
77+
78+
return z_next, b_next
79+
80+
81+
@lazy_compile
82+
def cd_loop(
83+
z: torch.Tensor,
84+
b: torch.Tensor,
85+
s: torch.Tensor,
86+
alpha: torch.Tensor,
87+
tol: float,
88+
maxiter: int,
89+
positive_code: bool,
90+
) -> torch.Tensor:
891

92+
is_converged = torch.zeros_like(z[:, 0], dtype=torch.bool)
93+
for _ in range(maxiter):
94+
z_next, b_next = cd_step(z, b, s, alpha, positive_code)
95+
96+
update = (z_next - z).abs().sum(dim=1) # [N]
97+
just_finished = update <= tol
998

10-
def coord_descent(x: torch.Tensor, z0: torch.Tensor, weight: torch.Tensor,
99+
# freeze if converged. can't early break here.
100+
cvf_2d = is_converged.unsqueeze(1)
101+
z = torch.where(cvf_2d, z, z_next)
102+
b = torch.where(cvf_2d, b, b_next)
103+
104+
is_converged = is_converged | just_finished
105+
106+
return softshrink(b, alpha, positive=positive_code)
107+
108+
109+
def coord_descent(x: torch.Tensor,
110+
z0: torch.Tensor,
111+
weight: torch.Tensor,
11112
alpha: torch.Tensor,
12113
maxiter: int, tol: float,
13114
positive_code: bool):
14115
""" modified coord_descent"""
15-
if isinstance(alpha, torch.Tensor):
16-
assert alpha.numel() == 1
17-
alpha = alpha.item()
18-
input_dim, code_dim = weight.shape # [D,K]
19-
batch_size, input_dim1 = x.shape # [N,D]
20-
assert input_dim1 == input_dim
21-
tol = tol * code_dim
22-
if z0 is None:
23-
z = x.new_zeros(batch_size, code_dim) # [N,K]
24-
else:
25-
assert z0.shape == (batch_size, code_dim)
26-
z = z0
27-
28-
b = torch.mm(x, weight) # [N,K]
29-
30-
# precompute S = I - W^T @ W
31-
S = - torch.mm(weight.T, weight) # [K,K]
32-
S.diagonal().add_(1.)
33-
34-
35-
def cd_update(z, b):
36-
if positive_code:
37-
z_next = (b - alpha).clamp_min(0)
38-
else:
39-
z_next = F.softshrink(b, alpha) # [N,K]
40-
z_diff = z_next - z # [N,K]
41-
k = z_diff.abs().argmax(1) # [N]
42-
kk = k.unsqueeze(1) # [N,1]
43-
b = b + S[:, k].T * z_diff.gather(1, kk) # [N,K] += [N,K] * [N,1]
44-
z = z.scatter(1, kk, z_next.gather(1, kk))
45-
return z, b
46-
47-
active = torch.arange(batch_size, device=weight.device)
48-
for i in range(maxiter):
49-
if len(active) == 0:
50-
break
51-
z_old = z[active]
52-
z_new, b[active] = cd_update(z_old, b[active])
53-
update = (z_new - z_old).abs().sum(1)
54-
z[active] = z_new
55-
active = active[update > tol]
56-
57-
z = F.softshrink(b, alpha)
116+
# lr set to one to avoid L computation. Lr is not used in CD
117+
z0, x, weight, lr, alpha, tol = _preprocess_input(z0, x, 1, weight, alpha, tol)
118+
119+
hessian, b = _grad_precompute(x, weight)
120+
code_dim = weight.size(1)
121+
# S = I - H
122+
s = torch.eye(code_dim, device=x.device, dtype=x.dtype) - hessian
123+
z = cd_loop(z0, b, s, alpha, tol=tol, maxiter=maxiter, positive_code=positive_code)
58124
return z
59125

60126
def rss_grad(z_k: torch.Tensor, x: torch.Tensor, weight: torch.Tensor):
@@ -65,15 +131,6 @@ def rss_grad(z_k: torch.Tensor, x: torch.Tensor, weight: torch.Tensor):
65131
def rss_grad_fast(z_k: torch.Tensor, hessian: torch.Tensor, b: torch.Tensor):
66132
return torch.mm(z_k, hessian) - b
67133

68-
def _grad_precompute(x: torch.Tensor, weight: torch.Tensor):
69-
# return Hessian and bias
70-
return torch.mm(weight.T, weight), torch.mm(x, weight)
71-
72-
def softshrink(x: torch.Tensor, lambd: torch.Tensor) -> torch.Tensor:
73-
lambd = lambd.clamp_min(0)
74-
return x.sign() * (x.abs() - lambd).clamp_min(0)
75-
76-
77134
def ista_step(
78135
z: torch.Tensor,
79136
hessian: torch.Tensor,
@@ -105,13 +162,10 @@ def ista_step(
105162

106163
# guard lr
107164
lr_safe = torch.nan_to_num(lr, nan=0.0, posinf=0.0, neginf=0.0)
108-
z_proposal = z - lr * g_safe
109-
threshold = alpha * lr
110-
if positive:
111-
z_next = (z_proposal - threshold).clamp_min(0)
112-
else:
113-
# z_next = F.softshrink(z_prev - lr * rss_grad(z_prev, x, weight), alpha * lr)
114-
z_next = softshrink(z_k_safe - lr_safe * g_safe, alpha * lr_safe)
165+
z_proposal = z - lr_safe * g_safe
166+
threshold = alpha * lr_safe
167+
168+
z_next = softshrink(z_proposal, threshold, positive)
115169
finite_mask = torch.isfinite(z) & torch.isfinite(g) & torch.isfinite(lr)
116170
return torch.where(finite_mask, z_next, z)
117171

@@ -226,10 +280,12 @@ def ista(x: torch.Tensor, z0: torch.Tensor,
226280
Returns:
227281
228282
"""
229-
# lr, alpha, tol = collate_params(z0, x, lr, weight, alpha, tol)
230-
z0 = z0.contiguous()
231-
x = x.contiguous()
232-
weight = weight.contiguous()
283+
# lr, alpha, tol = collate_params(x, lr, weight, alpha, tol)
284+
# z0 = z0.contiguous()
285+
# x = x.contiguous()
286+
# weight = weight.contiguous()
287+
# tol = z0.numel() * tol
288+
z0, x, weight, lr, alpha, tol = _preprocess_input(z0, x, lr, weight, alpha, tol)
233289
hessian, b = _grad_precompute(x, weight)
234290
# hessian = hessian.contiguous()
235291
# b = b.contiguous()
@@ -256,10 +312,12 @@ def fista(x: torch.Tensor, z0: torch.Tensor,
256312
Returns:
257313
258314
"""
259-
# lr, alpha, tol = collate_params(z0, x, lr, weight, alpha, tol)
260-
z0 = z0.contiguous()
261-
x = x.contiguous()
262-
weight = weight.contiguous()
315+
# lr, alpha, tol = collate_params(x, lr, weight, alpha, tol)
316+
# z0 = z0.contiguous()
317+
# x = x.contiguous()
318+
# weight = weight.contiguous()
319+
# tol = z0.numel() * tol
320+
z0, x, weight, lr, alpha, tol = _preprocess_input(z0, x, lr, weight, alpha, tol)
263321
hessian, b = _grad_precompute(x, weight)
264322
# hessian = hessian.contiguous()
265323
# b = b.contiguous()

torch_staintools/functional/optimization/sparse_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,7 @@ def lipschitz_constant(w: torch.Tensor):
114114
return L + torch.finfo(L.dtype).eps
115115

116116

117-
def collate_params(z0: torch.Tensor,
118-
x: torch.Tensor,
117+
def collate_params(x: torch.Tensor,
119118
lr: Optional[float | torch.Tensor],
120119
weight: torch.Tensor,
121120
alpha: float | torch.Tensor,
@@ -126,7 +125,8 @@ def collate_params(z0: torch.Tensor,
126125

127126
# if tol is None:
128127
# tol = PARAM.OPTIM_DEFAULT_TOL
129-
tol = z0.numel() * tol
128+
# handle it inside optimization.
129+
# tol = z0.numel() * tol
130130

131131
# if alpha is None:
132132
# alpha = PARAM.OPTIM_DEFAULT_SPARSE_ISTA_LAMBDA

0 commit comments

Comments
 (0)