Skip to content

Commit 47609ab

Browse files
authored
Document transform.h and fix cpplint errors (#9913)
1 parent b43d87c commit 47609ab

File tree

2 files changed

+50
-18
lines changed

2 files changed

+50
-18
lines changed

paddle/fluid/platform/details/device_ptr_cast.h renamed to paddle/fluid/platform/details/cuda_transform_iterator_cast.h

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,22 @@ limitations under the License. */
1818
#error device_ptr_cast must be include by .cu file
1919
#endif
2020

21-
#include <thrust/device_ptr.h>
21+
#include <type_traits> // For std::remove_pointer and std::is_pointer.
22+
23+
#include "thrust/device_ptr.h"
2224

2325
namespace paddle {
2426
namespace platform {
2527
namespace details {
28+
29+
// PointerToThrustDevicePtr has two speicalizations, one casts a (CUDA
30+
// device) pointer into thrust::device_ptr, the other keeps rest types
31+
// un-casted.
2632
template <typename T, bool is_ptr>
27-
struct DevicePtrCast;
33+
struct PointerToThrustDevicePtr;
2834

2935
template <typename T>
30-
struct DevicePtrCast<T, true> {
36+
struct PointerToThrustDevicePtr<T, true> {
3137
using ELEM = typename std::remove_pointer<T>::type;
3238
using RTYPE = thrust::device_ptr<ELEM>;
3339

@@ -37,17 +43,26 @@ struct DevicePtrCast<T, true> {
3743
};
3844

3945
template <typename T>
40-
struct DevicePtrCast<T, false> {
46+
struct PointerToThrustDevicePtr<T, false> {
4147
using RTYPE = T;
4248
inline RTYPE operator()(RTYPE it) const { return it; }
4349
};
4450

45-
// Cast T to thrust::device_ptr if T is a pointer.
46-
// Otherwise, e.g., T is a iterator, return T itself.
51+
// CastToCUDATransformIterator casts a pointer to thrust::device_ptr
52+
// so it could be used as the iterator of thrust::transform. It
53+
// doesn't cast other types.
54+
//
55+
// We need CastToCUDATransformIterator because it is often that we
56+
// want to use device memory pointers as transform iterators, e.g., to
57+
// transform a block of float32 to float16. In this case, we want
58+
// CastToCUDATransformIterator to cast float16/32 pointers to
59+
// thrust::device_ptr, otherwise they cannot work as the iterator
60+
// required by thrust::transform. At the same time, we don't want to
61+
// cast thrust::device_ptr to thrust::device_ptr repeatedly.
4762
template <typename T>
48-
auto DevPtrCast(T t) ->
49-
typename DevicePtrCast<T, std::is_pointer<T>::value>::RTYPE {
50-
DevicePtrCast<T, std::is_pointer<T>::value> cast;
63+
auto CastToCUDATransformIterator(T t) ->
64+
typename PointerToThrustDevicePtr<T, std::is_pointer<T>::value>::RTYPE {
65+
PointerToThrustDevicePtr<T, std::is_pointer<T>::value> cast;
5166
return cast(t);
5267
}
5368

paddle/fluid/platform/transform.h

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,44 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <algorithm>
18+
#include <type_traits>
19+
1720
#include "paddle/fluid/platform/device_context.h"
1821
#include "paddle/fluid/platform/enforce.h"
1922
#include "paddle/fluid/platform/hostdevice.h"
2023
#include "paddle/fluid/platform/place.h"
2124

22-
#include <algorithm>
23-
#include <type_traits>
2425
#ifdef __NVCC__
2526
#include <thrust/execution_policy.h>
2627
#include <thrust/transform.h>
27-
#include "paddle/fluid/platform/details/device_ptr_cast.h"
28+
#include "paddle/fluid/platform/details/cuda_transform_iterator_cast.h"
2829
#endif
2930

3031
namespace paddle {
3132
namespace platform {
3233

33-
// Transform on host or device. It provides the same API in std library.
34+
// Transform applys a unary or a binary functor on each element in a
35+
// range defined by a pair of iterators.
36+
//
37+
// - The specialization for CPU calls std::transform.
38+
// - The specialization for CUDA calls thrust::tranform.
39+
//
40+
// NOTE: We need to define InputIter and OutputIter defined as
41+
// different types, because the InputIter points op's inputs and
42+
// OutputIter pints to op's outputs.
43+
//
44+
// NOTE: We don't assume that InputIter to be const InputType* and
45+
// OutputIter to be OutputType*, because we might use a iterator
46+
// class, paddle::fluid::operators::RowwiseTRansformIterator.
3447
template <typename DeviceContext>
3548
struct Transform {
49+
// The unary version.
3650
template <typename InputIter, typename OutputIter, typename UnaryOperation>
3751
void operator()(const DeviceContext& context, InputIter first, InputIter last,
3852
OutputIter result, UnaryOperation op);
3953

54+
// The binary version.
4055
template <typename InputIter1, typename InputIter2, typename OutputIter,
4156
typename BinaryOperation>
4257
void operator()(const DeviceContext& context, InputIter1 first1,
@@ -70,8 +85,9 @@ struct Transform<platform::CUDADeviceContext> {
7085
auto place = context.GetPlace();
7186
PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place.");
7287
thrust::transform(thrust::cuda::par.on(context.stream()),
73-
details::DevPtrCast(first), details::DevPtrCast(last),
74-
details::DevPtrCast(result), op);
88+
details::CastToCUDATransformIterator(first),
89+
details::CastToCUDATransformIterator(last),
90+
details::CastToCUDATransformIterator(result), op);
7591
}
7692

7793
template <typename InputIter1, typename InputIter2, typename OutputIter,
@@ -82,9 +98,10 @@ struct Transform<platform::CUDADeviceContext> {
8298
auto place = context.GetPlace();
8399
PADDLE_ENFORCE(is_gpu_place(place), "It must use GPU place.");
84100
thrust::transform(thrust::cuda::par.on(context.stream()),
85-
details::DevPtrCast(first1), details::DevPtrCast(last1),
86-
details::DevPtrCast(first2), details::DevPtrCast(result),
87-
op);
101+
details::CastToCUDATransformIterator(first1),
102+
details::CastToCUDATransformIterator(last1),
103+
details::CastToCUDATransformIterator(first2),
104+
details::CastToCUDATransformIterator(result), op);
88105
}
89106
};
90107
#endif

0 commit comments

Comments
 (0)