Skip to content

Commit c849133

Browse files
dkazancmfep
authored andcommitted
minor correction to wname description
1 parent 33b0f02 commit c849133

File tree

2 files changed

+59
-24
lines changed

2 files changed

+59
-24
lines changed

httomolibgpu/prep/stripe.py

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -305,17 +305,45 @@ def _conv2d(
305305

306306
if groups == 1:
307307
grouped_convolution_kernel_x = module.get_function(symbol_names[0])
308-
grouped_convolution_kernel_x(grid_dim, block_dim,
309-
(dim_x, dim_y, dim_z, x, in_stride_x, in_stride_y,
310-
in_stride_z, out, out_stride_z, out_stride_group, w))
308+
grouped_convolution_kernel_x(
309+
grid_dim,
310+
block_dim,
311+
(
312+
dim_x,
313+
dim_y,
314+
dim_z,
315+
x,
316+
in_stride_x,
317+
in_stride_y,
318+
in_stride_z,
319+
out,
320+
out_stride_z,
321+
out_stride_group,
322+
w,
323+
),
324+
)
311325
return out
312326

313327
grouped_convolution_kernel_y = module.get_function(symbol_names[1])
314328
in_stride_group = x.strides[2] // x.dtype.itemsize
315-
grouped_convolution_kernel_y(grid_dim, block_dim,
316-
(dim_x, dim_y, dim_z, x, in_stride_x, in_stride_y,
317-
in_stride_z, in_stride_group, out, out_stride_z,
318-
out_stride_group, w))
329+
grouped_convolution_kernel_y(
330+
grid_dim,
331+
block_dim,
332+
(
333+
dim_x,
334+
dim_y,
335+
dim_z,
336+
x,
337+
in_stride_x,
338+
in_stride_y,
339+
in_stride_z,
340+
in_stride_group,
341+
out,
342+
out_stride_z,
343+
out_stride_group,
344+
w,
345+
),
346+
)
319347
del w
320348
return out
321349

@@ -353,7 +381,10 @@ def _conv_transpose2d(
353381
out = cp.zeros(out_shape, dtype="float32")
354382
w = cp.asarray(w)
355383

356-
symbol_names = [f"transposed_convolution_x<{wk}>", f"transposed_convolution_y<{hk}>"]
384+
symbol_names = [
385+
f"transposed_convolution_x<{wk}>",
386+
f"transposed_convolution_y<{hk}>",
387+
]
357388
module = load_cuda_module("remove_stripe_fw", name_expressions=symbol_names)
358389
dim_x = out.shape[-1]
359390
dim_y = out.shape[-2]
@@ -370,16 +401,20 @@ def _conv_transpose2d(
370401

371402
if wk > 1:
372403
transposed_convolution_kernel_x = module.get_function(symbol_names[0])
373-
transposed_convolution_kernel_x(grid_dim, block_dim,
374-
(dim_x, dim_y, dim_z, x,
375-
in_dim_x, in_stride_y, in_stride_z, w, out))
404+
transposed_convolution_kernel_x(
405+
grid_dim,
406+
block_dim,
407+
(dim_x, dim_y, dim_z, x, in_dim_x, in_stride_y, in_stride_z, w, out),
408+
)
376409
elif hk > 1:
377410
transposed_convolution_kernel_y = module.get_function(symbol_names[1])
378-
transposed_convolution_kernel_y(grid_dim, block_dim,
379-
(dim_x, dim_y, dim_z, x,
380-
in_dim_y, in_stride_y, in_stride_z, w, out))
411+
transposed_convolution_kernel_y(
412+
grid_dim,
413+
block_dim,
414+
(dim_x, dim_y, dim_z, x, in_dim_y, in_stride_y, in_stride_z, w, out),
415+
)
381416
else:
382-
assert(False)
417+
assert False
383418

384419
if pad != 0:
385420
out = out[:, :, pad[0] : out.shape[2] - pad[0], pad[1] : out.shape[3] - pad[1]]
@@ -452,12 +487,8 @@ def _sfb1d(
452487
g0 = np.concatenate([g0.reshape(*shape)] * C, axis=0)
453488
g1 = np.concatenate([g1.reshape(*shape)] * C, axis=0)
454489
pad = (L - 2, 0) if d == 2 else (0, L - 2)
455-
y_lo = _conv_transpose2d(
456-
lo, g0, stride=s, pad=pad, groups=C, mem_stack=mem_stack
457-
)
458-
y_hi = _conv_transpose2d(
459-
hi, g1, stride=s, pad=pad, groups=C, mem_stack=mem_stack
460-
)
490+
y_lo = _conv_transpose2d(lo, g0, stride=s, pad=pad, groups=C, mem_stack=mem_stack)
491+
y_hi = _conv_transpose2d(hi, g1, stride=s, pad=pad, groups=C, mem_stack=mem_stack)
461492
if mem_stack:
462493
# Allocation of the sum
463494
mem_stack.malloc(np.prod(y_hi) * np.float32().itemsize)
@@ -600,7 +631,7 @@ def remove_stripe_fw(
600631
sigma : float
601632
Damping parameter in Fourier space.
602633
wname : str
603-
Type of the wavelet filter. 'haar', 'db5', sym5', 'bior4.4', etc.
634+
Type of the wavelet filter: select from 'haar', 'db4', 'sym5', 'sym16' 'bior4.4'.
604635
level : int, optional
605636
Number of discrete wavelet transform levels.
606637
calc_peak_gpu_mem: str:

tests/test_prep/test_stripe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,19 @@ def test_remove_stripe_fw_calc_mem(slices, level, dim_x, wname, ensure_clean_mem
109109

110110

111111
@pytest.mark.parametrize("wname", ["haar", "db4", "sym5", "sym16", "bior4.4"])
112-
@pytest.mark.parametrize("slices", [177, 239, 320, 490, 607, 803, 859, 902, 951, 1019, 1074, 1105])
112+
@pytest.mark.parametrize(
113+
"slices", [177, 239, 320, 490, 607, 803, 859, 902, 951, 1019, 1074, 1105]
114+
)
113115
@pytest.mark.parametrize("level", [None, 7, 11])
114116
def test_remove_stripe_fw_calc_mem_big(wname, slices, level, ensure_clean_memory):
115117
dim_y = 901
116118
dim_x = 1200
117119
data_shape = (slices, dim_x, dim_y)
118120
hook = MaxMemoryHook()
119121
with hook:
120-
estimated_mem_peak = remove_stripe_fw(data_shape, wname=wname, level=level, calc_peak_gpu_mem=True)
122+
estimated_mem_peak = remove_stripe_fw(
123+
data_shape, wname=wname, level=level, calc_peak_gpu_mem=True
124+
)
121125
assert hook.max_mem == 0
122126
av_mem = cp.cuda.Device().mem_info[0]
123127
if av_mem < estimated_mem_peak:

0 commit comments

Comments
 (0)