@@ -60,12 +60,13 @@ inline void get_mid_dims(const framework::DDim& x_dims,
60
60
}
61
61
62
62
template <typename T, typename Place>
63
- struct RowwiseTransformIterator ;
63
+ class RowwiseTransformIterator ;
64
64
template <typename T, typename Place>
65
- struct MidWiseTransformIterator ;
65
+ class MidWiseTransformIterator ;
66
66
67
67
template <typename T>
68
- struct RowwiseTransformIterator <T, platform::CPUPlace> {
68
+ class RowwiseTransformIterator <T, platform::CPUPlace> {
69
+ public:
69
70
RowwiseTransformIterator (const T* ptr, int n) : ptr_(ptr), i_(0 ), n_(n) {}
70
71
71
72
RowwiseTransformIterator<T, platform::CPUPlace>& operator ++() {
@@ -86,13 +87,15 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> {
86
87
87
88
const T& operator *() { return ptr_[i_]; }
88
89
90
+ private:
89
91
const T* ptr_;
90
92
int i_;
91
93
int64_t n_;
92
94
};
93
95
94
96
template <typename T>
95
- struct MidWiseTransformIterator <T, platform::CPUPlace> {
97
+ class MidWiseTransformIterator <T, platform::CPUPlace> {
98
+ public:
96
99
MidWiseTransformIterator (const T* ptr, int n, int post)
97
100
: ptr_(ptr), i_(0 ), j_(0 ), n_(n), post_(post) {}
98
101
@@ -113,6 +116,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
113
116
114
117
const T& operator *() { return ptr_[i_]; }
115
118
119
+ private:
116
120
const T* ptr_;
117
121
int i_;
118
122
int64_t j_;
@@ -122,7 +126,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
122
126
123
127
#ifdef __NVCC__
124
128
template <typename T>
125
- struct RowwiseTransformIterator <T, platform::GPUPlace>
129
+ class RowwiseTransformIterator <T, platform::GPUPlace>
126
130
: public thrust::iterator_adaptor<
127
131
RowwiseTransformIterator<T, platform::GPUPlace>, const T*> {
128
132
public:
@@ -142,7 +146,7 @@ struct RowwiseTransformIterator<T, platform::GPUPlace>
142
146
};
143
147
144
148
template <typename T>
145
- struct MidWiseTransformIterator <T, platform::GPUPlace>
149
+ class MidWiseTransformIterator <T, platform::GPUPlace>
146
150
: public thrust::iterator_adaptor<
147
151
MidWiseTransformIterator<T, platform::GPUPlace>, const T*> {
148
152
public:
@@ -164,7 +168,8 @@ struct MidWiseTransformIterator<T, platform::GPUPlace>
164
168
#endif
165
169
166
170
template <typename Functor, typename T, typename Place>
167
- struct TransformFunctor {
171
+ class TransformFunctor {
172
+ public:
168
173
TransformFunctor (const framework::Tensor* x, const framework::Tensor* y,
169
174
framework::Tensor* z, const platform::DeviceContext& ctx,
170
175
Functor func)
@@ -192,6 +197,7 @@ struct TransformFunctor {
192
197
z_, func_);
193
198
}
194
199
200
+ private:
195
201
const T* x_;
196
202
const T* y_;
197
203
T* z_;
0 commit comments