@@ -618,7 +618,7 @@ def _replace_lookup_table_op_with_prefetch(self, program,
618
618
if op .type == LOOKUP_TABLE_TYPE :
619
619
continue_search_lookup_table_op = True
620
620
621
- op_index = list (all_ops ).index (op )
621
+ lookup_table_op_index = list (all_ops ).index (op )
622
622
ids_name = op .input ("Ids" )
623
623
out_name = op .output ("Out" )
624
624
@@ -637,7 +637,7 @@ def _replace_lookup_table_op_with_prefetch(self, program,
637
637
638
638
# insert split_ids_op
639
639
program .global_block ().insert_op (
640
- index = op_index ,
640
+ index = lookup_table_op_index ,
641
641
type = "split_ids" ,
642
642
inputs = {
643
643
'Ids' : [
@@ -649,7 +649,7 @@ def _replace_lookup_table_op_with_prefetch(self, program,
649
649
650
650
# insert prefetch_op
651
651
program .global_block ().insert_op (
652
- index = op_index + 1 ,
652
+ index = lookup_table_op_index + 1 ,
653
653
type = "prefetch" ,
654
654
inputs = {'X' : self .prefetch_input_vars },
655
655
outputs = {"Out" : self .prefetch_output_vars },
@@ -660,16 +660,21 @@ def _replace_lookup_table_op_with_prefetch(self, program,
660
660
661
661
# insert concat_op
662
662
program .global_block ().insert_op (
663
- index = op_index + 2 ,
664
- type = "concat" ,
665
- inputs = {'X' : self .prefetch_output_vars },
663
+ index = lookup_table_op_index + 2 ,
664
+ type = "merge_ids" ,
665
+ inputs = {
666
+ 'Ids' : [
667
+ program .global_block ().vars [varname ]
668
+ for varname in ids_name
669
+ ],
670
+ 'X' : self .prefetch_output_vars
671
+ },
666
672
outputs = {
667
673
"Out" : [
668
674
program .global_block ().vars [varname ]
669
675
for varname in out_name
670
676
]
671
- },
672
- attrs = {"axis" : 0 })
677
+ })
673
678
674
679
# delete lookup_table_op
675
680
delete_ops (program .global_block (), [op ])
0 commit comments