Skip to content

Commit 33b0f02

Browse files
committed
Update mem estimator and tests for very large sizes
1 parent 3d9d1f1 commit 33b0f02

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

httomolibgpu/prep/stripe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -337,8 +337,7 @@ def _conv_transpose2d(
337337
out_shape = [b, ci, hi, wi]
338338
if mem_stack:
339339
# The trouble here is that we allocate more than the returned size
340-
out_actual_bytes = np.prod(out_shape) * np.float32().itemsize
341-
mem_stack.malloc(out_actual_bytes)
340+
mem_stack.malloc((np.prod(out_shape) + w.size) * np.float32().itemsize)
342341
if pad != 0:
343342
new_out_shape = [
344343
out_shape[0],
@@ -347,7 +346,7 @@ def _conv_transpose2d(
347346
out_shape[3] - 2 * pad[1],
348347
]
349348
mem_stack.malloc(np.prod(new_out_shape) * np.float32().itemsize)
350-
mem_stack.free(np.prod(out_shape) * np.float32().itemsize)
349+
mem_stack.free((np.prod(out_shape) + w.size) * np.float32().itemsize)
351350
out_shape = new_out_shape
352351
return out_shape
353352

@@ -673,7 +672,7 @@ def remove_stripe_fw(
673672
mem_stack.free(np.prod(c) * np.float32().itemsize)
674673
mem_stack.malloc(np.prod(data) * np.float32().itemsize)
675674
mem_stack.free(np.prod(sli_shape) * np.float32().itemsize)
676-
return mem_stack.highwater
675+
return int(mem_stack.highwater * 1.1)
677676

678677
sli = cp.zeros(sli_shape, dtype="float32")
679678
sli[:, 0, (nproj_pad - nproj) // 2 : (nproj_pad + nproj) // 2] = data.swapaxes(0, 1)

tests/test_prep/test_stripe.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def ensure_clean_memory():
8686

8787

8888
@pytest.mark.parametrize("wname", ["haar", "db4", "sym5", "sym16", "bior4.4"])
89-
@pytest.mark.parametrize("slices", [55, 80])
90-
@pytest.mark.parametrize("level", [None, 1, 3, 7, 11])
89+
@pytest.mark.parametrize("slices", [3, 7, 32, 61, 109, 120, 150])
90+
@pytest.mark.parametrize("level", [None, 1, 3, 11])
9191
@pytest.mark.parametrize("dim_x", [128, 140])
9292
def test_remove_stripe_fw_calc_mem(slices, level, dim_x, wname, ensure_clean_memory):
9393
dim_y = 159
@@ -105,31 +105,32 @@ def test_remove_stripe_fw_calc_mem(slices, level, dim_x, wname, ensure_clean_mem
105105
assert hook.max_mem == 0
106106

107107
assert actual_mem_peak * 0.99 <= estimated_mem_peak
108-
assert estimated_mem_peak <= actual_mem_peak * 1.2
108+
assert estimated_mem_peak <= actual_mem_peak * 1.3
109109

110110

111-
@pytest.mark.parametrize("wname", ['db4', 'sym16'])
111+
@pytest.mark.parametrize("wname", ["haar", "db4", "sym5", "sym16", "bior4.4"])
112112
@pytest.mark.parametrize("slices", [177, 239, 320, 490, 607, 803, 859, 902, 951, 1019, 1074, 1105])
113-
def test_remove_stripe_fw_calc_mem_big(wname, slices, ensure_clean_memory):
113+
@pytest.mark.parametrize("level", [None, 7, 11])
114+
def test_remove_stripe_fw_calc_mem_big(wname, slices, level, ensure_clean_memory):
114115
dim_y = 901
115116
dim_x = 1200
116117
data_shape = (slices, dim_x, dim_y)
117118
hook = MaxMemoryHook()
118119
with hook:
119-
estimated_mem_peak = remove_stripe_fw(data_shape, wname=wname, calc_peak_gpu_mem=True)
120+
estimated_mem_peak = remove_stripe_fw(data_shape, wname=wname, level=level, calc_peak_gpu_mem=True)
120121
assert hook.max_mem == 0
121122
av_mem = cp.cuda.Device().mem_info[0]
122-
if av_mem < estimated_mem_peak * 1.1:
123+
if av_mem < estimated_mem_peak:
123124
pytest.skip("Not enough GPU memory to run this test")
124125

125126
hook = MaxMemoryHook()
126127
with hook:
127128
data = cp.random.random_sample(data_shape, dtype=np.float32)
128-
remove_stripe_fw(data, wname=wname)
129+
remove_stripe_fw(data, wname=wname, level=level)
129130
actual_mem_peak = hook.max_mem
130131

131132
assert actual_mem_peak * 0.99 <= estimated_mem_peak
132-
assert estimated_mem_peak <= actual_mem_peak * 1.2
133+
assert estimated_mem_peak <= actual_mem_peak * 1.3
133134

134135

135136
@pytest.mark.parametrize("angles", [180, 181])

0 commit comments

Comments
 (0)