Skip to content

Commit bbc4157

Browse files
MaheshRavishankarmemfrob
authored andcommitted
[mlir][Linalg] Fix element type of results when folding reshapes.
Fixing a minor bug which lead to element type of the output being modified when folding reshapes with generic op. Differential Revision: https://reviews.llvm.org/D101942
1 parent 42088b2 commit bbc4157

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,9 +1129,12 @@ struct PushExpandingReshape : public OpRewritePattern<GenericOpTy> {
11291129
SmallVector<Value> newOutputs;
11301130
SmallVector<Type> newOutputTypes;
11311131
for (auto output : op.outputs()) {
1132+
auto newOutputType = RankedTensorType::get(
1133+
reshapeFound.getSrcType().getShape(),
1134+
output.getType().template cast<RankedTensorType>().getElementType());
11321135
Value newOutput = rewriter.create<TensorReshapeOp>(
1133-
op->getLoc(), reshapeFound.getSrcType(), output, reassociation);
1134-
newOutputTypes.push_back(newOutput.getType());
1136+
op->getLoc(), newOutputType, output, reassociation);
1137+
newOutputTypes.push_back(newOutputType);
11351138
newOutputs.push_back(newOutput);
11361139
}
11371140
// 5. Create a new generic op with lowerer rank.

mlir/test/Dialect/Linalg/fusion-push-reshape.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,40 @@ func @reshape_negative(%A: tensor<12544x16xf32>, %B: tensor<112xf32>) -> tensor<
8888
} -> tensor<112x112x16xf32>
8989
return %22 : tensor<112x112x16xf32>
9090
}
91+
92+
// -----
93+
94+
func @type_correctness(%arg0 : tensor<6x5xi32>, %arg1 : tensor<5xf32>,
95+
%arg2 : tensor<5xf32>) -> tensor<2x3x5xf32> {
96+
%cst_6 = constant 1.000000e+00 : f32
97+
%cst_7 = constant 7.000000e+00 : f32
98+
%cst_8 = constant 1.1920929E-7 : f32
99+
%25 = linalg.tensor_reshape %arg0
100+
[affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>]
101+
: tensor<6x5xi32> into tensor<2x3x5xi32>
102+
%26 = linalg.init_tensor [2, 3, 5] : tensor<2x3x5xf32>
103+
%28 = linalg.generic {
104+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
105+
affine_map<(d0, d1, d2) -> (d2)>,
106+
affine_map<(d0, d1, d2) -> (d2)>,
107+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
108+
iterator_types = ["parallel", "parallel", "parallel"]}
109+
ins(%25, %arg1, %arg2 : tensor<2x3x5xi32>, tensor<5xf32>, tensor<5xf32>)
110+
outs(%26 : tensor<2x3x5xf32>) {
111+
^bb0(%arg6: i32, %arg7: f32, %arg8: f32, %arg9: f32): // no predecessors
112+
%29 = sitofp %arg6 : i32 to f32
113+
%30 = addf %arg7, %cst_8 : f32
114+
%31 = divf %cst_7, %30 : f32
115+
%32 = divf %cst_6, %31 : f32
116+
%33 = mulf %29, %32 : f32
117+
%34 = addf %33, %arg8 : f32
118+
linalg.yield %34 : f32
119+
} -> tensor<2x3x5xf32>
120+
return %28 : tensor<2x3x5xf32>
121+
}
122+
// CHECK-LABEL: func @type_correctness
123+
// CHECK: %[[OP:.+]] = linalg.generic
124+
// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<6x5xi32>, tensor<5xf32>, tensor<5xf32>)
125+
// CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>)
126+
// CHECK: linalg.tensor_reshape %[[OP]]
127+
// CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32>

0 commit comments

Comments
 (0)