Skip to content

Commit b8afff4

Browse files
authored
[SYCL] Add missing double support in cross() (#3017)
This patch adds missing double3 and double4 support in geometric function cross(). Currently, it only supports float types. Signed-off-by: Nawrin Sultana <[email protected]>
1 parent d8c8e08 commit b8afff4

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

sycl/include/CL/sycl/builtins.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -950,8 +950,8 @@ mul24(T x, T y) __NOEXC {
950950
// half3 cross (half3 p0, half3 p1)
951951
// half4 cross (half4 p0, half4 p1)
952952
template <typename T>
953-
detail::enable_if_t<detail::is_gencrossfloat<T>::value, T> cross(T p0,
954-
T p1) __NOEXC {
953+
detail::enable_if_t<detail::is_gencross<T>::value, T> cross(T p0,
954+
T p1) __NOEXC {
955955
return __sycl_std::__invoke_cross<T>(p0, p1);
956956
}
957957

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out
2+
// RUN: %RUN_ON_HOST %t.out
3+
4+
#include <CL/sycl.hpp>
5+
#include <cassert>
6+
#include <cmath>
7+
8+
bool isEqualTo(double x, double y, double epsilon = 0.001) {
9+
return std::fabs(x - y) <= epsilon;
10+
}
11+
12+
int main(int argc, const char **argv) {
13+
cl::sycl::cl_double4 r{0};
14+
{
15+
cl::sycl::buffer<cl::sycl::cl_double4, 1> BufR(&r, cl::sycl::range<1>(1));
16+
cl::sycl::queue myQueue;
17+
myQueue.submit([&](cl::sycl::handler &cgh) {
18+
auto AccR = BufR.get_access<cl::sycl::access::mode::write>(cgh);
19+
cgh.single_task<class crossD4>([=]() {
20+
AccR[0] = cl::sycl::cross(
21+
cl::sycl::cl_double4{
22+
2.5,
23+
3.0,
24+
4.0,
25+
0.0,
26+
},
27+
cl::sycl::cl_double4{
28+
5.2,
29+
6.0,
30+
7.0,
31+
0.0,
32+
});
33+
});
34+
});
35+
}
36+
cl::sycl::cl_double r1 = r.x();
37+
cl::sycl::cl_double r2 = r.y();
38+
cl::sycl::cl_double r3 = r.z();
39+
cl::sycl::cl_double r4 = r.w();
40+
41+
assert(isEqualTo(r1, -3.0));
42+
assert(isEqualTo(r2, 3.3));
43+
assert(isEqualTo(r3, -0.6));
44+
assert(isEqualTo(r4, 0.0));
45+
46+
return 0;
47+
}

0 commit comments

Comments
 (0)