-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtps.py
More file actions
246 lines (203 loc) · 8.29 KB
/
tps.py
File metadata and controls
246 lines (203 loc) · 8.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import numpy as np
from scipy.spatial.distance import cdist
from tqdm import tqdm
class ThinPlateSplineTransform:
def __init__(self, affine_only=False, chunk_size=None, dtype=np.float32):
self._estimated = False
self.params = None
self.affine_only = affine_only
self.chunk_size = chunk_size
self.dtype = dtype
def __call__(self, coords):
"""Transform coordinates from source to destination using thin plate spline."""
if not self._estimated:
raise ValueError("Transformation not estimated.")
params = np.moveaxis(self.params, 0, -1)
coords = np.asarray(coords).astype(int)
out = params[(coords[:, 1], coords[:, 0])]
return out
def _estimate_chunk_size(self, n_pixels, n_control_points, available_memory_gb=2.0):
"""
Estimate optimal chunk size based on memory constraints.
Parameters
----------
n_pixels : int
Total number of pixels to process
n_control_points : int
Number of control points
available_memory_gb : float
Available memory in GB for the computation
Returns
-------
chunk_size : int
Optimal number of pixels to process per chunk
"""
# Memory per chunk: chunk_size * n_control_points * bytes_per_float
bytes_per_element = np.dtype(self.dtype).itemsize
# Main memory consumers:
# 1. Distance matrix: chunk_size × n_control_points
# 2. U matrix: chunk_size × n_control_points
# 3. Intermediate arrays: ~2x the above for safety
memory_per_pixel = (
n_control_points * bytes_per_element * 4
) # 4x for safety margin
available_bytes = available_memory_gb * 1024**3
chunk_size = int(available_bytes / memory_per_pixel)
# Clamp to reasonable bounds
chunk_size = max(
1000, min(chunk_size, n_pixels)
) # At least 1000, at most all pixels
return chunk_size
def _check_valid_points(self, src, dst):
"""Check if source and destination points are valid."""
if src.shape != dst.shape:
raise ValueError("Source and destination points must have the same shape.")
elif src.shape == (2,):
raise ValueError(
"Incorrect shape for control points; expected (N, 2), received (2,)."
)
elif src.shape[1] != 2:
raise ValueError("Control points must be 2D coordinates.")
elif src.shape[0] < 3:
raise ValueError("At least 3 control points are required.")
# Check for duplicate points
src_duplicates = np.unique(src, axis=0).shape[0] != src.shape[0]
dst_duplicates = np.unique(dst, axis=0).shape[0] != dst.shape[0]
if src_duplicates or dst_duplicates:
raise ValueError("Control points contain duplicates.")
return True
def estimate(self, src, dst, size, available_memory_gb=2.0):
"""Estimate optimal spline mappings between source and destination points.
Parameters
----------
src : (N, 2) array_like
Control points at source coordinates.
dst : (N, 2) array_like
Control points at destination coordinates.
size : tuple
Size of the reference image (height, width).
Returns
-------
success: bool
True indicates that the estimation was successful.
Notes
-----
The number N of source and destination points must match.
"""
# validate input points
self._check_valid_points(src, dst)
# convert input pixels in arrays. cps are control points
xs = np.asarray(dst[:, 0])
ys = np.array(dst[:, 1])
cps = np.vstack([xs, ys]).T
xt = np.asarray(src[:, 0])
yt = np.array(src[:, 1])
n = len(xs)
# print("Number of control points:", n)
# construct L
L = self._TPS_makeL(cps)
# construct Y
xtAug = np.concatenate([xt, np.zeros(3)])
ytAug = np.concatenate([yt, np.zeros(3)])
Y = np.vstack([xtAug, ytAug]).T
# calculate unknown params in (W | a).T
params = np.linalg.solve(L, Y)
wi = params[:n, :]
a1 = params[n, :]
ax = params[n + 1, :]
ay = params[n + 2, :]
# Thin plate spline calculation
# at some point (x,y) in reference, the corresponding point in the distorted data is at
# [X,Y] = a1 + ax*xRef + ay*yRef + sum(wi*Ui)
# dimensions of reference image in pixels
lx = size[1]
ly = size[0]
# for fineness of grid, if you want to fix all points, leave nx=lx, ny=ly
nx = lx # num points along reference x-direction, full correction will have nx = lx
ny = ly # num points along reference y-direction, full correction will have ny = ly
n_pixels = nx * ny
# (x,y) coordinates from reference image
x = np.linspace(1, lx, nx)
y = np.linspace(1, ly, ny)
xgd, ygd = np.meshgrid(x, y)
pixels = np.vstack([xgd.flatten(), ygd.flatten()]).T
# affine transformation portion
axs = np.einsum("i,jk->ijk", ax, xgd)
ays = np.einsum("i,jk->ijk", ay, ygd)
affine = axs + ays
affine[0, :, :] += a1[0]
affine[1, :, :] += a1[1]
del xgd, ygd, x, y
if self.affine_only:
self.params = affine
else:
# Determine chunk size
if self.chunk_size is None:
chunk_size = self._estimate_chunk_size(n_pixels, n, available_memory_gb)
else:
chunk_size = self.chunk_size
n_chunks = int(np.ceil(n_pixels / chunk_size))
print(
f"Processing {n_pixels:,} pixels in {n_chunks} chunk(s) of ~{chunk_size:,} pixels each"
)
# Compute bending portion in chunks
print("Computing bending transformation...")
bend = np.zeros((2, ny, nx), dtype=self.dtype)
for chunk_idx in tqdm(range(n_chunks)):
start_idx = chunk_idx * chunk_size
end_idx = min((chunk_idx + 1) * chunk_size, n_pixels)
chunk_pixels = pixels[start_idx:end_idx]
# Vectorized distance calculation for this chunk
R = cdist(chunk_pixels, cps, "euclidean").astype(self.dtype)
Rsq = R * R
Rsq[R == 0] = 1 # Avoid log(0)
U = Rsq * np.log(Rsq)
# Matrix multiplication: (chunk_size, n) @ (n, 2) = (chunk_size, 2)
bend_chunk = U @ wi
# Reshape and place into output array
chunk_len = end_idx - start_idx
bend_chunk_reshaped = bend_chunk.T.reshape(2, chunk_len)
# Map flat indices back to 2D coordinates
y_indices = (start_idx + np.arange(chunk_len)) // nx
x_indices = (start_idx + np.arange(chunk_len)) % nx
bend[:, y_indices, x_indices] = bend_chunk_reshaped
# Clean up chunk memory
del R, Rsq, U, bend_chunk, bend_chunk_reshaped
self.params = affine + bend
self.size = size
self._estimated = True
return True
def _TPS_makeL(self, cp):
"""Function to make the L matrix for thin plate spline calculation.
Parameters
----------
cp : (K, 2) array_like
Control points.
Returns
-------
L : (K+3, K+3) ndarray
The L matrix for thin plate spline calculation.
"""
K = cp.shape[0]
L = np.zeros((K + 3, K + 3))
# Make P in L
L[:K, K] = 1
L[:K, K + 1 : K + 3] = cp
# Make P.T in L
L[K, :K] = 1
L[K + 1 :, :K] = cp.T
# Compute U matrix
R = cdist(cp, cp, "euclidean").astype(self.dtype)
Rsq = R * R
Rsq[R == 0] = 1
U0 = Rsq * np.log(Rsq)
mask = R > 0
U1 = np.zeros_like(R, dtype=self.dtype)
U1[mask] = R[mask] ** 2 * (2 * np.log(R[mask]))
R_safe = np.maximum(R, 1e-10)
U2 = R * R * np.log(R_safe * R_safe)
U2[R < 1e-10] = 0
U = U2
np.fill_diagonal(U, 0) # should be redundant
L[:K, :K] = U
return L