|
7 | 7 | //
|
8 | 8 | // ===----------------------------------------------------------------------===//
|
9 | 9 |
|
| 10 | +#include <algorithm> |
10 | 11 | #include <iostream>
|
11 | 12 | #include <vector>
|
12 | 13 |
|
|
15 | 16 |
|
16 | 17 | #define DATA_NUM 100
|
17 | 18 |
|
| 19 | +template<typename T = int> |
| 20 | +T *init_data(std::initializer_list<T> init) { |
| 21 | + T *Ptr = nullptr; |
| 22 | + cudaMallocManaged(&Ptr, sizeof(T) * init.size()); |
| 23 | + memcpy(Ptr, init.begin(), sizeof(T) * init.size()); |
| 24 | + return Ptr; |
| 25 | +} |
| 26 | + |
18 | 27 | template<typename T = int>
|
19 | 28 | void init_data(T* data, int num) {
|
20 | 29 | T host_data[DATA_NUM];
|
@@ -248,13 +257,88 @@ bool test_max(){
|
248 | 257 | return true;
|
249 | 258 | }
|
250 | 259 |
|
| 260 | +std::ostream &operator<<(std::ostream &os, const cub::KeyValuePair<int, int> &kv) { |
| 261 | + os << '[' << kv.key << ", " << kv.value << ']'; |
| 262 | + return os; |
| 263 | +} |
| 264 | + |
| 265 | +bool test_arg_min() { |
| 266 | + int num_segs = 3; |
| 267 | + int *offset = init_data({0, 3, 3, 7}); |
| 268 | + int *in = init_data({8, 6, 7, 5, 3, 0, 9}); |
| 269 | + |
| 270 | + cub::KeyValuePair<int, int> *out = init_data<cub::KeyValuePair<int, int>>({{}, {}, {}}); |
| 271 | + cub::KeyValuePair<int, int> expected[] = {{1, 6}, {1, INT_MAX}, {2, 0}}; |
| 272 | + |
| 273 | + // CHECK-DPCT1026 DPCT1026:{{.*}}: The call to cub::DeviceSegmentedReduce::ArgMin was removed because this call is redundant in SYCL. |
| 274 | + // CHECK: dpct::segmented_reduce_argmin(oneapi::dpl::execution::device_policy(q_ct1), in, out, num_segs, offset, offset + 1); |
| 275 | + void *tmp_storage = nullptr; |
| 276 | + size_t tmp_storage_size = 0; |
| 277 | + cub::DeviceSegmentedReduce::ArgMin(tmp_storage, tmp_storage_size, in, out, num_segs, offset, offset + 1); |
| 278 | + cudaMalloc(&tmp_storage, tmp_storage_size); |
| 279 | + cub::DeviceSegmentedReduce::ArgMin(tmp_storage, tmp_storage_size, in, out, num_segs, offset, offset + 1); |
| 280 | + cudaDeviceSynchronize(); |
| 281 | + |
| 282 | + auto cmp = [](const cub::KeyValuePair<int, int> &lhs, const cub::KeyValuePair<int, int> &rhs) -> bool { |
| 283 | + return lhs.value == rhs.value && lhs.key == rhs.key; |
| 284 | + }; |
| 285 | + |
| 286 | + if (!std::equal(out, out + num_segs, expected, cmp)) { |
| 287 | + std::cout << "ArgMin verify failed!\n"; |
| 288 | + std::cout << "expect: "; |
| 289 | + std::for_each(expected, expected + num_segs, [](const auto &v) { std::cout << v << " "; }); |
| 290 | + std::cout << "\n"; |
| 291 | + std::cout<< "current result: "; |
| 292 | + std::for_each(expected, expected + num_segs, [](const auto &v) { std::cout << v << " "; }); |
| 293 | + std::cout << "\n"; |
| 294 | + return false; |
| 295 | + } |
| 296 | + return true; |
| 297 | +} |
| 298 | + |
| 299 | +bool test_arg_max() { |
| 300 | + int num_segs = 3; |
| 301 | + int *offset = init_data({0, 3, 3, 7}); |
| 302 | + int *in = init_data({8, 6, 7, 5, 3, 0, 9}); |
| 303 | + |
| 304 | + cub::KeyValuePair<int, int> *out = init_data<cub::KeyValuePair<int, int>>({{}, {}, {}}); |
| 305 | + cub::KeyValuePair<int, int> expected[] = {{0, 8}, {1, INT_MIN}, {3, 9}}; |
| 306 | + |
| 307 | + // CHECK-DPCT1026 DPCT1026:{{.*}}: The call to cub::DeviceSegmentedReduce::ArgMax was removed because this call is redundant in SYCL. |
| 308 | + // CHECK: dpct::segmented_reduce_argmax(oneapi::dpl::execution::device_policy(q_ct1), in, out, num_segs, offset, offset + 1); |
| 309 | + void *tmp_storage = nullptr; |
| 310 | + size_t tmp_storage_size = 0; |
| 311 | + cub::DeviceSegmentedReduce::ArgMax(tmp_storage, tmp_storage_size, in, out, num_segs, offset, offset + 1); |
| 312 | + cudaMalloc(&tmp_storage, tmp_storage_size); |
| 313 | + cub::DeviceSegmentedReduce::ArgMax(tmp_storage, tmp_storage_size, in, out, num_segs, offset, offset + 1); |
| 314 | + cudaDeviceSynchronize(); |
| 315 | + |
| 316 | + auto cmp = [](const cub::KeyValuePair<int, int> &lhs, const cub::KeyValuePair<int, int> &rhs) -> bool { |
| 317 | + return lhs.value == rhs.value && lhs.key == rhs.key; |
| 318 | + }; |
| 319 | + |
| 320 | + if (!std::equal(out, out + num_segs, expected, cmp)) { |
| 321 | + std::cout << "ArgMax verify failed!\n"; |
| 322 | + std::cout << "expect: "; |
| 323 | + std::for_each(expected, expected + num_segs, [](const auto &v) { std::cout << v << " "; }); |
| 324 | + std::cout << "\n"; |
| 325 | + std::cout<< "current result: "; |
| 326 | + std::for_each(expected, expected + num_segs, [](const auto &v) { std::cout << v << " "; }); |
| 327 | + std::cout << "\n"; |
| 328 | + return false; |
| 329 | + } |
| 330 | + return true; |
| 331 | +} |
| 332 | + |
251 | 333 | int main() {
|
252 | 334 | bool Result = true;
|
253 | 335 | Result = test_reduce_1() && Result;
|
254 | 336 | Result = test_sum_1() && Result;
|
255 | 337 | Result = test_sum_2() && Result;
|
256 | 338 | Result = test_min() && Result;
|
257 | 339 | Result = test_max() && Result;
|
| 340 | + Result = test_arg_min() && Result; |
| 341 | + Result = test_arg_max() && Result; |
258 | 342 | if(Result) {
|
259 | 343 | std::cout << "cub_device Pass" << std::endl;
|
260 | 344 | return 0;
|
|
0 commit comments