Skip to content

Commit dbba3c0

Browse files
authored
[SYCLomatic] Migrate thrust::zip_iterator to dpct::zip_iterator (#563)
Introduce help type dpct::zip_iterator which only accept std::tuple as template argument to migrate thrust::zip_iterator. Signed-off-by: Jiang, Zhiwei <[email protected]>
1 parent d768db6 commit dbba3c0

File tree

6 files changed

+67
-5
lines changed

6 files changed

+67
-5
lines changed

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2172,7 +2172,7 @@ void TypeInDeclRule::registerMatcher(MatchFinder &MF) {
21722172
"cudaDataType_t", "cudaDataType", "cublasComputeType_t",
21732173
"cublasAtomicsMode_t", "CUmem_advise_enum", "CUmem_advise",
21742174
"thrust::tuple_element", "thrust::tuple_size", "cublasMath_t",
2175-
"cudaPointerAttributes")
2175+
"cudaPointerAttributes", "thrust::zip_iterator")
21762176
)))))
21772177
.bind("cudaTypeDef"),
21782178
this);

clang/lib/DPCT/MapNames.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ void MapNames::setExplicitNamespaceMap() {
282282
std::make_shared<TypeNameRule>("std::tuple_size")},
283283
{"thrust::swap",
284284
std::make_shared<TypeNameRule>("std::swap")},
285+
{"thrust::zip_iterator",
286+
std::make_shared<TypeNameRule>(
287+
getDpctNamespace() + "zip_iterator",
288+
HelperFeatureEnum::DplExtrasIterators_zip_iterator)},
285289
{"cusolverDnHandle_t",
286290
std::make_shared<TypeNameRule>(getClNamespace() + "queue*")},
287291
{"cusolverEigType_t", std::make_shared<TypeNameRule>("int64_t")},

clang/runtime/dpct-rt/include/dpl_extras/iterators.h.inc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,29 @@ template <typename KeyTp, typename _ValueTp> struct make_key_value_pair {
212212
};
213213
// DPCT_LABEL_END
214214

215+
// DPCT_LABEL_BEGIN|zip_iterator_impl|dpct::detail
216+
// DPCT_DEPENDENCY_EMPTY
217+
// DPCT_CODE
218+
template <class T> struct __zip_iterator_impl;
219+
template <class... Ts> struct __zip_iterator_impl<std::tuple<Ts...>> {
220+
using type = oneapi::dpl::zip_iterator<Ts...>;
221+
};
222+
// DPCT_LABEL_END
223+
215224
} // end namespace detail
216225

226+
// DPCT_LABEL_BEGIN|zip_iterator|dpct
227+
// DPCT_DEPENDENCY_BEGIN
228+
// DplExtrasIterators|zip_iterator_impl
229+
// DPCT_DEPENDENCY_END
230+
// DPCT_CODE
231+
// dpct::zip_iterator can only accept std::tuple type as template argument for
232+
// compatibility purpose. Please use oneapi::dpl::zip_iterator if you want to
233+
// pass iterator's types directly.
234+
template <typename... Ts>
235+
using zip_iterator = typename detail::__zip_iterator_impl<Ts...>::type;
236+
// DPCT_LABEL_END
237+
217238
// DPCT_LABEL_BEGIN|arg_index_input_iterator|dpct
218239
// DPCT_DEPENDENCY_BEGIN
219240
// DplExtrasIterators|make_key_value_pair

clang/test/dpct/helper_files_ref/include/dpl_extras/iterators.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,19 @@ template <typename KeyTp, typename _ValueTp> struct make_key_value_pair {
146146
}
147147
};
148148

149+
template <class T> struct __zip_iterator_impl;
150+
template <class... Ts> struct __zip_iterator_impl<std::tuple<Ts...>> {
151+
using type = oneapi::dpl::zip_iterator<Ts...>;
152+
};
153+
149154
} // end namespace detail
150155

156+
// dpct::zip_iterator can only accept std::tuple type as template argument for
157+
// compatibility purpose. Please use oneapi::dpl::zip_iterator if you want to
158+
// pass iterator's types directly.
159+
template <typename... Ts>
160+
using zip_iterator = typename detail::__zip_iterator_impl<Ts...>::type;
161+
151162
// arg_index_input_iterator is an iterator over a input iterator, with a index.
152163
// When dereferenced, it returns a key_value_pair, which can be interrogated for
153164
// the index key or the value from the input iterator
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: dpct --format-range=none --use-custom-helper=api -out-root %T/DplExtrasIterators/api_test6_out %s --cuda-include-path="%cuda-path/include" -- -x cuda --cuda-host-only -std=c++17 -fno-delayed-template-parsing
2+
// RUN: grep "IsCalled" %T/DplExtrasIterators/api_test6_out/MainSourceFiles.yaml | wc -l > %T/DplExtrasIterators/api_test6_out/count.txt
3+
// RUN: FileCheck --input-file %T/DplExtrasIterators/api_test6_out/count.txt --match-full-lines %s
4+
// RUN: rm -rf %T/DplExtrasIterators/api_test6_out
5+
6+
// CHECK: 4
7+
// TEST_FEATURE: DplExtrasIterators_zip_iterator
8+
9+
#include <thrust/iterator/zip_iterator.h>
10+
#include <thrust/tuple.h>
11+
12+
template<typename int_iterator>
13+
void foo() {
14+
typedef thrust::tuple<int_iterator, int_iterator> iterator_tuple;
15+
typedef thrust::zip_iterator<iterator_tuple> int_zip_iterator;
16+
}
17+
18+
int main() {
19+
return 0;
20+
}

clang/test/dpct/thrust-for-h2o4gpu.cu

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,20 @@ void foo() {
225225
{
226226
//CHECK: dpct::device_vector<int> int_in(3);
227227
//CHECK-NEXT: dpct::device_vector<float> float_in(3);
228-
//CHECK-NEXT: auto ret = oneapi::dpl::make_zip_iterator(std::make_tuple(int_in.begin(), float_in.begin()));
228+
//CHECK-NEXT: typedef dpct::device_vector<int>::iterator int_iterator;
229+
//CHECK-NEXT: typedef dpct::device_vector<float>::iterator float_iterator;
230+
//CHECK-NEXT: typedef std::tuple<int_iterator, float_iterator> iterator_tuple;
231+
//CHECK-NEXT: dpct::zip_iterator<iterator_tuple> ret = oneapi::dpl::make_zip_iterator(std::make_tuple(int_in.begin(), float_in.begin()));
229232
//CHECK-NEXT: auto arg = std::make_tuple(int_in.begin(), float_in.begin());
230-
//CHECK-NEXT: auto ret_1 = oneapi::dpl::make_zip_iterator(arg);
233+
//CHECK-NEXT: dpct::zip_iterator<iterator_tuple> ret_1 = oneapi::dpl::make_zip_iterator(arg);
231234
thrust::device_vector<int> int_in(3);
232235
thrust::device_vector<float> float_in(3);
233-
auto ret = thrust::make_zip_iterator(thrust::make_tuple(int_in.begin(), float_in.begin()));
236+
typedef thrust::device_vector<int>::iterator int_iterator;
237+
typedef thrust::device_vector<float>::iterator float_iterator;
238+
typedef thrust::tuple<int_iterator, float_iterator> iterator_tuple;
239+
thrust::zip_iterator<iterator_tuple> ret = thrust::make_zip_iterator(thrust::make_tuple(int_in.begin(), float_in.begin()));
234240
auto arg = thrust::make_tuple(int_in.begin(), float_in.begin());
235-
auto ret_1 = thrust::make_zip_iterator(arg);
241+
thrust::zip_iterator<iterator_tuple> ret_1 = thrust::make_zip_iterator(arg);
236242
}
237243

238244
{

0 commit comments

Comments
 (0)