@@ -484,7 +484,6 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
484
484
}
485
485
}
486
486
487
- bool is_one_result = result_pattern_graph.owned_op_call ().size () == 1 ;
488
487
// topo order visit result_pattern_graph
489
488
GraphTopo graph_topo_visit (&result_pattern_graph);
490
489
graph_topo_visit.WalkGraphNodesTopoOrder ([&](const OpCall& op_call) {
@@ -496,8 +495,9 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
496
495
}
497
496
498
497
// set insert point
499
- size_t max_input_op_index = 0UL ;
500
- pir::Operation* max_index_op = nullptr ;
498
+ // 1. get result pattern max-idx of input op
499
+ size_t max_res_idx = 0UL ;
500
+ pir::Operation* max_res_idx_op = nullptr ;
501
501
for (const Tensor* input : op_call.inputs ()) {
502
502
if (input->is_none ()) {
503
503
continue ;
@@ -507,18 +507,16 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
507
507
pir::Operation* ir_input_op = ir_val.defining_op ();
508
508
if (op_2_temp_program_index.count (ir_input_op) == 0 ) {
509
509
// do nothing
510
- } else if (max_input_op_index <
511
- op_2_temp_program_index.at (ir_input_op)) {
512
- max_input_op_index = op_2_temp_program_index.at (ir_input_op);
513
- max_index_op = ir_input_op;
514
- } else if (max_input_op_index ==
515
- op_2_temp_program_index.at (ir_input_op)) {
516
- const auto & ops_vec = temp_program[max_input_op_index];
510
+ } else if (max_res_idx < op_2_temp_program_index.at (ir_input_op)) {
511
+ max_res_idx = op_2_temp_program_index.at (ir_input_op);
512
+ max_res_idx_op = ir_input_op;
513
+ } else if (max_res_idx == op_2_temp_program_index.at (ir_input_op)) {
514
+ const auto & ops_vec = temp_program[max_res_idx];
517
515
for (auto it = ops_vec.begin (); it != ops_vec.end (); it++) {
518
- if (*it == max_index_op ) {
516
+ if (*it == max_res_idx_op ) {
519
517
break ;
520
518
} else if (*it == ir_input_op) {
521
- max_index_op = ir_input_op;
519
+ max_res_idx_op = ir_input_op;
522
520
break ;
523
521
} else {
524
522
// do nothing
@@ -530,51 +528,28 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
530
528
}
531
529
}
532
530
533
- if (is_one_result && !source_pattern_graph.owned_op_call ().empty ()) {
534
- // 1. get source pattern min-idx op
535
- pir::Operation* min_src_idx_op = src_match_ctx.IrOperation (
536
- source_pattern_graph.owned_op_call ()[0 ].get ());
537
- size_t min_src_idx = op_2_temp_program_index[min_src_idx_op];
538
- for (const auto & src_owned_op_call :
539
- source_pattern_graph.owned_op_call ()) {
540
- pir::Operation* src_owned_op =
541
- src_match_ctx.IrOperation (src_owned_op_call.get ());
542
- size_t src_owned_op_idx = op_2_temp_program_index[src_owned_op];
543
- if (min_src_idx > src_owned_op_idx) {
544
- min_src_idx = src_owned_op_idx;
545
- min_src_idx_op = src_owned_op;
546
- }
547
- }
548
- // 2. insert new op at point max(max_input_op_index+1, min_src_idx)
549
- if (min_src_idx > max_input_op_index) {
550
- rewriter.set_insertion_point (min_src_idx_op);
551
- max_input_op_index = op_2_temp_program_index[min_src_idx_op];
552
- } else {
553
- rewriter.SetInsertionPointAfter (max_index_op);
554
- }
555
- VLOG (6 ) << " (" << op_call.name () << " ) insert at idx "
556
- << std::max (max_input_op_index + 1 , min_src_idx);
557
- } else {
558
- if (max_input_op_index == 0UL ) {
559
- VLOG (6 ) << " Not found producer op for (" << op_call.name () << " )" ;
560
- pir::Operation* source_pattern_first_op = src_match_ctx.IrOperation (
561
- source_pattern_graph.owned_op_call ()[0 ].get ());
562
- max_input_op_index = op_2_temp_program_index[source_pattern_first_op];
563
- rewriter.set_insertion_point (source_pattern_first_op);
564
- } else {
565
- rewriter.SetInsertionPointAfter (max_index_op);
531
+ // 2. get source pattern min-idx op
532
+ pir::Operation* min_src_idx_op = src_match_ctx.IrOperation (
533
+ source_pattern_graph.owned_op_call ()[0 ].get ());
534
+ size_t min_src_idx = op_2_temp_program_index[min_src_idx_op];
535
+ for (const auto & src_owned_op_call : source_pattern_graph.owned_op_call ()) {
536
+ pir::Operation* src_owned_op =
537
+ src_match_ctx.IrOperation (src_owned_op_call.get ());
538
+ size_t src_owned_op_idx = op_2_temp_program_index[src_owned_op];
539
+ if (min_src_idx > src_owned_op_idx) {
540
+ min_src_idx = src_owned_op_idx;
541
+ min_src_idx_op = src_owned_op;
566
542
}
567
543
}
568
544
569
- pir::Operation* new_op =
570
- CreateOperation (op_call, src_match_ctx, rewriter, &res_match_ctx);
571
-
572
- size_t new_max_input_op_index = max_input_op_index + 1 ;
573
- op_2_temp_program_index[new_op] = new_max_input_op_index;
574
- if (new_max_input_op_index >= temp_program.size ()) {
575
- temp_program.emplace_back ();
545
+ // 3. insert new op at point max(max_res_idx+1, min_src_idx)
546
+ if (min_src_idx > max_res_idx) {
547
+ rewriter.set_insertion_point (min_src_idx_op);
548
+ } else {
549
+ rewriter.SetInsertionPointAfter (max_res_idx_op);
576
550
}
577
- temp_program[new_max_input_op_index].push_back (new_op);
551
+
552
+ CreateOperation (op_call, src_match_ctx, rewriter, &res_match_ctx);
578
553
});
579
554
580
555
return res_match_ctx;
0 commit comments