Skip to content

Commit c165942

Browse files
committed
Remove matindex
1 parent c7a985f commit c165942

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

httomolibgpu/prep/stripe.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ def remove_all_stripe(
201201
Corrected 3D tomographic data as a CuPy or NumPy array.
202202
203203
"""
204-
matindex = _create_matindex(data.shape[2], data.shape[0])
205204
for m in range(data.shape[1]):
206205
sino = data[:, m, :]
207206
sino = _rs_dead(sino, snr, la_size, matindex)
@@ -252,7 +251,7 @@ def _detect_stripe(listdata, snr):
252251
return listmask
253252

254253

255-
def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
254+
def _rs_large(sinogram, snr, size, drop_ratio=0.1, norm=True):
256255
"""
257256
Remove large stripes.
258257
"""
@@ -264,35 +263,35 @@ def _rs_large(sinogram, snr, size, matindex, drop_ratio=0.1, norm=True):
264263
list1 = cp.mean(sinosort[ndrop : nrow - ndrop], axis=0)
265264
list2 = cp.mean(sinosmooth[ndrop : nrow - ndrop], axis=0)
266265
listfact = list1 / list2
267-
268266
# Locate stripes
269267
listmask = _detect_stripe(listfact, snr)
270268
listmask = binary_dilation(listmask, iterations=1).astype(listmask.dtype)
271-
matfact = cp.tile(listfact, (nrow, 1))
269+
272270
# Normalize
273-
if norm is True:
274-
sinogram = sinogram / matfact
275-
sinogram1 = cp.transpose(sinogram)
276-
matcombine = cp.asarray(cp.dstack((matindex, sinogram1)))
277-
278-
ids = cp.argsort(matcombine[:, :, 1], axis=1)
279-
matsort = matcombine.copy()
280-
matsort[:, :, 0] = cp.take_along_axis(matsort[:, :, 0], ids, axis=1)
281-
matsort[:, :, 1] = cp.take_along_axis(matsort[:, :, 1], ids, axis=1)
282-
283-
matsort[:, :, 1] = cp.transpose(sinosmooth)
284-
ids = cp.argsort(matsort[:, :, 0], axis=1)
285-
matsortback = matsort.copy()
286-
matsortback[:, :, 0] = cp.take_along_axis(matsortback[:, :, 0], ids, axis=1)
287-
matsortback[:, :, 1] = cp.take_along_axis(matsortback[:, :, 1], ids, axis=1)
288-
289-
sino_corrected = cp.transpose(matsortback[:, :, 1])
271+
if norm:
272+
sinogram /= cp.tile(listfact, (nrow, 1))
273+
274+
# Transpose for sorting along columns
275+
sino_transposed = sinogram.T
276+
ids_sort = cp.argsort(sino_transposed, axis=1)
277+
278+
# Apply sorting without explicit matindex
279+
sino_sorted = cp.take_along_axis(sino_transposed, ids_sort, axis=1)
280+
281+
# Smoothen sorted sinogram
282+
sino_sorted[:, :] = cp.transpose(sinosmooth)
283+
284+
ids_restore = cp.argsort(ids_sort, axis=1)
285+
sino_corrected = cp.take_along_axis(sino_sorted, ids_restore, axis=1).T
286+
287+
# Apply corrections only to affected columns
290288
listxmiss = cp.where(listmask > 0.0)[0]
291289
sinogram[:, listxmiss] = sino_corrected[:, listxmiss]
290+
292291
return sinogram
293292

294293

295-
def _rs_dead(sinogram, snr, size, matindex, norm=True):
294+
def _rs_dead(sinogram, snr, size, norm=True):
296295
"""remove unresponsive and fluctuating stripes"""
297296
sinogram = cp.copy(sinogram) # Make it mutable
298297
(nrow, _) = sinogram.shape
@@ -323,7 +322,7 @@ def _rs_dead(sinogram, snr, size, matindex, norm=True):
323322

324323
# Remove residual stripes
325324
if norm is True:
326-
sinogram = _rs_large(sinogram, snr, size, matindex)
325+
sinogram = _rs_large(sinogram, snr, size)
327326
return sinogram
328327

329328

0 commit comments

Comments
 (0)