Skip to content

Commit 1ca5fb3

Browse files
committed
More ergonomic mem peak estimation in remove_stripe_fw
1 parent f5a8b0a commit 1ca5fb3

File tree

2 files changed

+25
-27
lines changed

2 files changed

+25
-27
lines changed

httomolibgpu/prep/stripe.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _reflect(x: np.ndarray, minx: float, maxx: float) -> np.ndarray:
207207
return np.array(out, dtype=x.dtype)
208208

209209

210-
class DeviceMemStack:
210+
class _DeviceMemStack:
211211
def __init__(self) -> None:
212212
self.allocations = []
213213
self.current = 0
@@ -231,7 +231,7 @@ def _round_up(self, size):
231231
return size * ALLOCATION_UNIT_SIZE
232232

233233

234-
def _mypad(x: cp.ndarray, pad: Tuple[int, int, int, int], mem_stack: Optional[DeviceMemStack]) -> cp.ndarray:
234+
def _mypad(x: cp.ndarray, pad: Tuple[int, int, int, int], mem_stack: Optional[_DeviceMemStack]) -> cp.ndarray:
235235
""" Function to do numpy like padding on Arrays. Only works for 2-D
236236
padding.
237237
@@ -261,7 +261,7 @@ def _mypad(x: cp.ndarray, pad: Tuple[int, int, int, int], mem_stack: Optional[De
261261
return x[:, :, :, xe]
262262

263263

264-
def _conv2d(x: cp.ndarray, w: np.ndarray, stride: Tuple[int, int], groups: int, mem_stack: Optional[DeviceMemStack]) -> cp.ndarray:
264+
def _conv2d(x: cp.ndarray, w: np.ndarray, stride: Tuple[int, int], groups: int, mem_stack: Optional[_DeviceMemStack]) -> cp.ndarray:
265265
""" Convolution (equivalent pytorch.conv2d)
266266
"""
267267
b, ci, hi, wi = x.shape if not mem_stack else x
@@ -296,7 +296,7 @@ def _conv2d(x: cp.ndarray, w: np.ndarray, stride: Tuple[int, int], groups: int,
296296
return out
297297

298298

299-
def _conv_transpose2d(x: cp.ndarray, w: np.ndarray, stride: Tuple[int, int], pad: Tuple[int, int], groups: int, mem_stack: Optional[DeviceMemStack]) -> cp.ndarray:
299+
def _conv_transpose2d(x: cp.ndarray, w: np.ndarray, stride: Tuple[int, int], pad: Tuple[int, int], groups: int, mem_stack: Optional[_DeviceMemStack]) -> cp.ndarray:
300300
""" Transposed convolution (equivalent pytorch.conv_transpose2d)
301301
"""
302302
b, co, ho, wo = x.shape if not mem_stack else x
@@ -331,7 +331,7 @@ def _conv_transpose2d(x: cp.ndarray, w: np.ndarray, stride: Tuple[int, int], pad
331331
return out, None
332332

333333

334-
def afb1d(x: cp.ndarray, h0: np.ndarray, h1: np.ndarray, dim: int, mem_stack: Optional[DeviceMemStack]) -> cp.ndarray:
334+
def _afb1d(x: cp.ndarray, h0: np.ndarray, h1: np.ndarray, dim: int, mem_stack: Optional[_DeviceMemStack]) -> cp.ndarray:
335335
""" 1D analysis filter bank (along one dimension only) of an image
336336
337337
Parameters
@@ -372,7 +372,7 @@ def afb1d(x: cp.ndarray, h0: np.ndarray, h1: np.ndarray, dim: int, mem_stack: Op
372372
return lohi
373373

374374

375-
def sfb1d(lo: cp.ndarray, hi: cp.ndarray, g0: np.ndarray, g1: np.ndarray, dim: int, mem_stack: Optional[DeviceMemStack]) -> cp.ndarray:
375+
def _sfb1d(lo: cp.ndarray, hi: cp.ndarray, g0: np.ndarray, g1: np.ndarray, dim: int, mem_stack: Optional[_DeviceMemStack]) -> cp.ndarray:
376376
""" 1D synthesis filter bank of an image Array
377377
"""
378378

@@ -396,7 +396,7 @@ def sfb1d(lo: cp.ndarray, hi: cp.ndarray, g0: np.ndarray, g1: np.ndarray, dim: i
396396
return y_lo + y_hi
397397

398398

399-
class DWTForward():
399+
class _DWTForward():
400400
""" Performs a 2d DWT Forward decomposition of an image
401401
402402
Args:
@@ -419,7 +419,7 @@ def __init__(self, wave: str):
419419
self.h1_row = np.array(h1_row).astype('float32')[
420420
::-1].reshape((1, 1, 1, -1))
421421

422-
def apply(self, x: cp.ndarray, mem_stack: Optional[DeviceMemStack] = None) -> Tuple[cp.ndarray, cp.ndarray]:
422+
def apply(self, x: cp.ndarray, mem_stack: Optional[_DeviceMemStack] = None) -> Tuple[cp.ndarray, cp.ndarray]:
423423
""" Forward pass of the DWT.
424424
425425
Args:
@@ -439,8 +439,8 @@ def apply(self, x: cp.ndarray, mem_stack: Optional[DeviceMemStack] = None) -> Tu
439439
"""
440440
# Do a multilevel transform
441441
# Do 1 level of the transform
442-
lohi = afb1d(x, self.h0_row, self.h1_row, dim=3, mem_stack=mem_stack)
443-
y = afb1d(lohi, self.h0_col, self.h1_col, dim=2, mem_stack=mem_stack)
442+
lohi = _afb1d(x, self.h0_row, self.h1_row, dim=3, mem_stack=mem_stack)
443+
y = _afb1d(lohi, self.h0_col, self.h1_col, dim=2, mem_stack=mem_stack)
444444
if mem_stack:
445445
y_shape = [y[0], np.prod(y) // y[0] // 4 // y[-2] // y[-1], 4, y[-2], y[-1]]
446446
x_shape = [y_shape[0], y_shape[1], y_shape[3], y_shape[4]]
@@ -459,7 +459,7 @@ def apply(self, x: cp.ndarray, mem_stack: Optional[DeviceMemStack] = None) -> Tu
459459
return (x, yh)
460460

461461

462-
class DWTInverse():
462+
class _DWTInverse():
463463
""" Performs a 2d DWT Inverse reconstruction of an image
464464
465465
Args:
@@ -477,7 +477,7 @@ def __init__(self, wave: str):
477477
self.g0_row = np.array(g0_row).astype('float32').reshape((1, 1, 1, -1))
478478
self.g1_row = np.array(g1_row).astype('float32').reshape((1, 1, 1, -1))
479479

480-
def apply(self, coeffs: Tuple[cp.ndarray, cp.ndarray], mem_stack: Optional[DeviceMemStack] = None) -> cp.ndarray:
480+
def apply(self, coeffs: Tuple[cp.ndarray, cp.ndarray], mem_stack: Optional[_DeviceMemStack] = None) -> cp.ndarray:
481481
"""
482482
Args:
483483
coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
@@ -498,9 +498,9 @@ def apply(self, coeffs: Tuple[cp.ndarray, cp.ndarray], mem_stack: Optional[Devic
498498
lh = yh[:, :, 0, :, :] if not mem_stack else [yh[0], yh[1], yh[3], yh[4]]
499499
hl = yh[:, :, 1, :, :] if not mem_stack else [yh[0], yh[1], yh[3], yh[4]]
500500
hh = yh[:, :, 2, :, :] if not mem_stack else [yh[0], yh[1], yh[3], yh[4]]
501-
lo = sfb1d(yl, lh, self.g0_col, self.g1_col, dim=2, mem_stack=mem_stack)
502-
hi = sfb1d(hl, hh, self.g0_col, self.g1_col, dim=2, mem_stack=mem_stack)
503-
yl = sfb1d(lo, hi, self.g0_row, self.g1_row, dim=3, mem_stack=mem_stack)
501+
lo = _sfb1d(yl, lh, self.g0_col, self.g1_col, dim=2, mem_stack=mem_stack)
502+
hi = _sfb1d(hl, hh, self.g0_col, self.g1_col, dim=2, mem_stack=mem_stack)
503+
yl = _sfb1d(lo, hi, self.g0_row, self.g1_row, dim=3, mem_stack=mem_stack)
504504
if mem_stack:
505505
mem_stack.free(np.prod(lo) * np.float32().itemsize)
506506
mem_stack.free(np.prod(hi) * np.float32().itemsize)
@@ -509,22 +509,23 @@ def apply(self, coeffs: Tuple[cp.ndarray, cp.ndarray], mem_stack: Optional[Devic
509509
return yl
510510

511511

512-
def remove_stripe_fw(data: cp.ndarray, sigma: float=1, wname: str='sym16', level: int=7, mem_stack: Optional[DeviceMemStack] = None) -> cp.ndarray:
512+
def remove_stripe_fw(data, sigma: float=1, wname: str='sym16', level: int=7, calc_peak_gpu_mem: bool = False) -> cp.ndarray:
513513
"""Remove stripes with wavelet filtering"""
514514

515-
[nproj, nz, ni] = data.shape if not mem_stack else data
515+
[nproj, nz, ni] = data.shape if not calc_peak_gpu_mem else data
516516

517517
nproj_pad = nproj + nproj // 8
518518

519519
# Accepts all wave types available to PyWavelets
520-
xfm = DWTForward(wave=wname)
521-
ifm = DWTInverse(wave=wname)
520+
xfm = _DWTForward(wave=wname)
521+
ifm = _DWTInverse(wave=wname)
522522

523523
# Wavelet decomposition.
524524
cc = []
525525
sli_shape = [nz, 1, nproj_pad, ni]
526526

527-
if mem_stack:
527+
if calc_peak_gpu_mem:
528+
mem_stack = _DeviceMemStack()
528529
# A data copy is assumed when invoking the function
529530
mem_stack.malloc(np.prod(data) * np.float32().itemsize)
530531
mem_stack.malloc(np.prod(sli_shape) * np.float32().itemsize)
@@ -561,7 +562,7 @@ def remove_stripe_fw(data: cp.ndarray, sigma: float=1, wname: str='sym16', level
561562
mem_stack.free(np.prod(c) * np.float32().itemsize)
562563
mem_stack.malloc(np.prod(data) * np.float32().itemsize)
563564
mem_stack.free(np.prod(sli_shape) * np.float32().itemsize)
564-
return
565+
return mem_stack.highwater
565566

566567
sli = cp.zeros(sli_shape, dtype='float32')
567568
sli[:, 0, (nproj_pad - nproj)//2:(nproj_pad + nproj) // 2] = data.swapaxes(0, 1)

tests/test_prep/test_stripe.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from httomolibgpu.prep.normalize import normalize
88
from httomolibgpu.prep.stripe import (
9-
DeviceMemStack,
109
remove_stripe_based_sorting,
1110
remove_stripe_ti,
1211
remove_stripe_fw,
@@ -88,7 +87,7 @@ def ensure_clean_memory():
8887
@pytest.mark.parametrize("slices", [55, 80])
8988
@pytest.mark.parametrize("level", [1, 3, 7, 11])
9089
@pytest.mark.parametrize("dim_x", [128, 140])
91-
def test_remove_stripe_fw_mem_stack(slices, level, dim_x, ensure_clean_memory):
90+
def test_remove_stripe_fw_calc_mem(slices, level, dim_x, ensure_clean_memory):
9291
dim_y = 159
9392
data = cp.random.random_sample((slices, dim_x, dim_y), dtype=np.float32)
9493
hook = MaxMemoryHook()
@@ -97,13 +96,11 @@ def test_remove_stripe_fw_mem_stack(slices, level, dim_x, ensure_clean_memory):
9796
actual_mem_peak = hook.max_mem
9897

9998
hook = MaxMemoryHook()
100-
mem_stack = DeviceMemStack()
10199
with hook:
102-
remove_stripe_fw(data.shape, level=level, mem_stack=mem_stack)
100+
estimated_mem_peak = remove_stripe_fw(data.shape, level=level, calc_peak_gpu_mem=True)
103101
assert hook.max_mem == 0
104-
estimated_mem_peak = mem_stack.highwater
105102

106-
# assert actual_mem_peak == estimated_mem_peak
103+
assert actual_mem_peak == estimated_mem_peak
107104
assert actual_mem_peak * 0.99 <= estimated_mem_peak
108105
assert estimated_mem_peak <= actual_mem_peak * 1.01
109106

0 commit comments

Comments
 (0)