@@ -205,40 +205,12 @@ def remove_all_stripe(
205205 for m in range (data .shape [1 ]):
206206 sino = data [:, m , :]
207207 sino = _rs_dead (sino , snr , la_size , matindex )
208- sino = _rs_sort2 (sino , sm_size , matindex , dim )
208+ sino = _rs_sort (sino , sm_size , dim )
209209 sino = cp .nan_to_num (sino )
210210 data [:, m , :] = sino
211211 return data
212212
213213
214- def _rs_sort2 (sinogram , size , matindex , dim ):
215- """
216- Remove stripes using the sorting technique.
217- """
218- sinogram = cp .transpose (sinogram )
219- matcomb = cp .asarray (cp .dstack ((matindex , sinogram )))
220-
221- # matsort = cp.asarray([row[row[:, 1].argsort()] for row in matcomb])
222- ids = cp .argsort (matcomb [:, :, 1 ], axis = 1 )
223- matsort = matcomb .copy ()
224- matsort [:, :, 0 ] = cp .take_along_axis (matsort [:, :, 0 ], ids , axis = 1 )
225- matsort [:, :, 1 ] = cp .take_along_axis (matsort [:, :, 1 ], ids , axis = 1 )
226- if dim == 1 :
227- matsort [:, :, 1 ] = median_filter (matsort [:, :, 1 ], (size , 1 ))
228- else :
229- matsort [:, :, 1 ] = median_filter (matsort [:, :, 1 ], (size , size ))
230-
231- # matsortback = cp.asarray([row[row[:, 0].argsort()] for row in matsort])
232-
233- ids = cp .argsort (matsort [:, :, 0 ], axis = 1 )
234- matsortback = matsort .copy ()
235- matsortback [:, :, 0 ] = cp .take_along_axis (matsortback [:, :, 0 ], ids , axis = 1 )
236- matsortback [:, :, 1 ] = cp .take_along_axis (matsortback [:, :, 1 ], ids , axis = 1 )
237-
238- sino_corrected = matsortback [:, :, 1 ]
239- return cp .transpose (sino_corrected )
240-
241-
242214def _mpolyfit (x , y ):
243215 n = len (x )
244216 x_mean = cp .mean (x )
@@ -261,8 +233,6 @@ def _detect_stripe(listdata, snr):
261233 listsorted = cp .sort (listdata )[::- 1 ]
262234 xlist = cp .arange (0 , numdata , 1.0 )
263235 ndrop = cp .int16 (0.25 * numdata )
264- # (_slope, _intercept) = cp.polyfit(xlist[ndrop:-ndrop - 1],
265- # listsorted[ndrop:-ndrop - 1], 1)
266236 (_slope , _intercept ) = _mpolyfit (
267237 xlist [ndrop : - ndrop - 1 ], listsorted [ndrop : - ndrop - 1 ]
268238 )
@@ -293,11 +263,6 @@ def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
293263 sinosmooth = median_filter (sinosort , (1 , size ))
294264 list1 = cp .mean (sinosort [ndrop : nrow - ndrop ], axis = 0 )
295265 list2 = cp .mean (sinosmooth [ndrop : nrow - ndrop ], axis = 0 )
296- # listfact = cp.divide(list1,
297- # list2,
298- # out=cp.ones_like(list1),
299- # where=list2 != 0)
300-
301266 listfact = list1 / list2
302267
303268 # Locate stripes
@@ -310,14 +275,12 @@ def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
310275 sinogram1 = cp .transpose (sinogram )
311276 matcombine = cp .asarray (cp .dstack ((matindex , sinogram1 )))
312277
313- # matsort = cp.asarray([row[row[:, 1].argsort()] for row in matcombine])
314278 ids = cp .argsort (matcombine [:, :, 1 ], axis = 1 )
315279 matsort = matcombine .copy ()
316280 matsort [:, :, 0 ] = cp .take_along_axis (matsort [:, :, 0 ], ids , axis = 1 )
317281 matsort [:, :, 1 ] = cp .take_along_axis (matsort [:, :, 1 ], ids , axis = 1 )
318282
319283 matsort [:, :, 1 ] = cp .transpose (sinosmooth )
320- # matsortback = cp.asarray([row[row[:, 0].argsort()] for row in matsort])
321284 ids = cp .argsort (matsort [:, :, 0 ], axis = 1 )
322285 matsortback = matsort .copy ()
323286 matsortback [:, :, 0 ] = cp .take_along_axis (matsortback [:, :, 0 ], ids , axis = 1 )
@@ -330,12 +293,9 @@ def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
330293
331294
332295def _rs_dead (sinogram , snr , size , matindex , norm = True ):
333- """
334- Remove unresponsive and fluctuating stripes.
335- """
296+ """remove unresponsive and fluctuating stripes"""
336297 sinogram = cp .copy (sinogram ) # Make it mutable
337298 (nrow , _ ) = sinogram .shape
338- # sinosmooth = cp.apply_along_axis(uniform_filter1d, 0, sinogram, 10)
339299 sinosmooth = uniform_filter1d (sinogram , 10 , axis = 0 )
340300
341301 listdiff = cp .sum (cp .abs (sinogram - sinosmooth ), axis = 0 )
@@ -344,22 +304,22 @@ def _rs_dead(sinogram, snr, size, matindex, norm=True):
344304 listfact = listdiff / listdiffbck
345305
346306 listmask = _detect_stripe (listfact , snr )
307+ del listfact
347308 listmask = binary_dilation (listmask , iterations = 1 ).astype (listmask .dtype )
348309 listmask [0 :2 ] = 0.0
349310 listmask [- 2 :] = 0.0
350- listx = cp .where (listmask < 1.0 )[0 ]
351- listy = cp .arange (nrow )
352- matz = sinogram [:, listx ]
353311
312+ listx = cp .where (listmask < 1.0 )[0 ]
354313 listxmiss = cp .where (listmask > 0.0 )[0 ]
314+ del listmask
355315
356- # finter = interpolate.interp2d(listx.get(), listy.get(), matz.get(), kind='linear')
357316 if len (listxmiss ) > 0 :
358- # sinogram_c[:, listxmiss.get()] = finter(listxmiss.get(), listy.get())
359317 ids = cp .searchsorted (listx , listxmiss )
360- sinogram [:, listxmiss ] = matz [:, ids - 1 ] + (listxmiss - listx [ids - 1 ]) * (
361- matz [:, ids ] - matz [:, ids - 1 ]
362- ) / (listx [ids ] - listx [ids - 1 ])
318+ weights = (listxmiss - listx [ids - 1 ]) / (listx [ids ] - listx [ids - 1 ])
319+ # direct interpolation without making an extra copy
320+ sinogram [:, listxmiss ] = sinogram [:, listx [ids - 1 ]] + weights * (
321+ sinogram [:, listx [ids ]] - sinogram [:, listx [ids - 1 ]]
322+ )
363323
364324 # Remove residual stripes
365325 if norm is True :
@@ -455,7 +415,7 @@ def raven_filter(
455415 # Removing padding
456416 data = data [pad_y : height - pad_y , :, pad_x : width - pad_x ].real
457417
458- return data
418+ return cp . require ( data , requirements = "C" )
459419
460420
461421def _create_matindex (nrow , ncol ):
0 commit comments