Skip to content

Commit b2c091d

Browse files
authored
Stable starlets (#249)
* remove proximal thresholds and add proper edge treatment in bspline_convolve * added scales to starelet_reconstruction; docs call for constraint/prior
1 parent e320124 commit b2c091d

File tree

2 files changed

+60
-41
lines changed

2 files changed

+60
-41
lines changed

src/scarlet2/morphology.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -211,51 +211,53 @@ def f(self, r2):
211211
return jnp.exp(-bn * (r2 ** (0.5 / n) - 1))
212212

213213

214-
prox_plus = lambda x: jnp.maximum(x, 0) # noqa: E731
215-
prox_soft = lambda x, thresh: jnp.sign(x) * prox_plus(jnp.abs(x) - thresh) # noqa: E731
216-
prox_soft_plus = lambda x, thresh: prox_plus(prox_soft(x, thresh)) # noqa: E731
217-
218-
219214
class StarletMorphology(Morphology):
220215
"""Morphology in the starlet basis
221216
217+
Notes
218+
-----
219+
The starlet basis is overcomplete, which means it can exactly represent the same image in multiple ways.
220+
If used without constraints or priors on the starlet coefficients, this morphology model is functionally
221+
indistinguishable from a 2D pixel array, while using more operations.
222+
222223
See Also
223224
--------
224225
scarlet2.wavelets.Starlet
225226
"""
226227

227228
coeffs: jnp.ndarray
228229
"""Starlet coefficients"""
229-
l1_thresh: float = eqx.field(default=0)
230-
"""L1 threshold for coefficient to create sparse representation"""
231-
positive: bool = eqx.field(default=True)
232-
"""Whether the coefficients are restricted to non-negative values"""
233230

234231
def __call__(self, **kwargs):
235232
"""Evaluate the model"""
236-
f = prox_soft_plus if self.positive else prox_soft
237-
return starlet_reconstruction(f(self.coeffs, self.l1_thresh))
233+
return starlet_reconstruction(self.coeffs)
238234

239235
@property
240236
def shape(self):
241237
"""Shape (2D) of the morphology model"""
242238
return self.coeffs.shape[-2:] # wavelet coeffs: scales x n1 x n2
243239

244240
@staticmethod
245-
def from_image(image, **kwargs):
241+
def from_image(image, min_value=None, max_value=None):
246242
"""Create starlet morphology from `image`
247243
248244
Parameters
249245
----------
250246
image: array
251247
2D image array to determine coefficients from.
252-
kwargs: dict
253-
Additional arguments for `__init__`
248+
min_value: (float, None):
249+
Minimum value threshold for coefficients
250+
max_value: (float, None):
251+
Minimum value threshold for coefficients
254252
255253
Returns
256254
-------
257255
StarletMorphology
258256
"""
259257
# Starlet transform of image (n1,n2) into coefficient with 3 dimensions: (scales+1,n1,n2)
260258
coeffs = starlet_transform(image)
261-
return StarletMorphology(coeffs, **kwargs)
259+
if min_value is not None:
260+
coeffs = coeffs.at[coeffs < min_value].set(min_value)
261+
if max_value is not None:
262+
coeffs = coeffs.at[coeffs > max_value].set(max_value)
263+
return StarletMorphology(coeffs)

src/scarlet2/wavelets.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -90,28 +90,37 @@ def bspline_convolve(image, scale):
9090
spacing between adjacent pixels with the spline.
9191
9292
"""
93-
# Filter for the scarlet transform. Here bspline
9493
h1d = jnp.array([1.0 / 16, 1.0 / 4, 3.0 / 8, 1.0 / 4, 1.0 / 16])
95-
j = scale
96-
97-
slice0 = slice(None, -(2 ** (j + 1)))
98-
slice1 = slice(None, -(2**j))
99-
slice3 = slice(2**j, None)
100-
slice4 = slice(2 ** (j + 1), None)
101-
102-
# row
103-
col = image * h1d[2]
104-
col = col.at[slice4].add(image[slice0] * h1d[0])
105-
col = col.at[slice3].add(image[slice1] * h1d[1])
106-
col = col.at[slice1].add(image[slice3] * h1d[3])
107-
col = col.at[slice0].add(image[slice4] * h1d[4])
108-
109-
# column
110-
result = col * h1d[2]
111-
result = result.at[:, slice4].add(col[:, slice0] * h1d[0])
112-
result = result.at[:, slice3].add(col[:, slice1] * h1d[1])
113-
result = result.at[:, slice1].add(col[:, slice3] * h1d[3])
114-
result = result.at[:, slice0].add(col[:, slice4] * h1d[4])
94+
step = 2**scale
95+
ny, nx = image.shape
96+
97+
row_idx = jnp.arange(ny)
98+
col_idx = jnp.arange(nx)
99+
100+
def reflect(idx, size):
101+
# reflect indices at boundaries into [0, size-1]
102+
idx = jnp.abs(idx)
103+
# Map into [0, 2*size - 2] period, then fold back
104+
idx = idx % (2 * size - 2)
105+
return jnp.where(idx >= size, 2 * size - 2 - idx, idx)
106+
107+
# a simpler version: clamp pixels beyond the edge to the edge pixels
108+
# return jnp.clip(idx, 0, size - 1) # clamp — or use true reflection below
109+
110+
# Row convolution
111+
col = jnp.zeros_like(image)
112+
for k, offset in enumerate([-2 * step, -step, 0, step, 2 * step]):
113+
reflected = reflect(row_idx + offset, ny)
114+
shifted = jnp.take(image, reflected, axis=0)
115+
col += shifted * h1d[k]
116+
117+
# Column convolution
118+
result = jnp.zeros_like(col)
119+
for k, offset in enumerate([-2 * step, -step, 0, step, 2 * step]):
120+
reflected = reflect(col_idx + offset, nx)
121+
shifted = jnp.take(col, reflected, axis=1)
122+
result += shifted * h1d[k]
123+
115124
return result
116125

117126

@@ -164,7 +173,7 @@ def starlet_transform(image, scales=None, generation=2, convolve2d=None):
164173
return starlet
165174

166175

167-
def starlet_reconstruction(starlets, generation=2, convolve2d=None):
176+
def starlet_reconstruction(starlets, generation=2, convolve2d=None, scales=None):
168177
"""Reconstruct an image from a dictionary of starlets
169178
170179
Parameters
@@ -177,6 +186,8 @@ def starlet_reconstruction(starlets, generation=2, convolve2d=None):
177186
convolve2d: function
178187
The filter function to use to convolve the image
179188
with starlets in 2D.
189+
scales: list of int
190+
The scales to include in the reconstruction (0 being the smallest)
180191
181192
Returns
182193
-------
@@ -187,11 +198,17 @@ def starlet_reconstruction(starlets, generation=2, convolve2d=None):
187198
return jnp.sum(starlets, axis=0)
188199
if convolve2d is None:
189200
convolve2d = bspline_convolve
190-
scales = len(starlets) - 1
191201

192-
c = starlets[-1]
193-
for i in range(1, scales + 1):
194-
j = scales - i
202+
# scales sorted in reverse order: from largest to smallest
203+
max_scale = len(starlets) - 1
204+
if scales is None:
205+
scales = tuple(max_scale - i for i in range(1, max_scale + 1))
206+
else:
207+
scales = sorted(tuple(scale for scale in scales if scale <= max_scale), reverse=True)
208+
209+
# reconstruct: initialize from largest, go to smallest
210+
c = starlets[scales[0]]
211+
for j in scales[1:]:
195212
cj = convolve2d(c, j)
196213
c = cj + starlets[j]
197214
return c

0 commit comments

Comments
 (0)