@@ -72,6 +72,25 @@ class StackOpMaker : public framework::OpProtoAndCheckerMaker {
7272 }
7373};
7474
75+ template <typename VecXType, typename T>
76+ struct StackFunctor {
77+ HOSTDEVICE StackFunctor (const VecXType &x, T *y, int n, int post)
78+ : x_(x), y_(y), n_(n), post_(post) {}
79+
80+ HOSTDEVICE void operator ()(int idx) {
81+ int i = idx / (n_ * post_);
82+ int which_x = idx / post_ - i * n_;
83+ int x_index = i * post_ + idx % post_;
84+ y_[idx] = x_[which_x][x_index];
85+ }
86+
87+ private:
88+ VecXType x_;
89+ T *y_;
90+ int n_;
91+ int post_;
92+ };
93+
7594template <typename VecDxType, typename T>
7695struct StackGradFunctor {
7796 HOSTDEVICE StackGradFunctor (const VecDxType &dx, const T *dy, int n, int post)
@@ -91,6 +110,14 @@ struct StackGradFunctor {
91110 int post_;
92111};
93112
113+ template <typename DeviceContext, typename VecXType, typename T>
114+ static inline void StackFunctorForRange (const DeviceContext &ctx,
115+ const VecXType &x, T *y, int total_num,
116+ int n, int post) {
117+ platform::ForRange<DeviceContext> for_range (ctx, total_num);
118+ for_range (StackFunctor<VecXType, T>(x, y, n, post));
119+ }
120+
94121template <typename DeviceContext, typename VecDxType, typename T>
95122static inline void StackGradFunctorForRange (const DeviceContext &ctx,
96123 const VecDxType &dx, const T *dy,
0 commit comments