Skip to content

Commit 856c26f

Browse files
authored
fix elementwise (#13146)
1 parent 4fa3cee commit 856c26f

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

paddle/fluid/operators/elementwise_op_function.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616

1717
#include <glog/logging.h>
1818
#include <algorithm>
19+
#include <iterator>
1920
#include <vector>
2021
#include "paddle/fluid/framework/eigen.h"
2122
#include "paddle/fluid/framework/op_registry.h"
@@ -94,8 +95,11 @@ class RowwiseTransformIterator;
9495
template <typename T, typename DeviceContext>
9596
class MidWiseTransformIterator;
9697

98+
// NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17
9799
template <typename T>
98-
class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
100+
class RowwiseTransformIterator<T, platform::CPUDeviceContext>
101+
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
102+
T *, T &> {
99103
public:
100104
RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {}
101105

@@ -126,7 +130,9 @@ class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
126130
};
127131

128132
template <typename T>
129-
class MidWiseTransformIterator<T, platform::CPUDeviceContext> {
133+
class MidWiseTransformIterator<T, platform::CPUDeviceContext>
134+
: public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t,
135+
T *, T &> {
130136
public:
131137
MidWiseTransformIterator(const T *ptr, int n, int post)
132138
: ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {}
@@ -479,8 +485,13 @@ void ElemwiseGradComputeNoBroadcast(
479485
const framework::Tensor &dout, int axis, framework::Tensor *dx,
480486
framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
481487
size_t N = static_cast<size_t>(framework::product(x_dim));
488+
#if !defined(_WIN32)
482489
platform::ForRange<DeviceContext> for_range(
483490
ctx.template device_context<DeviceContext>(), N);
491+
#else
492+
platform::ForRange<DeviceContext> for_range(
493+
ctx.device_context<DeviceContext>(), N);
494+
#endif // !_WIN32
484495
for_range(ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
485496
x.data<T>(), y.data<T>(), out.data<T>(), dout.data<T>(), dx_op, dy_op,
486497
dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()),
@@ -633,13 +644,13 @@ void ElementwiseGradCompute(const framework::ExecutionContext &ctx,
633644

634645
template <typename Functor, typename DeviceContext, typename T,
635646
typename OutType = T>
647+
636648
void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
637649
const framework::Tensor *x,
638650
const framework::Tensor *y, int axis, Functor func,
639651
framework::Tensor *z) {
640652
TransformFunctor<Functor, T, DeviceContext, OutType> functor(
641653
x, y, z, ctx.template device_context<DeviceContext>(), func);
642-
643654
auto x_dims = x->dims();
644655
auto y_dims_untrimed = y->dims();
645656
PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),

0 commit comments

Comments
 (0)