Skip to content

Commit 670a68e

Browse files
authored
[mlir][tensor] Preserve encoding in CollapseShapeOp::build (llvm#173720)
This PR updates `CollapseShapeOp::build` so that when the result type is not explicitly provided, the inferred result type preserves the encoding of the source tensor.
1 parent a46cb15 commit 670a68e

File tree

4 files changed

+16
-17
lines changed

4 files changed

+16
-17
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1253,7 +1253,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
12531253
inferCollapsedType(RankedTensorType type, ArrayRef<AffineMap> reassociation);
12541254
static RankedTensorType
12551255
inferCollapsedType(RankedTensorType type,
1256-
SmallVector<ReassociationIndices> reassociation);
1256+
ArrayRef<ReassociationIndices> reassociation);
12571257
}];
12581258
let hasVerifier = 1;
12591259
}

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,11 +1900,9 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims(
19001900
applyPermutationMap(indexingMap, ArrayRef(loopBound));
19011901
Value result;
19021902
if (isa<MemRefType>(collapsedOpResult.getType())) {
1903-
MemRefType expandShapeResultType = MemRefType::get(
1904-
originalResultType.getShape(), originalResultType.getElementType());
19051903
result = memref::ExpandShapeOp::create(
1906-
rewriter, loc, expandShapeResultType, collapsedOpResult,
1907-
reassociation, resultShape);
1904+
rewriter, loc, originalResultType, collapsedOpResult, reassociation,
1905+
resultShape);
19081906
} else {
19091907
result = tensor::ExpandShapeOp::create(
19101908
rewriter, loc, originalResultType, collapsedOpResult, reassociation,

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,7 +1985,7 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
19851985
}
19861986

19871987
RankedTensorType CollapseShapeOp::inferCollapsedType(
1988-
RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1988+
RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
19891989
return inferCollapsedType(
19901990
type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
19911991
type.getContext(), reassociation)));
@@ -2023,10 +2023,11 @@ CollapseShapeOp::inferCollapsedType(RankedTensorType type,
20232023
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
20242024
ArrayRef<ReassociationIndices> reassociation,
20252025
ArrayRef<NamedAttribute> attrs) {
2026-
auto resultType = inferCollapsedType(
2027-
llvm::cast<RankedTensorType>(src.getType()),
2028-
getSymbolLessAffineMaps(
2029-
convertReassociationIndicesToExprs(b.getContext(), reassociation)));
2026+
auto srcType = llvm::cast<RankedTensorType>(src.getType());
2027+
RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
2028+
auto resultType =
2029+
RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
2030+
srcType.getEncoding());
20302031
result.addAttribute(getReassociationAttrStrName(),
20312032
getReassociationIndicesAttribute(b, reassociation));
20322033
build(b, result, resultType, src, attrs);

mlir/test/Dialect/Linalg/collapse-dim.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,13 @@ func.func @uncollapsable_memref_projected_ops(%arg0: memref<1x24x32x8xf32>, %arg
168168
// CHECK-LABEL: func.func @linalg_copy(
169169
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x3x4x5xf32, 1 : i64>,
170170
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x2x3x4x5xf32, 3 : i64>) -> tensor<1x2x3x4x5xf32, 3 : i64> {
171-
// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32>
172-
// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32>
173-
// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
174-
// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32> into tensor<1x2x60xf32>
175-
// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32>) outs(%[[VAL_5]] : tensor<1x2x60xf32>) -> tensor<1x2x60xf32>
176-
// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32> into tensor<1x2x12x5xf32>
177-
// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32> into tensor<1x2x3x4x5xf32, 3 : i64>
171+
// CHECK: %[[VAL_2:.*]] = tensor.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 1 : i64> into tensor<1x2x12x5xf32, 1 : i64>
172+
// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3], [4]] : tensor<1x2x3x4x5xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
173+
// CHECK: %[[VAL_4:.*]] = tensor.collapse_shape %[[VAL_2]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 1 : i64> into tensor<1x2x60xf32, 1 : i64>
174+
// CHECK: %[[VAL_5:.*]] = tensor.collapse_shape %[[VAL_3]] {{\[\[}}0], [1], [2, 3]] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x60xf32, 3 : i64>
175+
// CHECK: %[[VAL_6:.*]] = linalg.copy ins(%[[VAL_4]] : tensor<1x2x60xf32, 1 : i64>) outs(%[[VAL_5]] : tensor<1x2x60xf32, 3 : i64>) -> tensor<1x2x60xf32, 3 : i64>
176+
// CHECK: %[[VAL_7:.*]] = tensor.expand_shape %[[VAL_6]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 2, 12, 5] : tensor<1x2x60xf32, 3 : i64> into tensor<1x2x12x5xf32, 3 : i64>
177+
// CHECK: %[[VAL_8:.*]] = tensor.expand_shape %[[VAL_7]] {{\[\[}}0], [1], [2, 3], [4]] output_shape [1, 2, 3, 4, 5] : tensor<1x2x12x5xf32, 3 : i64> into tensor<1x2x3x4x5xf32, 3 : i64>
178178
// CHECK: return %[[VAL_8]] : tensor<1x2x3x4x5xf32, 3 : i64>
179179
// CHECK: }
180180

0 commit comments

Comments
 (0)