Skip to content

Commit 69cbbb5

Browse files
committed
fix: fuse locations of double reshapes when folding.
1 parent ad4697c commit 69cbbb5

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,10 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
11071107
if (auto reshapeOp = llvm::dyn_cast_if_present<tosa::ReshapeOp>(
11081108
getInput1().getDefiningOp())) {
11091109
getInput1Mutable().assign(reshapeOp.getInput1());
1110+
1111+
// Fuse locations so that first ReshapeOp location isn't lost.
1112+
getResult().getDefiningOp()->setLoc(
1113+
mlir::FusedLoc::get(getContext(), {reshapeOp->getLoc(), getLoc()}));
11101114
return getResult();
11111115
}
11121116

mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,45 @@ func.func @canonicalize_optimize_sqrt_reciprocal_bf16(%arg0: tensor<1x5x1x1xbf16
4343
return %2 : tensor<1x5x1x1xbf16>
4444
}
4545
#loc0 = loc("Pow_B")
46-
#loc1 = loc("Reciprocal_C")
46+
#loc1 = loc("Reciprocal_C")
47+
48+
// -----
49+
50+
// CHECK-LABEL: @reshape_canonicalize_double
51+
func.func @reshape_canonicalize_double(%arg0: tensor<?x10xf32>) -> tensor<?x5xf32> {
52+
// CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0 {new_shape = array<i64: -1, 5>} {{.*}} loc([[LOC:.*]])
53+
// CHECK: return %[[VAL_1]]
54+
%0 = tosa.reshape %arg0 {new_shape = array<i64: 5, -1>}: (tensor<?x10xf32>) -> tensor<5x?xf32> loc(#loc0)
55+
%1 = tosa.reshape %0 {new_shape = array<i64: -1, 5>}: (tensor<5x?xf32>) -> tensor<?x5xf32> loc(#loc1)
56+
return %1 : tensor<?x5xf32>
57+
}
58+
#loc0 = loc("reshape1")
59+
#loc1 = loc("reshape2")
60+
61+
// CHECK-DAG: #[[A:.*]] = loc("reshape1")
62+
// CHECK-DAG: #[[B:.*]] = loc("reshape2")
63+
// CHECK-DAG: [[LOC]] = loc(fused[#[[A]], #[[B]]])
64+
65+
// -----
66+
67+
// CHECK-LABEL: @reshape_canonicalize_double_fused_locs
68+
func.func @reshape_canonicalize_double_fused_locs(%arg0: tensor<?x10xf32>) -> tensor<?x5xf32> {
69+
// CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0 {new_shape = array<i64: -1, 5>} {{.*}} loc([[LOC:.*]])
70+
// CHECK: return %[[VAL_1]]
71+
%0 = tosa.reshape %arg0 {new_shape = array<i64: 5, -1>}: (tensor<?x10xf32>) -> tensor<5x?xf32> loc(#fused_loc0)
72+
%1 = tosa.reshape %0 {new_shape = array<i64: -1, 5>}: (tensor<5x?xf32>) -> tensor<?x5xf32> loc(#fused_loc1)
73+
return %1 : tensor<?x5xf32>
74+
}
75+
#loc0 = loc("reshape1_1")
76+
#loc1 = loc("reshape1_2")
77+
#loc2 = loc("reshape2_1")
78+
#loc3 = loc("reshape2_2")
79+
80+
// CHECK-DAG: #[[A:.*]] = loc("reshape1_1")
81+
// CHECK-DAG: #[[B:.*]] = loc("reshape1_2")
82+
// CHECK-DAG: #[[C:.*]] = loc("reshape2_1")
83+
// CHECK-DAG: #[[D:.*]] = loc("reshape2_2")
84+
// CHECK-DAG: [[LOC]] = loc(fused[#[[A]], #[[B]], #[[C]], #[[D]]])
85+
86+
#fused_loc0 = loc(fused[#loc0, #loc1])
87+
#fused_loc1 = loc(fused[#loc2, #loc3])

0 commit comments

Comments
 (0)