We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5bd1e73 commit e03b574Copy full SHA for e03b574
paddle/operators/momentum_op.h
@@ -44,15 +44,11 @@ class MomentumOpKernel : public framework::OpKernel<T> {
44
auto g = framework::EigenVector<T>::Flatten(*grad);
45
auto* lr = learning_rate->data<T>();
46
47
- auto place = ctx.GetEigenDevice<platform::CPUPlace>();
48
-
49
- Eigen::DSizes<int, 1> grad_dsize(grad->numel());
50
51
- v_out.device(place) = v * mu + g;
+ v_out = v * mu + g;
52
if (use_nesterov) {
53
- p_out.device(place) = p - (g - v_out * mu) * lr[0];
+ p_out = p - (g - v_out * mu) * lr[0];
54
} else {
55
- p_out.device(place) = p - lr[0] * v_out;
+ p_out = p - lr[0] * v_out;
56
}
57
58
};
0 commit comments