@@ -201,7 +201,6 @@ def remove_all_stripe(
201201 Corrected 3D tomographic data as a CuPy or NumPy array.
202202
203203 """
204- matindex = _create_matindex (data .shape [2 ], data .shape [0 ])
205204 for m in range (data .shape [1 ]):
206205 sino = data [:, m , :]
207206 sino = _rs_dead (sino , snr , la_size , matindex )
@@ -252,7 +251,7 @@ def _detect_stripe(listdata, snr):
252251 return listmask
253252
254253
255- def _rs_large (sinogram , snr , size , matindex , drop_ratio = 0.1 , norm = True ):
254+ def _rs_large (sinogram , snr , size , drop_ratio = 0.1 , norm = True ):
256255 """
257256 Remove large stripes.
258257 """
@@ -264,35 +263,35 @@ def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
264263 list1 = cp .mean (sinosort [ndrop : nrow - ndrop ], axis = 0 )
265264 list2 = cp .mean (sinosmooth [ndrop : nrow - ndrop ], axis = 0 )
266265 listfact = list1 / list2
267-
268266 # Locate stripes
269267 listmask = _detect_stripe (listfact , snr )
270268 listmask = binary_dilation (listmask , iterations = 1 ).astype (listmask .dtype )
271- matfact = cp . tile ( listfact , ( nrow , 1 ))
269+
272270 # Normalize
273- if norm is True :
274- sinogram = sinogram / matfact
275- sinogram1 = cp . transpose ( sinogram )
276- matcombine = cp . asarray ( cp . dstack (( matindex , sinogram1 )))
277-
278- ids = cp .argsort (matcombine [:, :, 1 ] , axis = 1 )
279- matsort = matcombine . copy ()
280- matsort [:, :, 0 ] = cp . take_along_axis ( matsort [:, :, 0 ], ids , axis = 1 )
281- matsort [:, :, 1 ] = cp .take_along_axis (matsort [:, :, 1 ], ids , axis = 1 )
282-
283- matsort [:, :, 1 ] = cp . transpose ( sinosmooth )
284- ids = cp . argsort ( matsort [:, :, 0 ], axis = 1 )
285- matsortback = matsort . copy ()
286- matsortback [:, :, 0 ] = cp .take_along_axis ( matsortback [:, :, 0 ], ids , axis = 1 )
287- matsortback [:, :, 1 ] = cp .take_along_axis (matsortback [:, :, 1 ], ids , axis = 1 )
288-
289- sino_corrected = cp . transpose ( matsortback [:, :, 1 ])
271+ if norm :
272+ sinogram /= cp . tile ( listfact , ( nrow , 1 ))
273+
274+ # Transpose for sorting along columns
275+ sino_transposed = sinogram . T
276+ ids_sort = cp .argsort (sino_transposed , axis = 1 )
277+
278+ # Apply sorting without explicit matindex
279+ sino_sorted = cp .take_along_axis (sino_transposed , ids_sort , axis = 1 )
280+
281+ # Smoothen sorted sinogram
282+ sino_sorted [:, :] = cp . transpose ( sinosmooth )
283+
284+ ids_restore = cp .argsort ( ids_sort , axis = 1 )
285+ sino_corrected = cp .take_along_axis (sino_sorted , ids_restore , axis = 1 ). T
286+
287+ # Apply corrections only to affected columns
290288 listxmiss = cp .where (listmask > 0.0 )[0 ]
291289 sinogram [:, listxmiss ] = sino_corrected [:, listxmiss ]
290+
292291 return sinogram
293292
294293
295- def _rs_dead (sinogram , snr , size , matindex , norm = True ):
294+ def _rs_dead (sinogram , snr , size , norm = True ):
296295 """remove unresponsive and fluctuating stripes"""
297296 sinogram = cp .copy (sinogram ) # Make it mutable
298297 (nrow , _ ) = sinogram .shape
@@ -323,7 +322,7 @@ def _rs_dead(sinogram, snr, size, matindex, norm=True):
323322
324323 # Remove residual stripes
325324 if norm is True :
326- sinogram = _rs_large (sinogram , snr , size , matindex )
325+ sinogram = _rs_large (sinogram , snr , size )
327326 return sinogram
328327
329328
0 commit comments