@@ -90,28 +90,37 @@ def bspline_convolve(image, scale):
9090 spacing between adjacent pixels with the spline.
9191
9292 """
93- # Filter for the scarlet transform. Here bspline
9493 h1d = jnp .array ([1.0 / 16 , 1.0 / 4 , 3.0 / 8 , 1.0 / 4 , 1.0 / 16 ])
95- j = scale
96-
97- slice0 = slice (None , - (2 ** (j + 1 )))
98- slice1 = slice (None , - (2 ** j ))
99- slice3 = slice (2 ** j , None )
100- slice4 = slice (2 ** (j + 1 ), None )
101-
102- # row
103- col = image * h1d [2 ]
104- col = col .at [slice4 ].add (image [slice0 ] * h1d [0 ])
105- col = col .at [slice3 ].add (image [slice1 ] * h1d [1 ])
106- col = col .at [slice1 ].add (image [slice3 ] * h1d [3 ])
107- col = col .at [slice0 ].add (image [slice4 ] * h1d [4 ])
108-
109- # column
110- result = col * h1d [2 ]
111- result = result .at [:, slice4 ].add (col [:, slice0 ] * h1d [0 ])
112- result = result .at [:, slice3 ].add (col [:, slice1 ] * h1d [1 ])
113- result = result .at [:, slice1 ].add (col [:, slice3 ] * h1d [3 ])
114- result = result .at [:, slice0 ].add (col [:, slice4 ] * h1d [4 ])
94+ step = 2 ** scale
95+ ny , nx = image .shape
96+
97+ row_idx = jnp .arange (ny )
98+ col_idx = jnp .arange (nx )
99+
100+ def reflect (idx , size ):
101+ # reflect indices at boundaries into [0, size-1]
102+ idx = jnp .abs (idx )
103+ # Map into [0, 2*size - 2] period, then fold back
104+ idx = idx % (2 * size - 2 )
105+ return jnp .where (idx >= size , 2 * size - 2 - idx , idx )
106+
107+ # a simpler version: clamp pixels beyond the edge to the edge pixels
108+ # return jnp.clip(idx, 0, size - 1) # clamp — or use true reflection below
109+
110+ # Row convolution
111+ col = jnp .zeros_like (image )
112+ for k , offset in enumerate ([- 2 * step , - step , 0 , step , 2 * step ]):
113+ reflected = reflect (row_idx + offset , ny )
114+ shifted = jnp .take (image , reflected , axis = 0 )
115+ col += shifted * h1d [k ]
116+
117+ # Column convolution
118+ result = jnp .zeros_like (col )
119+ for k , offset in enumerate ([- 2 * step , - step , 0 , step , 2 * step ]):
120+ reflected = reflect (col_idx + offset , nx )
121+ shifted = jnp .take (col , reflected , axis = 1 )
122+ result += shifted * h1d [k ]
123+
115124 return result
116125
117126
@@ -164,7 +173,7 @@ def starlet_transform(image, scales=None, generation=2, convolve2d=None):
164173 return starlet
165174
166175
167- def starlet_reconstruction (starlets , generation = 2 , convolve2d = None ):
176+ def starlet_reconstruction (starlets , generation = 2 , convolve2d = None , scales = None ):
168177 """Reconstruct an image from a dictionary of starlets
169178
170179 Parameters
@@ -177,6 +186,8 @@ def starlet_reconstruction(starlets, generation=2, convolve2d=None):
177186 convolve2d: function
178187 The filter function to use to convolve the image
179188 with starlets in 2D.
189+ scales: list of int
190+ The scales to include in the reconstruction (0 being the smallest)
180191
181192 Returns
182193 -------
@@ -187,11 +198,17 @@ def starlet_reconstruction(starlets, generation=2, convolve2d=None):
187198 return jnp .sum (starlets , axis = 0 )
188199 if convolve2d is None :
189200 convolve2d = bspline_convolve
190- scales = len (starlets ) - 1
191201
192- c = starlets [- 1 ]
193- for i in range (1 , scales + 1 ):
194- j = scales - i
202+ # scales sorted in reverse order: from largest to smallest
203+ max_scale = len (starlets ) - 1
204+ if scales is None :
205+ scales = tuple (max_scale - i for i in range (1 , max_scale + 1 ))
206+ else :
207+ scales = sorted (tuple (scale for scale in scales if scale <= max_scale ), reverse = True )
208+
209+ # reconstruct: initialize from largest, go to smallest
210+ c = starlets [scales [0 ]]
211+ for j in scales [1 :]:
195212 cj = convolve2d (c , j )
196213 c = cj + starlets [j ]
197214 return c
0 commit comments