@@ -305,17 +305,45 @@ def _conv2d(
305305
306306 if groups == 1 :
307307 grouped_convolution_kernel_x = module .get_function (symbol_names [0 ])
308- grouped_convolution_kernel_x (grid_dim , block_dim ,
309- (dim_x , dim_y , dim_z , x , in_stride_x , in_stride_y ,
310- in_stride_z , out , out_stride_z , out_stride_group , w ))
308+ grouped_convolution_kernel_x (
309+ grid_dim ,
310+ block_dim ,
311+ (
312+ dim_x ,
313+ dim_y ,
314+ dim_z ,
315+ x ,
316+ in_stride_x ,
317+ in_stride_y ,
318+ in_stride_z ,
319+ out ,
320+ out_stride_z ,
321+ out_stride_group ,
322+ w ,
323+ ),
324+ )
311325 return out
312326
313327 grouped_convolution_kernel_y = module .get_function (symbol_names [1 ])
314328 in_stride_group = x .strides [2 ] // x .dtype .itemsize
315- grouped_convolution_kernel_y (grid_dim , block_dim ,
316- (dim_x , dim_y , dim_z , x , in_stride_x , in_stride_y ,
317- in_stride_z , in_stride_group , out , out_stride_z ,
318- out_stride_group , w ))
329+ grouped_convolution_kernel_y (
330+ grid_dim ,
331+ block_dim ,
332+ (
333+ dim_x ,
334+ dim_y ,
335+ dim_z ,
336+ x ,
337+ in_stride_x ,
338+ in_stride_y ,
339+ in_stride_z ,
340+ in_stride_group ,
341+ out ,
342+ out_stride_z ,
343+ out_stride_group ,
344+ w ,
345+ ),
346+ )
319347 del w
320348 return out
321349
@@ -353,7 +381,10 @@ def _conv_transpose2d(
353381 out = cp .zeros (out_shape , dtype = "float32" )
354382 w = cp .asarray (w )
355383
356- symbol_names = [f"transposed_convolution_x<{ wk } >" , f"transposed_convolution_y<{ hk } >" ]
384+ symbol_names = [
385+ f"transposed_convolution_x<{ wk } >" ,
386+ f"transposed_convolution_y<{ hk } >" ,
387+ ]
357388 module = load_cuda_module ("remove_stripe_fw" , name_expressions = symbol_names )
358389 dim_x = out .shape [- 1 ]
359390 dim_y = out .shape [- 2 ]
@@ -370,16 +401,20 @@ def _conv_transpose2d(
370401
371402 if wk > 1 :
372403 transposed_convolution_kernel_x = module .get_function (symbol_names [0 ])
373- transposed_convolution_kernel_x (grid_dim , block_dim ,
374- (dim_x , dim_y , dim_z , x ,
375- in_dim_x , in_stride_y , in_stride_z , w , out ))
404+ transposed_convolution_kernel_x (
405+ grid_dim ,
406+ block_dim ,
407+ (dim_x , dim_y , dim_z , x , in_dim_x , in_stride_y , in_stride_z , w , out ),
408+ )
376409 elif hk > 1 :
377410 transposed_convolution_kernel_y = module .get_function (symbol_names [1 ])
378- transposed_convolution_kernel_y (grid_dim , block_dim ,
379- (dim_x , dim_y , dim_z , x ,
380- in_dim_y , in_stride_y , in_stride_z , w , out ))
411+ transposed_convolution_kernel_y (
412+ grid_dim ,
413+ block_dim ,
414+ (dim_x , dim_y , dim_z , x , in_dim_y , in_stride_y , in_stride_z , w , out ),
415+ )
381416 else :
382- assert ( False )
417+ assert False
383418
384419 if pad != 0 :
385420 out = out [:, :, pad [0 ] : out .shape [2 ] - pad [0 ], pad [1 ] : out .shape [3 ] - pad [1 ]]
@@ -452,12 +487,8 @@ def _sfb1d(
452487 g0 = np .concatenate ([g0 .reshape (* shape )] * C , axis = 0 )
453488 g1 = np .concatenate ([g1 .reshape (* shape )] * C , axis = 0 )
454489 pad = (L - 2 , 0 ) if d == 2 else (0 , L - 2 )
455- y_lo = _conv_transpose2d (
456- lo , g0 , stride = s , pad = pad , groups = C , mem_stack = mem_stack
457- )
458- y_hi = _conv_transpose2d (
459- hi , g1 , stride = s , pad = pad , groups = C , mem_stack = mem_stack
460- )
490+ y_lo = _conv_transpose2d (lo , g0 , stride = s , pad = pad , groups = C , mem_stack = mem_stack )
491+ y_hi = _conv_transpose2d (hi , g1 , stride = s , pad = pad , groups = C , mem_stack = mem_stack )
461492 if mem_stack :
462493 # Allocation of the sum
463494 mem_stack .malloc (np .prod (y_hi ) * np .float32 ().itemsize )
@@ -600,7 +631,7 @@ def remove_stripe_fw(
600631 sigma : float
601632 Damping parameter in Fourier space.
602633 wname : str
603- Type of the wavelet filter. 'haar', 'db5 ', sym5', 'bior4.4', etc .
634+ Type of the wavelet filter: select from 'haar', 'db4 ', ' sym5', 'sym16' ' bior4.4'.
604635 level : int, optional
605636 Number of discrete wavelet transform levels.
606637 calc_peak_gpu_mem: str:
0 commit comments