@@ -72,25 +72,6 @@ class StackOpMaker : public framework::OpProtoAndCheckerMaker {
72
72
}
73
73
};
74
74
75
- template <typename VecXType, typename T>
76
- struct StackFunctor {
77
- HOSTDEVICE StackFunctor (const VecXType &x, T *y, int n, int post)
78
- : x_(x), y_(y), n_(n), post_(post) {}
79
-
80
- HOSTDEVICE void operator ()(int idx) {
81
- int i = idx / (n_ * post_);
82
- int which_x = idx / post_ - i * n_;
83
- int x_index = i * post_ + idx % post_;
84
- y_[idx] = x_[which_x][x_index];
85
- }
86
-
87
- private:
88
- VecXType x_;
89
- T *y_;
90
- int n_;
91
- int post_;
92
- };
93
-
94
75
template <typename VecDxType, typename T>
95
76
struct StackGradFunctor {
96
77
HOSTDEVICE StackGradFunctor (const VecDxType &dx, const T *dy, int n, int post)
@@ -110,14 +91,6 @@ struct StackGradFunctor {
110
91
int post_;
111
92
};
112
93
113
- template <typename DeviceContext, typename VecXType, typename T>
114
- static inline void StackFunctorForRange (const DeviceContext &ctx,
115
- const VecXType &x, T *y, int total_num,
116
- int n, int post) {
117
- platform::ForRange<DeviceContext> for_range (ctx, total_num);
118
- for_range (StackFunctor<VecXType, T>(x, y, n, post));
119
- }
120
-
121
94
template <typename DeviceContext, typename VecDxType, typename T>
122
95
static inline void StackGradFunctorForRange (const DeviceContext &ctx,
123
96
const VecDxType &dx, const T *dy,
0 commit comments