Skip to content

Commit 37671ac

Browse files
committed
follow comments
1 parent 9e244a8 commit 37671ac

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

paddle/operators/elementwise_op_function.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,13 @@ inline void get_mid_dims(const framework::DDim& x_dims,
6060
}
6161

6262
template <typename T, typename Place>
63-
struct RowwiseTransformIterator;
63+
class RowwiseTransformIterator;
6464
template <typename T, typename Place>
65-
struct MidWiseTransformIterator;
65+
class MidWiseTransformIterator;
6666

6767
template <typename T>
68-
struct RowwiseTransformIterator<T, platform::CPUPlace> {
68+
class RowwiseTransformIterator<T, platform::CPUPlace> {
69+
public:
6970
RowwiseTransformIterator(const T* ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
7071

7172
RowwiseTransformIterator<T, platform::CPUPlace>& operator++() {
@@ -86,13 +87,15 @@ struct RowwiseTransformIterator<T, platform::CPUPlace> {
8687

8788
const T& operator*() { return ptr_[i_]; }
8889

90+
private:
8991
const T* ptr_;
9092
int i_;
9193
int64_t n_;
9294
};
9395

9496
template <typename T>
95-
struct MidWiseTransformIterator<T, platform::CPUPlace> {
97+
class MidWiseTransformIterator<T, platform::CPUPlace> {
98+
public:
9699
MidWiseTransformIterator(const T* ptr, int n, int post)
97100
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
98101

@@ -113,6 +116,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
113116

114117
const T& operator*() { return ptr_[i_]; }
115118

119+
private:
116120
const T* ptr_;
117121
int i_;
118122
int64_t j_;
@@ -122,7 +126,7 @@ struct MidWiseTransformIterator<T, platform::CPUPlace> {
122126

123127
#ifdef __NVCC__
124128
template <typename T>
125-
struct RowwiseTransformIterator<T, platform::GPUPlace>
129+
class RowwiseTransformIterator<T, platform::GPUPlace>
126130
: public thrust::iterator_adaptor<
127131
RowwiseTransformIterator<T, platform::GPUPlace>, const T*> {
128132
public:
@@ -142,7 +146,7 @@ struct RowwiseTransformIterator<T, platform::GPUPlace>
142146
};
143147

144148
template <typename T>
145-
struct MidWiseTransformIterator<T, platform::GPUPlace>
149+
class MidWiseTransformIterator<T, platform::GPUPlace>
146150
: public thrust::iterator_adaptor<
147151
MidWiseTransformIterator<T, platform::GPUPlace>, const T*> {
148152
public:
@@ -164,7 +168,8 @@ struct MidWiseTransformIterator<T, platform::GPUPlace>
164168
#endif
165169

166170
template <typename Functor, typename T, typename Place>
167-
struct TransformFunctor {
171+
class TransformFunctor {
172+
public:
168173
TransformFunctor(const framework::Tensor* x, const framework::Tensor* y,
169174
framework::Tensor* z, const platform::DeviceContext& ctx,
170175
Functor func)
@@ -192,6 +197,7 @@ struct TransformFunctor {
192197
z_, func_);
193198
}
194199

200+
private:
195201
const T* x_;
196202
const T* y_;
197203
T* z_;

0 commit comments

Comments
 (0)