@@ -14,29 +14,44 @@ limitations under the License. */
14
14
15
15
#pragma once
16
16
17
+ #include < algorithm>
18
+ #include < type_traits>
19
+
17
20
#include " paddle/fluid/platform/device_context.h"
18
21
#include " paddle/fluid/platform/enforce.h"
19
22
#include " paddle/fluid/platform/hostdevice.h"
20
23
#include " paddle/fluid/platform/place.h"
21
24
22
- #include < algorithm>
23
- #include < type_traits>
24
25
#ifdef __NVCC__
25
26
#include < thrust/execution_policy.h>
26
27
#include < thrust/transform.h>
27
- #include " paddle/fluid/platform/details/device_ptr_cast .h"
28
+ #include " paddle/fluid/platform/details/cuda_transform_iterator_cast .h"
28
29
#endif
29
30
30
31
namespace paddle {
31
32
namespace platform {
32
33
33
- // Transform on host or device. It provides the same API in std library.
34
+ // Transform applys a unary or a binary functor on each element in a
35
+ // range defined by a pair of iterators.
36
+ //
37
+ // - The specialization for CPU calls std::transform.
38
+ // - The specialization for CUDA calls thrust::tranform.
39
+ //
40
+ // NOTE: We need to define InputIter and OutputIter defined as
41
+ // different types, because the InputIter points op's inputs and
42
+ // OutputIter pints to op's outputs.
43
+ //
44
+ // NOTE: We don't assume that InputIter to be const InputType* and
45
+ // OutputIter to be OutputType*, because we might use a iterator
46
+ // class, paddle::fluid::operators::RowwiseTRansformIterator.
34
47
template <typename DeviceContext>
35
48
struct Transform {
49
+ // The unary version.
36
50
template <typename InputIter, typename OutputIter, typename UnaryOperation>
37
51
void operator ()(const DeviceContext& context, InputIter first, InputIter last,
38
52
OutputIter result, UnaryOperation op);
39
53
54
+ // The binary version.
40
55
template <typename InputIter1, typename InputIter2, typename OutputIter,
41
56
typename BinaryOperation>
42
57
void operator ()(const DeviceContext& context, InputIter1 first1,
@@ -70,8 +85,9 @@ struct Transform<platform::CUDADeviceContext> {
70
85
auto place = context.GetPlace ();
71
86
PADDLE_ENFORCE (is_gpu_place (place), " It must use GPU place." );
72
87
thrust::transform (thrust::cuda::par.on (context.stream ()),
73
- details::DevPtrCast (first), details::DevPtrCast (last),
74
- details::DevPtrCast (result), op);
88
+ details::CastToCUDATransformIterator (first),
89
+ details::CastToCUDATransformIterator (last),
90
+ details::CastToCUDATransformIterator (result), op);
75
91
}
76
92
77
93
template <typename InputIter1, typename InputIter2, typename OutputIter,
@@ -82,9 +98,10 @@ struct Transform<platform::CUDADeviceContext> {
82
98
auto place = context.GetPlace ();
83
99
PADDLE_ENFORCE (is_gpu_place (place), " It must use GPU place." );
84
100
thrust::transform (thrust::cuda::par.on (context.stream ()),
85
- details::DevPtrCast (first1), details::DevPtrCast (last1),
86
- details::DevPtrCast (first2), details::DevPtrCast (result),
87
- op);
101
+ details::CastToCUDATransformIterator (first1),
102
+ details::CastToCUDATransformIterator (last1),
103
+ details::CastToCUDATransformIterator (first2),
104
+ details::CastToCUDATransformIterator (result), op);
88
105
}
89
106
};
90
107
#endif
0 commit comments