Skip to content

Commit 7970ab9

Browse files
authored
Merge pull request #12544 from jacquesqiao/dist-lookup-table-only-support-sgd
dist lookup table only support sgd
2 parents 2d036c4 + 111bde9 commit 7970ab9

File tree

2 files changed

+111
-7
lines changed

2 files changed

+111
-7
lines changed

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,5 +359,110 @@ def transpiler_test_impl(self):
359359
["sum", "scale", "scale", "elementwise_add", "momentum"])
360360

361361

362+
class TestDistLookupTableBase(TranspilerTest):
363+
def network_with_table(self, is_sparse, is_distributed):
364+
def emb_pool(ids):
365+
table_size = 1000
366+
emb_size = 64
367+
emb = fluid.layers.embedding(
368+
input=ids,
369+
size=[table_size, emb_size],
370+
dtype='float32',
371+
param_attr='shared_w', # share parameter
372+
is_sparse=is_sparse,
373+
is_distributed=is_distributed)
374+
pool = fluid.layers.sequence_pool(input=emb, pool_type='average')
375+
return pool
376+
377+
title_ids = fluid.layers.data(
378+
name='title_ids', shape=[1], dtype='int64', lod_level=1)
379+
brand_ids = fluid.layers.data(
380+
name='brand_ids', shape=[1], dtype='int64', lod_level=1)
381+
title_emb = emb_pool(title_ids)
382+
brand_emb = emb_pool(brand_ids)
383+
fc0 = fluid.layers.concat(input=[title_emb, brand_emb], axis=1)
384+
predict = fluid.layers.fc(input=fc0,
385+
size=2,
386+
act=None,
387+
param_attr=fluid.ParamAttr(name='fc_w'),
388+
bias_attr=fluid.ParamAttr(name='fc_b'))
389+
390+
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
391+
cost = fluid.layers.cross_entropy(input=predict, label=label)
392+
avg_cost = fluid.layers.mean(cost)
393+
optimizer = fluid.optimizer.Adam(learning_rate=0.003)
394+
optimizer.minimize(avg_cost)
395+
396+
397+
class TestLocalLookupTable(TestDistLookupTableBase):
398+
def net_conf(self):
399+
self.network_with_table(is_sparse=True, is_distributed=False)
400+
401+
def transpiler_test_impl(self):
402+
pserver1, startup1 = self.get_pserver(self.pserver1_ep)
403+
404+
self.assertEqual(len(pserver1.blocks), 3)
405+
# 0 listen_and_serv
406+
# 1 optimize for fc_w or fc_b adam
407+
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
408+
["sum", "scale", "adam", "scale", "scale"])
409+
# 2 optimize for table adam
410+
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
411+
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
412+
["sum", "adam", "scale", "scale"])
413+
414+
trainer = self.get_trainer()
415+
self.assertEqual(len(trainer.blocks), 1)
416+
ops = [
417+
'lookup_table', 'sequence_pool', 'lookup_table', 'sequence_pool',
418+
'concat', 'mul', 'elementwise_add', 'cross_entropy', 'mean',
419+
'fill_constant', 'mean_grad', 'cross_entropy_grad',
420+
'elementwise_add_grad', 'send', 'mul_grad', 'send', 'concat_grad',
421+
'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
422+
'lookup_table_grad', 'sum', 'split_selected_rows', 'send',
423+
'send_barrier', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat'
424+
]
425+
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
426+
427+
428+
class TestDistLookupTable(TestDistLookupTableBase):
429+
def net_conf(self):
430+
self.network_with_table(is_sparse=True, is_distributed=True)
431+
432+
def transpiler_test_impl(self):
433+
pserver1, startup1 = self.get_pserver(self.pserver1_ep)
434+
435+
self.assertEqual(len(pserver1.blocks), 6)
436+
# 0 listen_and_serv
437+
# 1 optimize for fc_w or fc_b adam
438+
self.assertEqual([op.type for op in pserver1.blocks[1].ops],
439+
["sum", "scale", "adam", "scale", "scale"])
440+
# 2 optimize for table sgd
441+
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
442+
["sum", "sgd"])
443+
# 3 prefetch -> lookup_sparse_table for data0
444+
self.assertEqual([op.type for op in pserver1.blocks[3].ops],
445+
["lookup_sparse_table"])
446+
# 4 prefetch -> lookup_sparse_table for data1
447+
self.assertEqual([op.type for op in pserver1.blocks[4].ops],
448+
["lookup_sparse_table"])
449+
# 5 save table
450+
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
451+
452+
trainer = self.get_trainer()
453+
self.assertEqual(len(trainer.blocks), 1)
454+
ops = [
455+
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
456+
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul',
457+
'elementwise_add', 'cross_entropy', 'mean', 'fill_constant',
458+
'mean_grad', 'cross_entropy_grad', 'elementwise_add_grad', 'send',
459+
'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
460+
'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
461+
'sum', 'split_ids', 'send', 'send_barrier', 'recv', 'recv',
462+
'fetch_barrier'
463+
]
464+
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
465+
466+
362467
if __name__ == "__main__":
363468
unittest.main()

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -896,8 +896,6 @@ def _create_table_optimize_block(self, pserver_index, pserver_program,
896896
self.table_name
897897
][0]
898898
table_opt_block = pserver_program.create_block(pre_block_idx)
899-
# only support sgd now
900-
assert table_opt_op.type == "sgd"
901899

902900
if self.sync_mode:
903901
# create grad vars in pserver program
@@ -937,11 +935,12 @@ def _create_table_optimize_block(self, pserver_index, pserver_program,
937935
"LearningRate": [lr_var]
938936
}
939937
outputs = {"ParamOut": [param_var]}
940-
table_opt_block.append_op(
941-
type=table_opt_op.type,
942-
inputs=inputs,
943-
outputs=outputs,
944-
attrs=table_opt_op.attrs)
938+
# only support sgd now
939+
import logging
940+
logging.warn(
941+
"distribute lookup table only support sgd optimizer, change it's optimizer to sgd instead of "
942+
+ table_opt_op.type)
943+
table_opt_block.append_op(type="sgd", inputs=inputs, outputs=outputs)
945944

946945
# add table parameter gradient and it's block id to grad_to_block_id
947946
grad_to_block_id.append(grad_var.name + ":" + str(table_opt_block.idx))

0 commit comments

Comments
 (0)