@@ -207,7 +207,7 @@ def _reflect(x: np.ndarray, minx: float, maxx: float) -> np.ndarray:
207207 return np .array (out , dtype = x .dtype )
208208
209209
210- class DeviceMemStack :
210+ class _DeviceMemStack :
211211 def __init__ (self ) -> None :
212212 self .allocations = []
213213 self .current = 0
@@ -231,7 +231,7 @@ def _round_up(self, size):
231231 return size * ALLOCATION_UNIT_SIZE
232232
233233
234- def _mypad (x : cp .ndarray , pad : Tuple [int , int , int , int ], mem_stack : Optional [DeviceMemStack ]) -> cp .ndarray :
234+ def _mypad (x : cp .ndarray , pad : Tuple [int , int , int , int ], mem_stack : Optional [_DeviceMemStack ]) -> cp .ndarray :
235235 """ Function to do numpy like padding on Arrays. Only works for 2-D
236236 padding.
237237
@@ -261,7 +261,7 @@ def _mypad(x: cp.ndarray, pad: Tuple[int, int, int, int], mem_stack: Optional[De
261261 return x [:, :, :, xe ]
262262
263263
264- def _conv2d (x : cp .ndarray , w : np .ndarray , stride : Tuple [int , int ], groups : int , mem_stack : Optional [DeviceMemStack ]) -> cp .ndarray :
264+ def _conv2d (x : cp .ndarray , w : np .ndarray , stride : Tuple [int , int ], groups : int , mem_stack : Optional [_DeviceMemStack ]) -> cp .ndarray :
265265 """ Convolution (equivalent pytorch.conv2d)
266266 """
267267 b , ci , hi , wi = x .shape if not mem_stack else x
@@ -296,7 +296,7 @@ def _conv2d(x: cp.ndarray, w: np.ndarray, stride: Tuple[int, int], groups: int,
296296 return out
297297
298298
299- def _conv_transpose2d (x : cp .ndarray , w : np .ndarray , stride : Tuple [int , int ], pad : Tuple [int , int ], groups : int , mem_stack : Optional [DeviceMemStack ]) -> cp .ndarray :
299+ def _conv_transpose2d (x : cp .ndarray , w : np .ndarray , stride : Tuple [int , int ], pad : Tuple [int , int ], groups : int , mem_stack : Optional [_DeviceMemStack ]) -> cp .ndarray :
300300 """ Transposed convolution (equivalent pytorch.conv_transpose2d)
301301 """
302302 b , co , ho , wo = x .shape if not mem_stack else x
@@ -331,7 +331,7 @@ def _conv_transpose2d(x: cp.ndarray, w: np.ndarray, stride: Tuple[int, int], pad
331331 return out , None
332332
333333
334- def afb1d (x : cp .ndarray , h0 : np .ndarray , h1 : np .ndarray , dim : int , mem_stack : Optional [DeviceMemStack ]) -> cp .ndarray :
334+ def _afb1d (x : cp .ndarray , h0 : np .ndarray , h1 : np .ndarray , dim : int , mem_stack : Optional [_DeviceMemStack ]) -> cp .ndarray :
335335 """ 1D analysis filter bank (along one dimension only) of an image
336336
337337 Parameters
@@ -372,7 +372,7 @@ def afb1d(x: cp.ndarray, h0: np.ndarray, h1: np.ndarray, dim: int, mem_stack: Op
372372 return lohi
373373
374374
375- def sfb1d (lo : cp .ndarray , hi : cp .ndarray , g0 : np .ndarray , g1 : np .ndarray , dim : int , mem_stack : Optional [DeviceMemStack ]) -> cp .ndarray :
375+ def _sfb1d (lo : cp .ndarray , hi : cp .ndarray , g0 : np .ndarray , g1 : np .ndarray , dim : int , mem_stack : Optional [_DeviceMemStack ]) -> cp .ndarray :
376376 """ 1D synthesis filter bank of an image Array
377377 """
378378
@@ -396,7 +396,7 @@ def sfb1d(lo: cp.ndarray, hi: cp.ndarray, g0: np.ndarray, g1: np.ndarray, dim: i
396396 return y_lo + y_hi
397397
398398
399- class DWTForward ():
399+ class _DWTForward ():
400400 """ Performs a 2d DWT Forward decomposition of an image
401401
402402 Args:
@@ -419,7 +419,7 @@ def __init__(self, wave: str):
419419 self .h1_row = np .array (h1_row ).astype ('float32' )[
420420 ::- 1 ].reshape ((1 , 1 , 1 , - 1 ))
421421
422- def apply (self , x : cp .ndarray , mem_stack : Optional [DeviceMemStack ] = None ) -> Tuple [cp .ndarray , cp .ndarray ]:
422+ def apply (self , x : cp .ndarray , mem_stack : Optional [_DeviceMemStack ] = None ) -> Tuple [cp .ndarray , cp .ndarray ]:
423423 """ Forward pass of the DWT.
424424
425425 Args:
@@ -439,8 +439,8 @@ def apply(self, x: cp.ndarray, mem_stack: Optional[DeviceMemStack] = None) -> Tu
439439 """
440440 # Do a multilevel transform
441441 # Do 1 level of the transform
442- lohi = afb1d (x , self .h0_row , self .h1_row , dim = 3 , mem_stack = mem_stack )
443- y = afb1d (lohi , self .h0_col , self .h1_col , dim = 2 , mem_stack = mem_stack )
442+ lohi = _afb1d (x , self .h0_row , self .h1_row , dim = 3 , mem_stack = mem_stack )
443+ y = _afb1d (lohi , self .h0_col , self .h1_col , dim = 2 , mem_stack = mem_stack )
444444 if mem_stack :
445445 y_shape = [y [0 ], np .prod (y ) // y [0 ] // 4 // y [- 2 ] // y [- 1 ], 4 , y [- 2 ], y [- 1 ]]
446446 x_shape = [y_shape [0 ], y_shape [1 ], y_shape [3 ], y_shape [4 ]]
@@ -459,7 +459,7 @@ def apply(self, x: cp.ndarray, mem_stack: Optional[DeviceMemStack] = None) -> Tu
459459 return (x , yh )
460460
461461
462- class DWTInverse ():
462+ class _DWTInverse ():
463463 """ Performs a 2d DWT Inverse reconstruction of an image
464464
465465 Args:
@@ -477,7 +477,7 @@ def __init__(self, wave: str):
477477 self .g0_row = np .array (g0_row ).astype ('float32' ).reshape ((1 , 1 , 1 , - 1 ))
478478 self .g1_row = np .array (g1_row ).astype ('float32' ).reshape ((1 , 1 , 1 , - 1 ))
479479
480- def apply (self , coeffs : Tuple [cp .ndarray , cp .ndarray ], mem_stack : Optional [DeviceMemStack ] = None ) -> cp .ndarray :
480+ def apply (self , coeffs : Tuple [cp .ndarray , cp .ndarray ], mem_stack : Optional [_DeviceMemStack ] = None ) -> cp .ndarray :
481481 """
482482 Args:
483483 coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
@@ -498,9 +498,9 @@ def apply(self, coeffs: Tuple[cp.ndarray, cp.ndarray], mem_stack: Optional[Devic
498498 lh = yh [:, :, 0 , :, :] if not mem_stack else [yh [0 ], yh [1 ], yh [3 ], yh [4 ]]
499499 hl = yh [:, :, 1 , :, :] if not mem_stack else [yh [0 ], yh [1 ], yh [3 ], yh [4 ]]
500500 hh = yh [:, :, 2 , :, :] if not mem_stack else [yh [0 ], yh [1 ], yh [3 ], yh [4 ]]
501- lo = sfb1d (yl , lh , self .g0_col , self .g1_col , dim = 2 , mem_stack = mem_stack )
502- hi = sfb1d (hl , hh , self .g0_col , self .g1_col , dim = 2 , mem_stack = mem_stack )
503- yl = sfb1d (lo , hi , self .g0_row , self .g1_row , dim = 3 , mem_stack = mem_stack )
501+ lo = _sfb1d (yl , lh , self .g0_col , self .g1_col , dim = 2 , mem_stack = mem_stack )
502+ hi = _sfb1d (hl , hh , self .g0_col , self .g1_col , dim = 2 , mem_stack = mem_stack )
503+ yl = _sfb1d (lo , hi , self .g0_row , self .g1_row , dim = 3 , mem_stack = mem_stack )
504504 if mem_stack :
505505 mem_stack .free (np .prod (lo ) * np .float32 ().itemsize )
506506 mem_stack .free (np .prod (hi ) * np .float32 ().itemsize )
@@ -509,22 +509,23 @@ def apply(self, coeffs: Tuple[cp.ndarray, cp.ndarray], mem_stack: Optional[Devic
509509 return yl
510510
511511
512- def remove_stripe_fw (data : cp . ndarray , sigma : float = 1 , wname : str = 'sym16' , level : int = 7 , mem_stack : Optional [ DeviceMemStack ] = None ) -> cp .ndarray :
512+ def remove_stripe_fw (data , sigma : float = 1 , wname : str = 'sym16' , level : int = 7 , calc_peak_gpu_mem : bool = False ) -> cp .ndarray :
513513 """Remove stripes with wavelet filtering"""
514514
515- [nproj , nz , ni ] = data .shape if not mem_stack else data
515+ [nproj , nz , ni ] = data .shape if not calc_peak_gpu_mem else data
516516
517517 nproj_pad = nproj + nproj // 8
518518
519519 # Accepts all wave types available to PyWavelets
520- xfm = DWTForward (wave = wname )
521- ifm = DWTInverse (wave = wname )
520+ xfm = _DWTForward (wave = wname )
521+ ifm = _DWTInverse (wave = wname )
522522
523523 # Wavelet decomposition.
524524 cc = []
525525 sli_shape = [nz , 1 , nproj_pad , ni ]
526526
527- if mem_stack :
527+ if calc_peak_gpu_mem :
528+ mem_stack = _DeviceMemStack ()
528529 # A data copy is assumed when invoking the function
529530 mem_stack .malloc (np .prod (data ) * np .float32 ().itemsize )
530531 mem_stack .malloc (np .prod (sli_shape ) * np .float32 ().itemsize )
@@ -561,7 +562,7 @@ def remove_stripe_fw(data: cp.ndarray, sigma: float=1, wname: str='sym16', level
561562 mem_stack .free (np .prod (c ) * np .float32 ().itemsize )
562563 mem_stack .malloc (np .prod (data ) * np .float32 ().itemsize )
563564 mem_stack .free (np .prod (sli_shape ) * np .float32 ().itemsize )
564- return
565+ return mem_stack . highwater
565566
566567 sli = cp .zeros (sli_shape , dtype = 'float32' )
567568 sli [:, 0 , (nproj_pad - nproj )// 2 :(nproj_pad + nproj ) // 2 ] = data .swapaxes (0 , 1 )
0 commit comments