Skip to content

Commit 85dcd33

Browse files
Add int overflow assert to PrefixSum (#4794)
The assert was already implemented for some backends, but not the cub and rocprim ones. Additionally, `nelms_per_block * virtual_block_id`could overflow before.
1 parent 35bdc36 commit 85dcd33

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

Src/Base/AMReX_Scan.H

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ T PrefixSum_mp (N n, FIN const& fin, FOUT const& fout, TYPE, RetSum a_ret_sum)
197197
constexpr int nthreads = nwarps_per_block*Gpu::Device::warp_size;
198198
constexpr int nchunks = 12;
199199
constexpr int nelms_per_block = nthreads * nchunks;
200-
AMREX_ALWAYS_ASSERT(static_cast<Long>(n) < static_cast<Long>(std::numeric_limits<int>::max())*nelms_per_block);
200+
AMREX_ALWAYS_ASSERT(static_cast<Long>(n) < static_cast<Long>(
201+
std::numeric_limits<int>::max())*nelms_per_block);
201202
int nblocks = (static_cast<Long>(n) + nelms_per_block - 1) / nelms_per_block;
202203
std::size_t sm = sizeof(T) * (Gpu::Device::warp_size + nwarps_per_block);
203204
auto stream = Gpu::gpuStream();
@@ -228,7 +229,7 @@ T PrefixSum_mp (N n, FIN const& fin, FOUT const& fout, TYPE, RetSum a_ret_sum)
228229
T* shared2 = shared + Gpu::Device::warp_size;
229230

230231
// Each block processes [ibegin,iend).
231-
N ibegin = nelms_per_block * blockIdxx;
232+
N ibegin = static_cast<N>(nelms_per_block) * blockIdxx;
232233
N iend = amrex::min(static_cast<N>(ibegin+nelms_per_block), n);
233234

234235
// Each block is responsible for nchunks chunks of data,
@@ -366,7 +367,7 @@ T PrefixSum_mp (N n, FIN const& fin, FOUT const& fout, TYPE, RetSum a_ret_sum)
366367
int blockDimx = gh.item->get_local_range(0);
367368

368369
// Each block processes [ibegin,iend).
369-
N ibegin = nelms_per_block * blockIdxx;
370+
N ibegin = static_cast<N>(nelms_per_block) * blockIdxx;
370371
N iend = amrex::min(static_cast<N>(ibegin+nelms_per_block), n);
371372
T prev_sum = (blockIdxx == 0) ? 0 : blocksum_p[blockIdxx-1];
372373
for (N offset = ibegin + threadIdxx; offset < iend; offset += blockDimx) {
@@ -398,7 +399,8 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE type, RetSum a_ret_sum = retSum
398399
constexpr int nthreads = nwarps_per_block*Gpu::Device::warp_size;
399400
constexpr int nchunks = 12;
400401
constexpr int nelms_per_block = nthreads * nchunks;
401-
AMREX_ALWAYS_ASSERT(static_cast<Long>(n) < static_cast<Long>(std::numeric_limits<int>::max())*nelms_per_block);
402+
AMREX_ALWAYS_ASSERT(static_cast<Long>(n) < static_cast<Long>(
403+
std::numeric_limits<int>::max())*nelms_per_block);
402404
int nblocks = (static_cast<Long>(n) + nelms_per_block - 1) / nelms_per_block;
403405

404406
#ifndef AMREX_SYCL_NO_MULTIPASS_SCAN
@@ -462,7 +464,7 @@ T PrefixSum (N n, FIN && fin, FOUT && fout, TYPE type, RetSum a_ret_sum = retSum
462464
}
463465

464466
// Each block processes [ibegin,iend).
465-
N ibegin = nelms_per_block * virtual_block_id;
467+
N ibegin = static_cast<N>(nelms_per_block) * virtual_block_id;
466468
N iend = amrex::min(static_cast<N>(ibegin+nelms_per_block), n);
467469
BlockStatusT& block_status = block_status_p[virtual_block_id];
468470

@@ -637,6 +639,8 @@ T PrefixSum (N n, FIN const& fin, FOUT const& fout, TYPE, RetSum a_ret_sum = ret
637639
constexpr int nthreads = nwarps_per_block*Gpu::Device::warp_size; // # of threads per block
638640
constexpr int nelms_per_thread = sizeof(T) >= 8 ? 8 : 16;
639641
constexpr int nelms_per_block = nthreads * nelms_per_thread;
642+
AMREX_ALWAYS_ASSERT(static_cast<Long>(n) < static_cast<Long>(
643+
std::numeric_limits<int>::max())*nelms_per_block);
640644
int nblocks = (n + nelms_per_block - 1) / nelms_per_block;
641645
std::size_t sm = 0;
642646
auto stream = Gpu::gpuStream();
@@ -713,7 +717,7 @@ T PrefixSum (N n, FIN const& fin, FOUT const& fout, TYPE, RetSum a_ret_sum = ret
713717
auto const virtual_block_id = scan_bid.get(threadIdx.x, temp_storage.ordered_bid);
714718

715719
// Each block processes [ibegin,iend).
716-
N ibegin = nelms_per_block * virtual_block_id;
720+
N ibegin = static_cast<N>(nelms_per_block) * virtual_block_id;
717721
N iend = amrex::min(static_cast<N>(ibegin+nelms_per_block), n);
718722

719723
auto input_begin = rocprim::make_transform_iterator(
@@ -800,6 +804,8 @@ T PrefixSum (N n, FIN const& fin, FOUT const& fout, TYPE, RetSum a_ret_sum = ret
800804
constexpr int nthreads = nwarps_per_block*Gpu::Device::warp_size; // # of threads per block
801805
constexpr int nelms_per_thread = sizeof(T) >= 8 ? 4 : 8;
802806
constexpr int nelms_per_block = nthreads * nelms_per_thread;
807+
AMREX_ALWAYS_ASSERT(static_cast<Long>(n) < static_cast<Long>(
808+
std::numeric_limits<int>::max())*nelms_per_block);
803809
int nblocks = (n + nelms_per_block - 1) / nelms_per_block;
804810
std::size_t sm = 0;
805811
auto stream = Gpu::gpuStream();
@@ -854,7 +860,7 @@ T PrefixSum (N n, FIN const& fin, FOUT const& fout, TYPE, RetSum a_ret_sum = ret
854860
int virtual_block_id = blockIdx.x;
855861

856862
// Each block processes [ibegin,iend).
857-
N ibegin = nelms_per_block * virtual_block_id;
863+
N ibegin = static_cast<N>(nelms_per_block) * virtual_block_id;
858864
N iend = amrex::min(static_cast<N>(ibegin+nelms_per_block), n);
859865

860866
auto input_lambda = [&] (N i) -> T { return fin(i+ibegin); };
@@ -944,7 +950,8 @@ T PrefixSum (N n, FIN const& fin, FOUT const& fout, TYPE, RetSum a_ret_sum = ret
944950
constexpr int nthreads = nwarps_per_block*Gpu::Device::warp_size;
945951
constexpr int nchunks = 12;
946952
constexpr int nelms_per_block = nthreads * nchunks;
947-
AMREX_ALWAYS_ASSERT(static_cast<Long>(n) < static_cast<Long>(std::numeric_limits<int>::max())*nelms_per_block);
953+
AMREX_ALWAYS_ASSERT(static_cast<Long>(n) < static_cast<Long>(
954+
std::numeric_limits<int>::max())*nelms_per_block);
948955
int nblocks = (static_cast<Long>(n) + nelms_per_block - 1) / nelms_per_block;
949956
std::size_t sm = sizeof(T) * (Gpu::Device::warp_size + nwarps_per_block) + sizeof(int);
950957
auto stream = Gpu::gpuStream();
@@ -997,7 +1004,7 @@ T PrefixSum (N n, FIN const& fin, FOUT const& fout, TYPE, RetSum a_ret_sum = ret
9971004
}
9981005

9991006
// Each block processes [ibegin,iend).
1000-
N ibegin = nelms_per_block * virtual_block_id;
1007+
N ibegin = static_cast<N>(nelms_per_block) * virtual_block_id;
10011008
N iend = amrex::min(static_cast<N>(ibegin+nelms_per_block), n);
10021009
BlockStatusT& block_status = block_status_p[virtual_block_id];
10031010

0 commit comments

Comments
 (0)