@@ -85,7 +85,7 @@ struct IdentityGrad {
85
85
HOSTDEVICE T operator ()(T x, T y, T out, T dout) const { return dout; }
86
86
};
87
87
88
- template <typename DeviceContext, typename T>
88
+ template <typename DeviceContext, typename T>
89
89
void default_elementwise_add_grad (const framework::ExecutionContext& ctx,
90
90
const framework::Tensor* x,
91
91
const framework::Tensor* y,
@@ -100,16 +100,15 @@ void default_elementwise_add_grad(const framework::ExecutionContext& ctx,
100
100
IdentityGrad<T>());
101
101
}
102
102
103
- template <typename DeviceContext, typename T>
103
+ template <typename DeviceContext, typename T>
104
104
typename std::enable_if<
105
105
std::is_floating_point<T>::value &&
106
106
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
107
107
elementwise_add_grad (const framework::ExecutionContext& ctx,
108
- const framework::Tensor* x,
109
- const framework::Tensor* y,
108
+ const framework::Tensor* x, const framework::Tensor* y,
110
109
const framework::Tensor* out,
111
- const framework::Tensor* dout,
112
- framework::Tensor* dx, framework::Tensor* dy) {
110
+ const framework::Tensor* dout, framework::Tensor* dx,
111
+ framework::Tensor* dy) {
113
112
auto blas = math::GetBlas<DeviceContext, T>(ctx);
114
113
115
114
if (dx) {
@@ -123,16 +122,15 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
123
122
}
124
123
}
125
124
126
- template <typename DeviceContext, typename T>
125
+ template <typename DeviceContext, typename T>
127
126
typename std::enable_if<
128
127
!std::is_floating_point<T>::value ||
129
128
!std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
130
129
elementwise_add_grad (const framework::ExecutionContext& ctx,
131
- const framework::Tensor* x,
132
- const framework::Tensor* y,
130
+ const framework::Tensor* x, const framework::Tensor* y,
133
131
const framework::Tensor* out,
134
- const framework::Tensor* dout,
135
- framework::Tensor* dx, framework::Tensor* dy) {
132
+ const framework::Tensor* dout, framework::Tensor* dx,
133
+ framework::Tensor* dy) {
136
134
default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
137
135
}
138
136
@@ -152,8 +150,8 @@ class ElementwiseAddGradKernel : public framework::OpKernel<T> {
152
150
if (platform::is_cpu_place (ctx.GetPlace ()) && (x->dims () == y->dims ())) {
153
151
elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
154
152
} else {
155
- default_elementwise_add_grad<DeviceContext, T>(
156
- ctx, x, y, out, dout, dx, dy);
153
+ default_elementwise_add_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
154
+ dy);
157
155
}
158
156
}
159
157
};
0 commit comments