Skip to content

Commit 9cafe29

Browse files
author
Yihan Wang
authored
[SYCLomatic #719] Add test for 10 CUB APIs (#274)
Signed-off-by: Wang, Yihan <[email protected]>
1 parent b1bfa0e commit 9cafe29

File tree

7 files changed

+1038
-32
lines changed

7 files changed

+1038
-32
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#include <cub/cub.cuh>
2+
3+
template <typename T> T *init(std::initializer_list<T> list) {
4+
T *p = nullptr;
5+
cudaMalloc<T>(&p, sizeof(T) * list.size());
6+
cudaMemcpy(p, list.begin(), sizeof(T) * list.size(), cudaMemcpyHostToDevice);
7+
return p;
8+
}
9+
10+
bool test_arg_max() {
11+
int num_items = 7;
12+
int *d_in = init({8, 6, 7, 5, 3, 0, 9});
13+
cub::KeyValuePair<int, int> *d_out =
14+
init<cub::KeyValuePair<int, int>>({{-1, -1}}),
15+
out;
16+
void *d_temp_storage = NULL;
17+
size_t temp_storage_bytes = 0;
18+
cub::DeviceReduce::ArgMax(d_temp_storage, temp_storage_bytes, d_in, d_out,
19+
num_items);
20+
cudaMalloc(&d_temp_storage, temp_storage_bytes);
21+
cub::DeviceReduce::ArgMax(d_temp_storage, temp_storage_bytes, d_in, d_out,
22+
num_items);
23+
cudaFree(d_temp_storage);
24+
cudaMemcpy(&out, d_out, sizeof(out), cudaMemcpyDeviceToHost);
25+
return out.key == 6 && out.value == 9;
26+
}
27+
28+
bool test_arg_max_non_defaule_stream() {
29+
int num_items = 7;
30+
int *d_in = init({8, 6, 7, 5, 3, 0, 9});
31+
cub::KeyValuePair<int, int> *d_out =
32+
init<cub::KeyValuePair<int, int>>({{-1, -1}}),
33+
out;
34+
cudaStream_t s = nullptr;
35+
cudaStreamCreate(&s);
36+
void *d_temp_storage = NULL;
37+
size_t temp_storage_bytes = 0;
38+
cub::DeviceReduce::ArgMax(d_temp_storage, temp_storage_bytes, d_in, d_out,
39+
num_items, s);
40+
cudaMalloc(&d_temp_storage, temp_storage_bytes);
41+
cub::DeviceReduce::ArgMax(d_temp_storage, temp_storage_bytes, d_in, d_out,
42+
num_items, s);
43+
cudaFree(d_temp_storage);
44+
cudaMemcpy(&out, d_out, sizeof(out), cudaMemcpyDeviceToHost);
45+
cudaStreamDestroy(s);
46+
return out.key == 6 && out.value == 9;
47+
}
48+
49+
bool test_arg_min() {
50+
int num_items = 7;
51+
int *d_in = init({8, 6, 7, 5, 3, 0, 9});
52+
cub::KeyValuePair<int, int> *d_out =
53+
init<cub::KeyValuePair<int, int>>({{-1, -1}}),
54+
out;
55+
void *d_temp_storage = NULL;
56+
size_t temp_storage_bytes = 0;
57+
cub::DeviceReduce::ArgMin(d_temp_storage, temp_storage_bytes, d_in, d_out,
58+
num_items);
59+
cudaMalloc(&d_temp_storage, temp_storage_bytes);
60+
cub::DeviceReduce::ArgMin(d_temp_storage, temp_storage_bytes, d_in, d_out,
61+
num_items);
62+
cudaFree(d_temp_storage);
63+
cudaMemcpy(&out, d_out, sizeof(out), cudaMemcpyDeviceToHost);
64+
return out.key == 5 && out.value == 0;
65+
}
66+
67+
bool test_arg_min_non_default_stream() {
68+
int num_items = 7;
69+
int *d_in = init({8, 6, 7, 5, 3, 0, 9});
70+
cub::KeyValuePair<int, int> *d_out =
71+
init<cub::KeyValuePair<int, int>>({{-1, -1}}),
72+
out;
73+
cudaStream_t s = nullptr;
74+
cudaStreamCreate(&s);
75+
void *d_temp_storage = NULL;
76+
size_t temp_storage_bytes = 0;
77+
cub::DeviceReduce::ArgMin(d_temp_storage, temp_storage_bytes, d_in, d_out,
78+
num_items, s);
79+
cudaMalloc(&d_temp_storage, temp_storage_bytes);
80+
cub::DeviceReduce::ArgMin(d_temp_storage, temp_storage_bytes, d_in, d_out,
81+
num_items, s);
82+
cudaFree(d_temp_storage);
83+
cudaMemcpy(&out, d_out, sizeof(out), cudaMemcpyDeviceToHost);
84+
cudaStreamDestroy(s);
85+
return out.key == 5 && out.value == 0;
86+
}
87+
88+
int main() {
89+
int res = 0;
90+
if (!test_arg_max()) {
91+
res = 1;
92+
std::cout << "cub::DeviceReduce::ArgMax test failed\n";
93+
}
94+
95+
if (!test_arg_max_non_defaule_stream()) {
96+
res = 1;
97+
std::cout << "cub::DeviceReduce::ArgMax(Non default stream) test failed\n";
98+
}
99+
100+
if (!test_arg_min()) {
101+
res = 1;
102+
std::cout << "cub::DeviceReduce::ArgMin test failed\n";
103+
}
104+
105+
if (!test_arg_min_non_default_stream()) {
106+
res = 1;
107+
std::cout << "cub::DeviceReduce::ArgMin(Non default stream) test failed\n";
108+
}
109+
110+
return res;
111+
}

0 commit comments

Comments
 (0)