Skip to content

Commit 0ebaa16

Browse files
authored
[SYCLomatic] Fix compile error of member function operator() is not marked const for migrated code of thrust::make_transform_iterator (#2236)
Signed-off-by: chenwei.sun <[email protected]>
1 parent e72f53c commit 0ebaa16

File tree

2 files changed

+110
-4
lines changed

2 files changed

+110
-4
lines changed

clang/lib/DPCT/APINamesThrust.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ CALL_FACTORY_ENTRY("thrust::make_permutation_iterator",
462462
// thrust::make_transform_iterator
463463
CALL_FACTORY_ENTRY("thrust::make_transform_iterator",
464464
CALL("oneapi::dpl::make_transform_iterator", ARG(0),
465-
ARG(1)))
465+
THRUST_FUNCTOR(1)))
466466

467467
// thrust::norm
468468
CALL_FACTORY_ENTRY("thrust::norm", CALL("std::norm", ARG(0)))

clang/test/dpct/thrust-cast.cu

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99
// CHECK-NEXT: #include <dpct/dpct.hpp>
1010
// CHECK-NEXT: #include <complex>
1111
// CHECK-NEXT: #include <dpct/dpl_utils.hpp>
12+
#include <cuda_runtime.h>
1213
#include <thrust/complex.h>
14+
#include <thrust/device_free.h>
15+
#include <thrust/device_malloc.h>
1316
#include <thrust/device_ptr.h>
1417
#include <thrust/device_vector.h>
15-
#include <thrust/device_malloc.h>
16-
#include <thrust/device_free.h>
17-
#include <cuda_runtime.h>
18+
#include <thrust/functional.h>
19+
#include <thrust/iterator/counting_iterator.h>
20+
#include <thrust/iterator/transform_iterator.h>
21+
#include <thrust/scatter.h>
1822

1923
// CHECK: void kernel(std::complex<double> *det) {}
2024
__global__ void kernel(thrust::complex<double> *det) {}
@@ -55,3 +59,105 @@ void foo(thrust::device_vector<float> &vec, const int i, const int j) {
5559
//CHECK-NEXT: });
5660
kernel2<<<1, 1>>>(thrust::raw_pointer_cast(vec.data()) + i * j);
5761
}
62+
63+
//CHECK: /*
64+
//CHECK-NEXT: DPCT1044:{{[0-9]+}}: thrust::unary_function was removed because std::unary_function has been deprecated in C++11. You may need to remove references to typedefs from thrust::unary_function in the class definition.
65+
//CHECK-NEXT: */
66+
//CHECK-NEXT: struct transpose_102_index {
67+
//CHECK-NEXT: const size_t m, n, p;
68+
//CHECK-NEXT: transpose_102_index(size_t _m, size_t _n, size_t _p) : m(_m), n(_n), p(_p) {}
69+
//CHECK-NEXT: size_t operator()(size_t linear_index) const {
70+
//CHECK-NEXT: size_t i = linear_index / (n * p);
71+
//CHECK-NEXT: size_t rmdr = linear_index % (n * p);
72+
//CHECK-NEXT: size_t j = rmdr / p;
73+
//CHECK-NEXT: size_t k = rmdr % p;
74+
//CHECK-NEXT: return m * p * j + p * i + k;
75+
//CHECK-NEXT: }
76+
//CHECK-NEXT: };
77+
struct transpose_102_index : public thrust::unary_function<size_t, size_t> {
78+
const size_t m, n, p;
79+
__host__ __device__ transpose_102_index(size_t _m, size_t _n, size_t _p) : m(_m), n(_n), p(_p) {}
80+
__host__ __device__ size_t operator()(size_t linear_index) {
81+
size_t i = linear_index / (n * p);
82+
size_t rmdr = linear_index % (n * p);
83+
size_t j = rmdr / p;
84+
size_t k = rmdr % p;
85+
return m * p * j + p * i + k;
86+
}
87+
};
88+
89+
//CHECK: /*
90+
//CHECK-NEXT: DPCT1044:{{[0-9]+}}: thrust::unary_function was removed because std::unary_function has been deprecated in C++11. You may need to remove references to typedefs from thrust::unary_function in the class definition.
91+
//CHECK-NEXT: */
92+
//CHECK-NEXT:struct transpose_201_index {
93+
//CHECK-NEXT: const size_t m, n, p;
94+
//CHECK-NEXT: transpose_201_index(size_t _m, size_t _n, size_t _p) : m(_m), n(_n), p(_p) {}
95+
//CHECK-NEXT: size_t operator()(size_t linear_index) const {
96+
//CHECK-NEXT: size_t i = linear_index / (n * p);
97+
//CHECK-NEXT: size_t rmdr = linear_index % (n * p);
98+
//CHECK-NEXT: size_t j = rmdr / p;
99+
//CHECK-NEXT: size_t k = rmdr % p;
100+
//CHECK-NEXT: return m * n * k + n * i + j;
101+
//CHECK-NEXT: }
102+
//CHECK-NEXT:};
103+
struct transpose_201_index : public thrust::unary_function<size_t, size_t> {
104+
const size_t m, n, p;
105+
__host__ __device__ transpose_201_index(size_t _m, size_t _n, size_t _p) : m(_m), n(_n), p(_p) {}
106+
__host__ __device__ size_t operator()(size_t linear_index) {
107+
size_t i = linear_index / (n * p);
108+
size_t rmdr = linear_index % (n * p);
109+
size_t j = rmdr / p;
110+
size_t k = rmdr % p;
111+
return m * n * k + n * i + j;
112+
}
113+
};
114+
115+
//CHECK: /*
116+
//CHECK-NEXT: DPCT1044:{{[0-9]+}}: thrust::unary_function was removed because std::unary_function has been deprecated in C++11. You may need to remove references to typedefs from thrust::unary_function in the class definition.
117+
//CHECK-NEXT: */
118+
//CHECK-NEXT:struct transpose_10_index {
119+
//CHECK-NEXT: const size_t m, n;
120+
//CHECK-NEXT: transpose_10_index(size_t _m, size_t _n) : m(_m), n(_n) {}
121+
//CHECK-NEXT: size_t operator()(size_t linear_index) const {
122+
//CHECK-NEXT: size_t i = linear_index / n;
123+
//CHECK-NEXT: size_t rmdr = linear_index % n;
124+
//CHECK-NEXT: size_t j = rmdr;
125+
//CHECK-NEXT: return j * m + i;
126+
//CHECK-NEXT: }
127+
//CHECK-NEXT:};
128+
struct transpose_10_index : public thrust::unary_function<size_t, size_t> {
129+
const size_t m, n;
130+
__host__ __device__ transpose_10_index(size_t _m, size_t _n) : m(_m), n(_n) {}
131+
__host__ __device__ size_t operator()(size_t linear_index) {
132+
size_t i = linear_index / n;
133+
size_t rmdr = linear_index % n;
134+
size_t j = rmdr;
135+
return j * m + i;
136+
}
137+
};
138+
139+
void transpose_102(size_t m, size_t n, size_t p, thrust::device_vector<float> &src,
140+
thrust::device_vector<float> &dst) {
141+
thrust::counting_iterator<size_t> indices(0);
142+
thrust::scatter(src.begin(), src.end(),
143+
thrust::make_transform_iterator(indices, transpose_102_index(m, n, p)),
144+
dst.begin());
145+
}
146+
147+
void transpose_201(size_t m, size_t n, size_t p, thrust::device_vector<float> &src,
148+
thrust::device_vector<float> &dst) {
149+
thrust::counting_iterator<size_t> indices(0);
150+
thrust::scatter(src.begin(), src.end(),
151+
thrust::make_transform_iterator(indices, transpose_201_index(m, n, p)),
152+
dst.begin());
153+
}
154+
155+
template <typename T>
156+
void transpose_2d(size_t m, size_t n, thrust::device_vector<T> &src, thrust::device_vector<T> &dst) {
157+
static_assert(std::is_integral<T>::value || std::is_floating_point<T>::value || std::is_same<T, __half>::value,
158+
"T must be float type or integer type");
159+
thrust::counting_iterator<size_t> indices(0);
160+
thrust::scatter(src.begin(), src.end(),
161+
thrust::make_transform_iterator(indices, transpose_10_index(m, n)),
162+
dst.begin());
163+
}

0 commit comments

Comments
 (0)