@@ -16,6 +16,7 @@ limitations under the License. */
16
16
17
17
#include < glog/logging.h>
18
18
#include < algorithm>
19
+ #include < iterator>
19
20
#include < vector>
20
21
#include " paddle/fluid/framework/eigen.h"
21
22
#include " paddle/fluid/framework/op_registry.h"
@@ -94,8 +95,11 @@ class RowwiseTransformIterator;
94
95
template <typename T, typename DeviceContext>
95
96
class MidWiseTransformIterator ;
96
97
98
+ // NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17
97
99
template <typename T>
98
- class RowwiseTransformIterator <T, platform::CPUDeviceContext> {
100
+ class RowwiseTransformIterator <T, platform::CPUDeviceContext>
101
+ : public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t ,
102
+ T *, T &> {
99
103
public:
100
104
RowwiseTransformIterator (const T *ptr, int n) : ptr_(ptr), i_(0 ), n_(n) {}
101
105
@@ -126,7 +130,9 @@ class RowwiseTransformIterator<T, platform::CPUDeviceContext> {
126
130
};
127
131
128
132
template <typename T>
129
- class MidWiseTransformIterator <T, platform::CPUDeviceContext> {
133
+ class MidWiseTransformIterator <T, platform::CPUDeviceContext>
134
+ : public std::iterator<std::random_access_iterator_tag, T, std::ptrdiff_t ,
135
+ T *, T &> {
130
136
public:
131
137
MidWiseTransformIterator (const T *ptr, int n, int post)
132
138
: ptr_(ptr), i_(0 ), j_(0 ), n_(n), post_(post) {}
@@ -479,8 +485,13 @@ void ElemwiseGradComputeNoBroadcast(
479
485
const framework::Tensor &dout, int axis, framework::Tensor *dx,
480
486
framework::Tensor *dy, DX_OP dx_op, DY_OP dy_op) {
481
487
size_t N = static_cast <size_t >(framework::product (x_dim));
488
+ #if !defined(_WIN32)
482
489
platform::ForRange<DeviceContext> for_range (
483
490
ctx.template device_context <DeviceContext>(), N);
491
+ #else
492
+ platform::ForRange<DeviceContext> for_range (
493
+ ctx.device_context <DeviceContext>(), N);
494
+ #endif // !_WIN32
484
495
for_range (ElemwiseGradNoBroadcast<T, DX_OP, DY_OP>{
485
496
x.data <T>(), y.data <T>(), out.data <T>(), dout.data <T>(), dx_op, dy_op,
486
497
dx == nullptr ? nullptr : dx->mutable_data <T>(ctx.GetPlace ()),
@@ -633,13 +644,13 @@ void ElementwiseGradCompute(const framework::ExecutionContext &ctx,
633
644
634
645
template <typename Functor, typename DeviceContext, typename T,
635
646
typename OutType = T>
647
+
636
648
void ElementwiseComputeEx (const framework::ExecutionContext &ctx,
637
649
const framework::Tensor *x,
638
650
const framework::Tensor *y, int axis, Functor func,
639
651
framework::Tensor *z) {
640
652
TransformFunctor<Functor, T, DeviceContext, OutType> functor (
641
653
x, y, z, ctx.template device_context <DeviceContext>(), func);
642
-
643
654
auto x_dims = x->dims ();
644
655
auto y_dims_untrimed = y->dims ();
645
656
PADDLE_ENFORCE_GE (x_dims.size (), y_dims_untrimed.size (),
0 commit comments