Skip to content

Commit 2e355f0

Browse files
authored
Fix attribute naming for momentum_op (#5453)
* Fix attribute naming for momentum_op * Fix minor typo in comment * Fix attribute name * Fix names in test_optimizer * Fix python wrapper
1 parent c88f98c commit 2e355f0

File tree

5 files changed

+7
-7
lines changed

5 files changed

+7
-7
lines changed

paddle/operators/momentum_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
7575
AddOutput("VelocityOut", "(Tensor) Output updated velocity");
7676

7777
AddAttr<float>("mu", "(float) Momentum coefficient");
78-
AddAttr<bool>("useNesterov",
78+
AddAttr<bool>("use_nesterov",
7979
"(bool, default false) "
8080
"Use Nesterov Momentum")
8181
.SetDefault(false);

paddle/operators/momentum_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
3434
velocity_out->mutable_data<T>(ctx.GetPlace());
3535

3636
float mu = ctx.Attr<float>("mu");
37-
bool use_nesterov = ctx.Attr<bool>("useNesterov");
37+
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
3838

3939
auto p_out = framework::EigenVector<T>::Flatten(*param_out);
4040
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);

python/paddle/v2/framework/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def _append_optimize_op(self, block, param_and_grad):
297297
"VelocityOut": velocity_acc
298298
},
299299
attrs={"mu": self._momentum,
300-
"useNesterov": self._use_nesterov})
300+
"use_nesterov": self._use_nesterov})
301301

302302
return momentum_op
303303

python/paddle/v2/framework/tests/test_momentum_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_check_output(self):
3737

3838

3939
class TestMomentumOp2(OpTest):
40-
'''Test Momentum with defaukt values for attributes
40+
'''Test Momentum with default values for attributes
4141
'''
4242

4343
def setUp(self):
@@ -57,7 +57,7 @@ def setUp(self):
5757
'LearningRate': learning_rate
5858
}
5959

60-
self.attrs = {'mu': mu, 'useNesterov': use_nesterov}
60+
self.attrs = {'mu': mu, 'use_nesterov': use_nesterov}
6161

6262
velocity_out = mu * velocity + grad
6363
if use_nesterov:

python/paddle/v2/framework/tests/test_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_vanilla_momentum_optimizer(self):
9898
self.assertEqual(len(opts), 1)
9999
sgd_op = opts[0]
100100
self.assertEqual(sgd_op.type, "momentum")
101-
self.assertFalse(sgd_op.attr('useNesterov'))
101+
self.assertFalse(sgd_op.attr('use_nesterov'))
102102

103103
# Check accumulators
104104
accumulators = momentum_optimizer.get_accumulators()
@@ -143,7 +143,7 @@ def test_nesterov_momentum_optimizer(self):
143143
self.assertEqual(len(opts), 1)
144144
sgd_op = opts[0]
145145
self.assertEqual(sgd_op.type, "momentum")
146-
self.assertTrue(sgd_op.attr('useNesterov'))
146+
self.assertTrue(sgd_op.attr('use_nesterov'))
147147

148148
# Check accumulators
149149
accumulators = momentum_optimizer.get_accumulators()

0 commit comments

Comments
 (0)