Skip to content

Commit 6115c15

Browse files
authored
[SYCLomatic #1641] Add tests for dpct::unique_count (#602)
Signed-off-by: Matthew Michel <[email protected]>
1 parent 5919ec7 commit 6115c15

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed

help_function/help_function.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@
148148
<test testName="onedpl_test_uninitialized_fill" configFile="config/TEMPLATE_help_function_skip_cuda_backend.xml" />
149149
<test testName="onedpl_test_unique_by_key_copy" configFile="config/TEMPLATE_help_function_skip_cuda_backend.xml" />
150150
<test testName="onedpl_test_unique_by_key" configFile="config/TEMPLATE_help_function_skip_cuda_backend.xml" />
151+
<test testName="onedpl_test_unique_count" configFile="config/TEMPLATE_help_function_skip_cuda_backend.xml" />
151152
<test testName="onedpl_test_unique" configFile="config/TEMPLATE_help_function_skip_cuda_backend.xml" />
152153
<test testName="onedpl_test_vector" configFile="config/TEMPLATE_help_function_skip_cuda_backend.xml" />
153154
<test testName="onedpl_test_group_exchange" configFile="config/TEMPLATE_help_function.xml" />
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// ====------ onedpl_test_unique_count.cpp---------- -*- C++ -* ----===////
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//
8+
// ===----------------------------------------------------------------------===//
9+
10+
#include <oneapi/dpl/execution>
11+
#include <oneapi/dpl/algorithm>
12+
13+
#include <dpct/dpct.hpp>
14+
#include <dpct/dpl_utils.hpp>
15+
16+
#include <sycl/sycl.hpp>
17+
18+
#include <iostream>
19+
#include <type_traits>
20+
21+
template <typename String, typename _T1, typename _T2>
22+
int
23+
ASSERT_EQUAL(String msg, _T1&& X, _T2&& Y)
24+
{
25+
if (X != Y)
26+
{
27+
std::cout << "FAIL: " << msg << " - (" << X << "," << Y << ")" << std::endl;
28+
return 1;
29+
}
30+
return 0;
31+
}
32+
33+
int
34+
test_passed(int failing_elems, std::string test_name)
35+
{
36+
if (failing_elems == 0)
37+
{
38+
std::cout << "PASS: " << test_name << std::endl;
39+
return 0;
40+
}
41+
return 1;
42+
}
43+
44+
template <typename Buffer>
45+
void
46+
iota_buffer(Buffer& dst_buf, int start_index, int end_index, int offset)
47+
{
48+
auto dst = dst_buf.get_host_access();
49+
for (int i = start_index; i != end_index; ++i)
50+
{
51+
dst[i] = i + offset;
52+
}
53+
}
54+
55+
struct uint32_wrapper
56+
{
57+
::std::uint32_t val;
58+
};
59+
60+
int
61+
main()
62+
{
63+
// used to detect failures
64+
int failed_tests = 0;
65+
int num_failing = 0;
66+
std::string test_name = "";
67+
sycl::queue q(dpct::get_default_queue());
68+
auto policy = oneapi::dpl::execution::make_device_policy(q);
69+
70+
// 1. Test call with n runs of the form: [1, 2, 2, 3, 3, 3, ..., n]
71+
{
72+
test_name = "Testing 7 runs of form [1, 2, 2, 3, 3, 3, ..., 7]";
73+
std::size_t n = 7;
74+
sycl::buffer<std::uint32_t> src{sycl::range<1>((n * (n + 1)) / 2)};
75+
{
76+
auto acc = src.get_host_access();
77+
for (std::size_t i = 1; i <= n; ++i)
78+
for (std::size_t j = 0; j < i; ++j)
79+
acc[j + (i * (i - 1) / 2)] = i;
80+
}
81+
auto res = dpct::unique_count(policy, dpl::begin(src), dpl::end(src));
82+
auto local_failures = ASSERT_EQUAL(test_name, res, n);
83+
test_passed(local_failures, test_name);
84+
num_failing += local_failures;
85+
}
86+
// 2. Test case where each element is distinct run
87+
{
88+
test_name = "Testing 20 runs of form [1, 2, 3, ..., 19, 20]";
89+
std::size_t n = 20;
90+
sycl::buffer<std::uint32_t> src(n);
91+
iota_buffer(src, 0, n, 1);
92+
auto res = dpct::unique_count(policy, dpl::begin(src), dpl::end(src));
93+
auto local_failures = ASSERT_EQUAL(test_name, res, n);
94+
test_passed(local_failures, test_name);
95+
num_failing += local_failures;
96+
}
97+
// 3. Test 1 element case
98+
{
99+
test_name = "Testing 1 runs of form [30]";
100+
std::size_t n = 1;
101+
sycl::buffer<std::uint32_t> src(n);
102+
{
103+
auto acc = src.get_host_access();
104+
acc[0] = 30;
105+
}
106+
auto res = dpct::unique_count(policy, dpl::begin(src), dpl::end(src));
107+
auto local_failures = ASSERT_EQUAL(test_name, res, n);
108+
test_passed(local_failures, test_name);
109+
num_failing += local_failures;
110+
}
111+
// 4. Test 0 element case
112+
{
113+
test_name = "Testing 0 runs of form []";
114+
std::size_t n = 0;
115+
sycl::buffer<std::uint32_t> src(n);
116+
auto res = dpct::unique_count(policy, dpl::begin(src), dpl::end(src));
117+
auto local_failures = ASSERT_EQUAL(test_name, res, 0);
118+
test_passed(local_failures, test_name);
119+
num_failing += local_failures;
120+
}
121+
// 5. Test custom predicate
122+
{
123+
auto is_in_group_of_three = [](auto fst, auto snd) {
124+
using T = ::std::decay_t<decltype(fst)>;
125+
return static_cast<T>(fst / 3) == static_cast<T>(snd / 3);
126+
};
127+
test_name = "Testing custom predicate grouping runs of length 3";
128+
std::size_t n = 21;
129+
sycl::buffer<std::uint32_t> src(n);
130+
iota_buffer(src, 0, n, 0);
131+
auto res = dpct::unique_count(policy, dpl::begin(src), dpl::end(src), is_in_group_of_three);
132+
auto local_failures = ASSERT_EQUAL(test_name, res, n / 3);
133+
test_passed(local_failures, test_name);
134+
num_failing += local_failures;
135+
}
136+
// 6. Test custom predicate with custom type
137+
{
138+
auto is_equal_uint32_wrapper = [](uint32_wrapper fst, uint32_wrapper snd) { return fst.val == snd.val; };
139+
test_name = "Testing 16 runs of custom predicate and custom datatype";
140+
std::size_t n = 32;
141+
sycl::buffer<uint32_wrapper> src(n);
142+
{
143+
auto acc = src.get_host_access();
144+
for (int i = 0; i < n; ++i)
145+
acc[i] = uint32_wrapper{static_cast<uint32_t>(i / 2)};
146+
}
147+
auto res = dpct::unique_count(policy, dpl::begin(src), dpl::end(src), is_equal_uint32_wrapper);
148+
auto local_failures = ASSERT_EQUAL(test_name, res, n / 2);
149+
test_passed(local_failures, test_name);
150+
num_failing += local_failures;
151+
}
152+
153+
std::cout << std::endl << failed_tests << " failing test(s) detected." << std::endl;
154+
if (failed_tests == 0)
155+
{
156+
return 0;
157+
}
158+
return 1;
159+
}

0 commit comments

Comments
 (0)