Skip to content

Commit 6531dab

Browse files
committed
Add test_remove_stripe_fw_performance
1 parent 8bd38ef commit 6531dab

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/test_prep/test_stripe.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,31 @@ def test_remove_stripe_fw_calc_mem_big(wname, slices, level, dims, ensure_clean_
137137
assert estimated_mem_peak <= actual_mem_peak * 1.3
138138

139139

140+
@pytest.mark.perf
141+
def test_remove_stripe_fw_performance(ensure_clean_memory):
142+
data_host = (
143+
np.random.random_sample(size=(1801, 5, 2560)).astype(np.float32) * 2.0 + 0.001
144+
)
145+
data = cp.asarray(data_host, dtype=np.float32)
146+
147+
# do a cold run first
148+
remove_stripe_fw(cp.copy(data))
149+
150+
dev = cp.cuda.Device()
151+
dev.synchronize()
152+
153+
start = time.perf_counter_ns()
154+
nvtx.RangePush("Core")
155+
for _ in range(10):
156+
# have to take copy, as data is modified in-place
157+
remove_stripe_fw(cp.copy(data))
158+
nvtx.RangePop()
159+
dev.synchronize()
160+
duration_ms = float(time.perf_counter_ns() - start) * 1e-6 / 10
161+
162+
assert "performance in ms" == duration_ms
163+
164+
140165
@pytest.mark.parametrize("angles", [180, 181])
141166
@pytest.mark.parametrize("det_x", [11, 18])
142167
@pytest.mark.parametrize("det_y", [5, 7, 8])

0 commit comments

Comments
 (0)