Skip to content

Commit 9b805c4

Browse files
authored
[CINN] Fix horizontal fusion with empty loop (#71550) (#71574)
* [CINN] Fix horizontal fusion with empty loop * update
1 parent 599dc84 commit 9b805c4

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

paddle/cinn/operator_fusion/pattern_fuser.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ static bool IsLoopFrameworkEqual(const StmtPattern& lhs,
365365
const auto& rhs_loops = GetLoopFramework(rhs);
366366
VLOG(4) << "lhs " << lhs_loops.DebugStr();
367367
VLOG(4) << "rhs " << rhs_loops.DebugStr();
368+
if (lhs_loops.loop.empty() || rhs_loops.loop.empty()) return false;
368369

369370
// TODO(huangjiyi): support horizontal fusion without reduce dims equal.
370371
const auto get_reduce_loop = [](const MaybeLoopFramework& loop) {

paddle/cinn/operator_fusion/pir_graph_analyzing/loop_axis_mapping.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ LoopAxisMapping CreateLoopAxisMappingForReshape(pir::Operation* op) {
563563
result.output_values.push_back(op->result(0));
564564
result.loop2output.resize(1);
565565
auto in_shape = GetCompatibleValueAllDims(op->operand_source(0));
566-
auto out_shape = GetValueAllDims(op->result(0));
566+
auto out_shape = GetCompatibleValueAllDims(op->result(0));
567567
result.loop = out_shape;
568568

569569
if (!ShapeProductEqual(in_shape, out_shape)) {

0 commit comments

Comments
 (0)