Skip to content

Commit ecab0d7

Browse files
author
Yihan Wang
authored
[SYCLomatic #991] Add test for cub::LaneId and cub::WarpId (#369)
Signed-off-by: Wang, Yihan <[email protected]>
1 parent de04571 commit ecab0d7

File tree

1 file changed

+51
-3
lines changed

1 file changed

+51
-3
lines changed

features/feature_case/cub/cub_intrinsic.cu

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,57 @@ bool test_iadd3() {
3131
iadd3(0, 1, 0);
3232
}
3333

34-
int main() {
35-
if (!test_iadd3()) {
36-
return 1;
34+
__global__ void laneid_and_warpid(int *laneids, int *warpids) {
35+
unsigned tid =
36+
((blockIdx.x + (blockIdx.y * gridDim.x)) * (blockDim.x * blockDim.y)) +
37+
(threadIdx.x + (threadIdx.y * blockDim.x));
38+
laneids[tid] = cub::LaneId();
39+
warpids[tid] = cub::WarpId();
40+
}
41+
42+
bool test_laneid_warpid() {
43+
int *d_warpids, *d_laneids;
44+
cudaMalloc(&d_laneids, sizeof(int) * 66);
45+
cudaMalloc(&d_warpids, sizeof(int) * 66);
46+
laneid_and_warpid<<<2, 33>>>(d_laneids, d_warpids);
47+
cudaDeviceSynchronize();
48+
int laneids[66] = {0}, warpids[66] = {0};
49+
cudaMemcpy(laneids, d_laneids, sizeof(int) * 66, cudaMemcpyDeviceToHost);
50+
cudaMemcpy(warpids, d_warpids, sizeof(int) * 66, cudaMemcpyDeviceToHost);
51+
cudaDeviceSynchronize();
52+
std::map<int, int> cnt_laneid, cnt_warpid, cnt_laneid_num;
53+
for (int I = 0; I < 66; ++I) {
54+
cnt_warpid[warpids[I]]++;
55+
cnt_laneid[laneids[I]]++;
3756
}
57+
58+
int total_warpid = 0;
59+
for (const auto &[k, v] : cnt_warpid)
60+
total_warpid += v;
61+
for (const auto &[k, v] : cnt_laneid)
62+
cnt_laneid_num[v]++;
63+
64+
auto check_laneid_num = [&]() {
65+
if (cnt_laneid_num.size() != 2)
66+
return false;
67+
const auto first = *cnt_laneid_num.begin();
68+
const auto second = *std::next(cnt_laneid_num.begin());
69+
return first.first + 2 == second.first;
70+
};
71+
72+
cudaFree(d_laneids);
73+
cudaFree(d_warpids);
74+
return total_warpid == 66 && check_laneid_num();
75+
}
76+
77+
#define TEST(FUNC) \
78+
if (!FUNC()) { \
79+
printf(#FUNC " failed\n"); \
80+
return 1; \
81+
}
82+
83+
int main() {
84+
TEST(test_iadd3);
85+
TEST(test_laneid_warpid);
3886
return 0;
3987
}

0 commit comments

Comments
 (0)