Skip to content

Commit 6caea45

Browse files
committed
add TestFtrlOptimizer
1 parent ca341db commit 6caea45

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

python/paddle/fluid/tests/unittests/test_optimizer.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,5 +434,71 @@ def test_decayed_adagrad_optimizer(self):
434434
self.assertAlmostEqual(init_ops[1].attr('value'), 0.0)
435435

436436

437+
class TestFtrlOptimizer(unittest.TestCase):
438+
class MockFtrl(optimizer.FtrlOptimizer):
439+
def get_accumulators(self):
440+
return self._accumulators
441+
442+
def get_squared_str(self):
443+
return self._squared_acc_str
444+
445+
def get_linear_str(self):
446+
return self._linear_acc_str
447+
448+
def test_ftrl_optimizer(self):
449+
init_program = framework.Program()
450+
program = framework.Program()
451+
block = program.global_block()
452+
mul_x = block.create_parameter(
453+
dtype="float32",
454+
shape=[5, 10],
455+
lod_level=0,
456+
name="mul.x",
457+
optimize_attr={'learning_rate': 1.1})
458+
mul_y = block.create_var(
459+
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
460+
mul_out = block.create_var(
461+
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
462+
block.append_op(
463+
type="mul",
464+
inputs={"X": mul_x,
465+
"Y": mul_y},
466+
outputs={"Out": mul_out},
467+
attrs={"x_num_col_dims": 1})
468+
mean_out = block.create_var(
469+
dtype="float32", shape=[1], lod_level=0, name="mean.out")
470+
block.append_op(
471+
type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
472+
learning_rate = 0.01
473+
ftrl_optimizer = self.MockFtrl(
474+
learning_rate=learning_rate, l1=0.0, l2=0.0, lr_power=-0.5)
475+
params_grads = append_backward(mean_out)
476+
self.assertEqual(len(params_grads), 1)
477+
self.assertEqual(len(ftrl_optimizer.get_accumulators()), 0)
478+
opts = ftrl_optimizer.create_optimization_pass(params_grads, mul_out,
479+
init_program)
480+
self.assertEqual(len(opts), 3)
481+
self.assertEqual([op.type for op in opts],
482+
["fill_constant", "elementwise_mul", "ftrl"])
483+
484+
# Check accumulators
485+
accumulators = ftrl_optimizer.get_accumulators()
486+
self.assertEqual(len(accumulators), 2)
487+
self.assertTrue(ftrl_optimizer.get_squared_str() in accumulators)
488+
self.assertTrue(ftrl_optimizer.get_linear_str() in accumulators)
489+
squared_acc = accumulators[ftrl_optimizer.get_squared_str()]
490+
linear_acc = accumulators[ftrl_optimizer.get_linear_str()]
491+
self.assertEqual(len(squared_acc), 1)
492+
self.assertEqual(len(linear_acc), 1)
493+
self.assertTrue(mul_x.name in squared_acc)
494+
self.assertTrue(mul_x.name in linear_acc)
495+
496+
# Check init_program
497+
init_ops = init_program.global_block().ops
498+
self.assertEqual(len(init_ops), 3)
499+
self.assertEqual(init_ops[0].type, "fill_constant")
500+
self.assertAlmostEqual(init_ops[0].attr('value'), learning_rate)
501+
502+
437503
if __name__ == '__main__':
438504
unittest.main()

0 commit comments

Comments
 (0)