Skip to content

Commit d3309f6

Browse files
authored
[DRR] Fix drr rewrite order when result pattern op be inserted (#64784)
* fix drr rewrite insert op order * fix
1 parent 9998434 commit d3309f6

File tree

1 file changed

+28
-53
lines changed

1 file changed

+28
-53
lines changed

paddle/fluid/pir/drr/src/rewrite_pattern.cc

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,6 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
484484
}
485485
}
486486

487-
bool is_one_result = result_pattern_graph.owned_op_call().size() == 1;
488487
// topo order visit result_pattern_graph
489488
GraphTopo graph_topo_visit(&result_pattern_graph);
490489
graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) {
@@ -496,8 +495,9 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
496495
}
497496

498497
// set insert point
499-
size_t max_input_op_index = 0UL;
500-
pir::Operation* max_index_op = nullptr;
498+
// 1. get result pattern max-idx of input op
499+
size_t max_res_idx = 0UL;
500+
pir::Operation* max_res_idx_op = nullptr;
501501
for (const Tensor* input : op_call.inputs()) {
502502
if (input->is_none()) {
503503
continue;
@@ -507,18 +507,16 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
507507
pir::Operation* ir_input_op = ir_val.defining_op();
508508
if (op_2_temp_program_index.count(ir_input_op) == 0) {
509509
// do nothing
510-
} else if (max_input_op_index <
511-
op_2_temp_program_index.at(ir_input_op)) {
512-
max_input_op_index = op_2_temp_program_index.at(ir_input_op);
513-
max_index_op = ir_input_op;
514-
} else if (max_input_op_index ==
515-
op_2_temp_program_index.at(ir_input_op)) {
516-
const auto& ops_vec = temp_program[max_input_op_index];
510+
} else if (max_res_idx < op_2_temp_program_index.at(ir_input_op)) {
511+
max_res_idx = op_2_temp_program_index.at(ir_input_op);
512+
max_res_idx_op = ir_input_op;
513+
} else if (max_res_idx == op_2_temp_program_index.at(ir_input_op)) {
514+
const auto& ops_vec = temp_program[max_res_idx];
517515
for (auto it = ops_vec.begin(); it != ops_vec.end(); it++) {
518-
if (*it == max_index_op) {
516+
if (*it == max_res_idx_op) {
519517
break;
520518
} else if (*it == ir_input_op) {
521-
max_index_op = ir_input_op;
519+
max_res_idx_op = ir_input_op;
522520
break;
523521
} else {
524522
// do nothing
@@ -530,51 +528,28 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
530528
}
531529
}
532530

533-
if (is_one_result && !source_pattern_graph.owned_op_call().empty()) {
534-
// 1. get source pattern min-idx op
535-
pir::Operation* min_src_idx_op = src_match_ctx.IrOperation(
536-
source_pattern_graph.owned_op_call()[0].get());
537-
size_t min_src_idx = op_2_temp_program_index[min_src_idx_op];
538-
for (const auto& src_owned_op_call :
539-
source_pattern_graph.owned_op_call()) {
540-
pir::Operation* src_owned_op =
541-
src_match_ctx.IrOperation(src_owned_op_call.get());
542-
size_t src_owned_op_idx = op_2_temp_program_index[src_owned_op];
543-
if (min_src_idx > src_owned_op_idx) {
544-
min_src_idx = src_owned_op_idx;
545-
min_src_idx_op = src_owned_op;
546-
}
547-
}
548-
// 2. insert new op at point max(max_input_op_index+1, min_src_idx)
549-
if (min_src_idx > max_input_op_index) {
550-
rewriter.set_insertion_point(min_src_idx_op);
551-
max_input_op_index = op_2_temp_program_index[min_src_idx_op];
552-
} else {
553-
rewriter.SetInsertionPointAfter(max_index_op);
554-
}
555-
VLOG(6) << "(" << op_call.name() << ") insert at idx "
556-
<< std::max(max_input_op_index + 1, min_src_idx);
557-
} else {
558-
if (max_input_op_index == 0UL) {
559-
VLOG(6) << "Not found producer op for (" << op_call.name() << ")";
560-
pir::Operation* source_pattern_first_op = src_match_ctx.IrOperation(
561-
source_pattern_graph.owned_op_call()[0].get());
562-
max_input_op_index = op_2_temp_program_index[source_pattern_first_op];
563-
rewriter.set_insertion_point(source_pattern_first_op);
564-
} else {
565-
rewriter.SetInsertionPointAfter(max_index_op);
531+
// 2. get source pattern min-idx op
532+
pir::Operation* min_src_idx_op = src_match_ctx.IrOperation(
533+
source_pattern_graph.owned_op_call()[0].get());
534+
size_t min_src_idx = op_2_temp_program_index[min_src_idx_op];
535+
for (const auto& src_owned_op_call : source_pattern_graph.owned_op_call()) {
536+
pir::Operation* src_owned_op =
537+
src_match_ctx.IrOperation(src_owned_op_call.get());
538+
size_t src_owned_op_idx = op_2_temp_program_index[src_owned_op];
539+
if (min_src_idx > src_owned_op_idx) {
540+
min_src_idx = src_owned_op_idx;
541+
min_src_idx_op = src_owned_op;
566542
}
567543
}
568544

569-
pir::Operation* new_op =
570-
CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx);
571-
572-
size_t new_max_input_op_index = max_input_op_index + 1;
573-
op_2_temp_program_index[new_op] = new_max_input_op_index;
574-
if (new_max_input_op_index >= temp_program.size()) {
575-
temp_program.emplace_back();
545+
// 3. insert new op at point max(max_res_idx+1, min_src_idx)
546+
if (min_src_idx > max_res_idx) {
547+
rewriter.set_insertion_point(min_src_idx_op);
548+
} else {
549+
rewriter.SetInsertionPointAfter(max_res_idx_op);
576550
}
577-
temp_program[new_max_input_op_index].push_back(new_op);
551+
552+
CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx);
578553
});
579554

580555
return res_match_ctx;

0 commit comments

Comments
 (0)