Skip to content

Commit 373f649

Browse files
committed
add comment and unit test
test=develop
1 parent 6705046 commit 373f649

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

python/paddle/fluid/optimizer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,15 @@ def _create_optimization_pass(self,
249249

250250
def _process_distribute_lookuptable(self, param_grads, loss,
251251
startup_program):
252+
"""
253+
Because distribute lookup table only support SGD optimizer for now, not support
254+
other optimizer and regularization, so we should find the table parameter out,
255+
and avoid to add regularization and other op for it, and add sgd optimize op
256+
for it independently.
257+
:param param_grads(list((Var, Var))): list of (param, grad) pair.
258+
:param loss: the loss variable.
259+
:param startup_program: the startup program
260+
"""
252261
program = loss.block.program
253262
table_name = find_distributed_lookup_table(program)
254263
table_param = None

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def transpiler_test_impl(self):
641641
# 5 save table
642642
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
643643

644-
trainer, _ = self.get_trainer(config)
644+
trainer, trainer_startup = self.get_trainer(config)
645645
self.assertEqual(len(trainer.blocks), 1)
646646
ops = [
647647
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
@@ -655,6 +655,16 @@ def transpiler_test_impl(self):
655655
'recv', 'concat'
656656
]
657657
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
658+
startup_ops = [
659+
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
660+
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
661+
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
662+
'fill_constant', 'fill_constant', 'uniform_random',
663+
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
664+
'fake_init'
665+
]
666+
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
667+
startup_ops)
658668

659669

660670
class TestDistLookupTableSliceSize(TestDistLookupTableBase):

0 commit comments

Comments
 (0)