Skip to content

Commit a941786

Browse files
committed
replace concat_op with merge_ids_op
1 parent 509cb0b commit a941786

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def _replace_lookup_table_op_with_prefetch(self, program,
618618
if op.type == LOOKUP_TABLE_TYPE:
619619
continue_search_lookup_table_op = True
620620

621-
op_index = list(all_ops).index(op)
621+
lookup_table_op_index = list(all_ops).index(op)
622622
ids_name = op.input("Ids")
623623
out_name = op.output("Out")
624624

@@ -637,7 +637,7 @@ def _replace_lookup_table_op_with_prefetch(self, program,
637637

638638
# insert split_ids_op
639639
program.global_block().insert_op(
640-
index=op_index,
640+
index=lookup_table_op_index,
641641
type="split_ids",
642642
inputs={
643643
'Ids': [
@@ -649,7 +649,7 @@ def _replace_lookup_table_op_with_prefetch(self, program,
649649

650650
# insert prefetch_op
651651
program.global_block().insert_op(
652-
index=op_index + 1,
652+
index=lookup_table_op_index + 1,
653653
type="prefetch",
654654
inputs={'X': self.prefetch_input_vars},
655655
outputs={"Out": self.prefetch_output_vars},
@@ -660,16 +660,21 @@ def _replace_lookup_table_op_with_prefetch(self, program,
660660

661661
# insert concat_op
662662
program.global_block().insert_op(
663-
index=op_index + 2,
664-
type="concat",
665-
inputs={'X': self.prefetch_output_vars},
663+
index=lookup_table_op_index + 2,
664+
type="merge_ids",
665+
inputs={
666+
'Ids': [
667+
program.global_block().vars[varname]
668+
for varname in ids_name
669+
],
670+
'X': self.prefetch_output_vars
671+
},
666672
outputs={
667673
"Out": [
668674
program.global_block().vars[varname]
669675
for varname in out_name
670676
]
671-
},
672-
attrs={"axis": 0})
677+
})
673678

674679
# delete lookup_table_op
675680
delete_ops(program.global_block(), [op])

0 commit comments

Comments
 (0)