|
16 | 16 | #include "paddle/framework/eigen.h"
|
17 | 17 | #include "paddle/framework/op_registry.h"
|
18 | 18 | #include "paddle/framework/operator.h"
|
| 19 | +#include "paddle/platform/transform.h" |
| 20 | + |
| 21 | +#ifdef __NVCC__ |
| 22 | +#include <thrust/iterator/iterator_adaptor.h> |
| 23 | +#endif |
19 | 24 |
|
20 | 25 | #include "paddle/operators/math/math_function.h"
|
21 | 26 |
|
@@ -54,6 +59,153 @@ inline void get_mid_dims(const framework::DDim& x_dims,
|
54 | 59 | }
|
55 | 60 | }
|
56 | 61 |
|
| 62 | +template <typename T, typename Place> |
| 63 | +class RowwiseTransformIterator; |
| 64 | +template <typename T, typename Place> |
| 65 | +class MidWiseTransformIterator; |
| 66 | + |
| 67 | +template <typename T> |
| 68 | +class RowwiseTransformIterator<T, platform::CPUPlace> { |
| 69 | + public: |
| 70 | + RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {} |
| 71 | + |
| 72 | + RowwiseTransformIterator<T, platform::CPUPlace>& operator++() { |
| 73 | + ++i_; |
| 74 | + i_ %= n_; |
| 75 | + return *this; |
| 76 | + } |
| 77 | + |
| 78 | + bool operator==( |
| 79 | + const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const { |
| 80 | + return (ptr_ + i_) == &(*rhs); |
| 81 | + } |
| 82 | + |
| 83 | + bool operator!=( |
| 84 | + const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const { |
| 85 | + return (ptr_ + i_) != &(*rhs); |
| 86 | + } |
| 87 | + |
| 88 | + const T& operator*() { return ptr_[i_]; } |
| 89 | + |
| 90 | + private: |
| 91 | + const T* ptr_; |
| 92 | + int i_; |
| 93 | + int64_t n_; |
| 94 | +}; |
| 95 | + |
| 96 | +template <typename T> |
| 97 | +class MidWiseTransformIterator<T, platform::CPUPlace> { |
| 98 | + public: |
| 99 | + MidWiseTransformIterator(const T* ptr, int n, int post) |
| 100 | + : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} |
| 101 | + |
| 102 | + MidWiseTransformIterator<T, platform::CPUPlace>& operator++() { |
| 103 | + i_ = (++j_ / post_) % n_; |
| 104 | + return *this; |
| 105 | + } |
| 106 | + |
| 107 | + bool operator==( |
| 108 | + const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const { |
| 109 | + return (ptr_ + i_) == &(*rhs); |
| 110 | + } |
| 111 | + |
| 112 | + bool operator!=( |
| 113 | + const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const { |
| 114 | + return (ptr_ + i_) != &(*rhs); |
| 115 | + } |
| 116 | + |
| 117 | + const T& operator*() { return ptr_[i_]; } |
| 118 | + |
| 119 | + private: |
| 120 | + const T* ptr_; |
| 121 | + int i_; |
| 122 | + int64_t j_; |
| 123 | + int64_t n_; |
| 124 | + int post_; |
| 125 | +}; |
| 126 | + |
| 127 | +#ifdef __NVCC__ |
| 128 | +template <typename T> |
| 129 | +class RowwiseTransformIterator<T, platform::GPUPlace> |
| 130 | + : public thrust::iterator_adaptor< |
| 131 | + RowwiseTransformIterator<T, platform::GPUPlace>, const T*> { |
| 132 | + public: |
| 133 | + typedef thrust::iterator_adaptor< |
| 134 | + RowwiseTransformIterator<T, platform::GPUPlace>, const T*> |
| 135 | + super_t; |
| 136 | + HOSTDEVICE RowwiseTransformIterator(const T* x, int n) |
| 137 | + : super_t(x), begin_(x), n_(n){}; |
| 138 | + friend class thrust::iterator_core_access; |
| 139 | + |
| 140 | + private: |
| 141 | + unsigned int n_; |
| 142 | + const T* begin_; |
| 143 | + HOSTDEVICE typename super_t::reference dereference() const { |
| 144 | + return *(begin_ + (this->base() - begin_) % n_); |
| 145 | + } |
| 146 | +}; |
| 147 | + |
| 148 | +template <typename T> |
| 149 | +class MidWiseTransformIterator<T, platform::GPUPlace> |
| 150 | + : public thrust::iterator_adaptor< |
| 151 | + MidWiseTransformIterator<T, platform::GPUPlace>, const T*> { |
| 152 | + public: |
| 153 | + typedef thrust::iterator_adaptor< |
| 154 | + MidWiseTransformIterator<T, platform::GPUPlace>, const T*> |
| 155 | + super_t; |
| 156 | + HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post) |
| 157 | + : super_t(x), begin_(x), n_(n), post_(post){}; |
| 158 | + friend class thrust::iterator_core_access; |
| 159 | + |
| 160 | + private: |
| 161 | + unsigned int post_; |
| 162 | + unsigned int n_; |
| 163 | + const T* begin_; |
| 164 | + HOSTDEVICE typename super_t::reference dereference() const { |
| 165 | + return *(begin_ + (((this->base() - begin_) / post_) % n_)); |
| 166 | + } |
| 167 | +}; |
| 168 | +#endif |
| 169 | + |
| 170 | +template <typename Functor, typename T, typename Place> |
| 171 | +class TransformFunctor { |
| 172 | + public: |
| 173 | + TransformFunctor(const framework::Tensor* x, const framework::Tensor* y, |
| 174 | + framework::Tensor* z, const platform::DeviceContext& ctx, |
| 175 | + Functor func) |
| 176 | + : x_(x->data<T>()), |
| 177 | + y_(y->data<T>()), |
| 178 | + z_(z->mutable_data<T>(ctx.GetPlace())), |
| 179 | + nx_(x->numel()), |
| 180 | + ctx_(ctx), |
| 181 | + func_(func) {} |
| 182 | + |
| 183 | + inline void Run() const { |
| 184 | + platform::Transform<Place> trans; |
| 185 | + trans(ctx_, x_, x_ + nx_, y_, z_, func_); |
| 186 | + } |
| 187 | + |
| 188 | + inline void RunRowWise(int n, int pre) const { |
| 189 | + platform::Transform<Place> trans; |
| 190 | + trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, Place>(y_, n), z_, |
| 191 | + func_); |
| 192 | + } |
| 193 | + |
| 194 | + inline void RunMidWise(int n, int pre, int post) const { |
| 195 | + platform::Transform<Place> trans; |
| 196 | + trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator<T, Place>(y_, n, post), |
| 197 | + z_, func_); |
| 198 | + } |
| 199 | + |
| 200 | + private: |
| 201 | + const T* x_; |
| 202 | + const T* y_; |
| 203 | + T* z_; |
| 204 | + int64_t nx_; |
| 205 | + const platform::DeviceContext& ctx_; |
| 206 | + Functor func_; |
| 207 | +}; |
| 208 | + |
57 | 209 | #define EIGEN_FUNCTOR(name, eigen_op) \
|
58 | 210 | struct Eigen##name##Functor { \
|
59 | 211 | template <typename Place, typename T> \
|
|
0 commit comments