Skip to content

Commit 3644446

Browse files
author
chengduo
authored
Merge pull request #6229 from chengduoZH/profiling/updata_elementwise_op
[Profiling] Update elementwise op
2 parents 1b6804f + 37671ac commit 3644446

File tree

2 files changed

+190
-1
lines changed

2 files changed

+190
-1
lines changed

paddle/operators/elementwise_add_op.h

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,48 @@
1919
namespace paddle {
2020
namespace operators {
2121

22+
template <typename T>
23+
struct AddFunctor {
24+
HOSTDEVICE T operator()(T a, T b) const { return a + b; }
25+
};
26+
2227
template <typename Place, typename T>
2328
class ElementwiseAddKernel : public framework::OpKernel<T> {
2429
public:
2530
void Compute(const framework::ExecutionContext& ctx) const override {
26-
ElementwiseCompute<EigenAddFunctor, Place, T>(ctx);
31+
using Tensor = framework::Tensor;
32+
33+
auto* x = ctx.Input<Tensor>("X");
34+
auto* y = ctx.Input<Tensor>("Y");
35+
auto* z = ctx.Output<Tensor>("Out");
36+
z->mutable_data<T>(ctx.GetPlace());
37+
TransformFunctor<AddFunctor<T>, T, Place> functor(
38+
x, y, z, ctx.device_context(), AddFunctor<T>());
39+
40+
auto x_dims = x->dims();
41+
auto y_dims = y->dims();
42+
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
43+
"Rank of first input must >= rank of second input.");
44+
45+
if (x_dims == y_dims) {
46+
functor.Run();
47+
return;
48+
}
49+
50+
int axis = ctx.Attr<int>("axis");
51+
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
52+
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
53+
"Axis should be in range [0, x_dims)");
54+
55+
int pre, n, post;
56+
get_mid_dims(x_dims, y_dims, axis, pre, n, post);
57+
if (post == 1) {
58+
functor.RunRowWise(n, pre);
59+
return;
60+
} else {
61+
functor.RunMidWise(n, pre, post);
62+
return;
63+
}
2764
}
2865
};
2966

paddle/operators/elementwise_op_function.h

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
#include "paddle/framework/eigen.h"
1717
#include "paddle/framework/op_registry.h"
1818
#include "paddle/framework/operator.h"
19+
#include "paddle/platform/transform.h"
20+
21+
#ifdef __NVCC__
22+
#include <thrust/iterator/iterator_adaptor.h>
23+
#endif
1924

2025
#include "paddle/operators/math/math_function.h"
2126

@@ -54,6 +59,153 @@ inline void get_mid_dims(const framework::DDim& x_dims,
5459
}
5560
}
5661

62+
template <typename T, typename Place>
63+
class RowwiseTransformIterator;
64+
template <typename T, typename Place>
65+
class MidWiseTransformIterator;
66+
67+
template <typename T>
68+
class RowwiseTransformIterator<T, platform::CPUPlace> {
69+
public:
70+
RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
71+
72+
RowwiseTransformIterator<T, platform::CPUPlace>& operator++() {
73+
++i_;
74+
i_ %= n_;
75+
return *this;
76+
}
77+
78+
bool operator==(
79+
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
80+
return (ptr_ + i_) == &(*rhs);
81+
}
82+
83+
bool operator!=(
84+
const RowwiseTransformIterator<T, platform::CPUPlace>& rhs) const {
85+
return (ptr_ + i_) != &(*rhs);
86+
}
87+
88+
const T& operator*() { return ptr_[i_]; }
89+
90+
private:
91+
const T* ptr_;
92+
int i_;
93+
int64_t n_;
94+
};
95+
96+
template <typename T>
97+
class MidWiseTransformIterator<T, platform::CPUPlace> {
98+
public:
99+
MidWiseTransformIterator(const T* ptr, int n, int post)
100+
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
101+
102+
MidWiseTransformIterator<T, platform::CPUPlace>& operator++() {
103+
i_ = (++j_ / post_) % n_;
104+
return *this;
105+
}
106+
107+
bool operator==(
108+
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
109+
return (ptr_ + i_) == &(*rhs);
110+
}
111+
112+
bool operator!=(
113+
const MidWiseTransformIterator<T, platform::CPUPlace>& rhs) const {
114+
return (ptr_ + i_) != &(*rhs);
115+
}
116+
117+
const T& operator*() { return ptr_[i_]; }
118+
119+
private:
120+
const T* ptr_;
121+
int i_;
122+
int64_t j_;
123+
int64_t n_;
124+
int post_;
125+
};
126+
127+
#ifdef __NVCC__
128+
template <typename T>
129+
class RowwiseTransformIterator<T, platform::GPUPlace>
130+
: public thrust::iterator_adaptor<
131+
RowwiseTransformIterator<T, platform::GPUPlace>, const T*> {
132+
public:
133+
typedef thrust::iterator_adaptor<
134+
RowwiseTransformIterator<T, platform::GPUPlace>, const T*>
135+
super_t;
136+
HOSTDEVICE RowwiseTransformIterator(const T* x, int n)
137+
: super_t(x), begin_(x), n_(n){};
138+
friend class thrust::iterator_core_access;
139+
140+
private:
141+
unsigned int n_;
142+
const T* begin_;
143+
HOSTDEVICE typename super_t::reference dereference() const {
144+
return *(begin_ + (this->base() - begin_) % n_);
145+
}
146+
};
147+
148+
template <typename T>
149+
class MidWiseTransformIterator<T, platform::GPUPlace>
150+
: public thrust::iterator_adaptor<
151+
MidWiseTransformIterator<T, platform::GPUPlace>, const T*> {
152+
public:
153+
typedef thrust::iterator_adaptor<
154+
MidWiseTransformIterator<T, platform::GPUPlace>, const T*>
155+
super_t;
156+
HOSTDEVICE MidWiseTransformIterator(const T* x, int n, int post)
157+
: super_t(x), begin_(x), n_(n), post_(post){};
158+
friend class thrust::iterator_core_access;
159+
160+
private:
161+
unsigned int post_;
162+
unsigned int n_;
163+
const T* begin_;
164+
HOSTDEVICE typename super_t::reference dereference() const {
165+
return *(begin_ + (((this->base() - begin_) / post_) % n_));
166+
}
167+
};
168+
#endif
169+
170+
template <typename Functor, typename T, typename Place>
171+
class TransformFunctor {
172+
public:
173+
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
174+
framework::Tensor* z, const platform::DeviceContext& ctx,
175+
Functor func)
176+
: x_(x->data<T>()),
177+
y_(y->data<T>()),
178+
z_(z->mutable_data<T>(ctx.GetPlace())),
179+
nx_(x->numel()),
180+
ctx_(ctx),
181+
func_(func) {}
182+
183+
inline void Run() const {
184+
platform::Transform<Place> trans;
185+
trans(ctx_, x_, x_ + nx_, y_, z_, func_);
186+
}
187+
188+
inline void RunRowWise(int n, int pre) const {
189+
platform::Transform<Place> trans;
190+
trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator<T, Place>(y_, n), z_,
191+
func_);
192+
}
193+
194+
inline void RunMidWise(int n, int pre, int post) const {
195+
platform::Transform<Place> trans;
196+
trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator<T, Place>(y_, n, post),
197+
z_, func_);
198+
}
199+
200+
private:
201+
const T* x_;
202+
const T* y_;
203+
T* z_;
204+
int64_t nx_;
205+
const platform::DeviceContext& ctx_;
206+
Functor func_;
207+
};
208+
57209
#define EIGEN_FUNCTOR(name, eigen_op) \
58210
struct Eigen##name##Functor { \
59211
template <typename Place, typename T> \

0 commit comments

Comments
 (0)