diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp index 7ff435a033985..ebb88bf695d4c 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -16,24 +16,6 @@ using namespace mlir; using namespace mlir::tensor; -/// Compute a map that for a given dimension of the expanded type gives the -/// dimension in the collapsed type it maps to. Essentially its the inverse of -/// the `reassocation` maps. -static llvm::DenseMap -getExpandedDimToCollapsedDimMap(ArrayRef reassociation) { - llvm::DenseMap expandedDimToCollapsedDim; - for (const auto &map : enumerate(reassociation)) { - unsigned startPos = - cast(map.value().getResults().front()).getPosition(); - unsigned endPos = - cast(map.value().getResults().back()).getPosition(); - for (auto dim : llvm::seq_inclusive(startPos, endPos)) { - expandedDimToCollapsedDim[dim] = map.index(); - } - } - return expandedDimToCollapsedDim; -} - /// For reshape op compute the shape at dimension `dimIndex` of the output in /// terms of shape of the `src`, when the reshape op is a collapsing /// operation. It is the product of the shape of the collapsed dimensions of the @@ -76,86 +58,33 @@ static SmallVector getCollapsedOutputShapeFromInputShape( })); } -/// For an expanding reshape op, compute the value for a dimension of the output -/// from the shape of the input. -static OpFoldResult getExpandedOutputDimFromInputShape( - OpBuilder &builder, Location loc, int64_t dimIndex, Value src, - ArrayRef dstStaticShape, ArrayRef reassociation, - llvm::DenseMap &expandedDimToCollapsedDim) { - if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) { - // Static dimension: return Attribute. - return builder.getIndexAttr(dstStaticShape[dimIndex]); - } - unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex]; - unsigned startPos = - cast(reassociation[sourceDimPos].getResults().front()) - .getPosition(); - unsigned endPos = - cast(reassociation[sourceDimPos].getResults().back()) - .getPosition(); - int64_t linearizedStaticDim = 1; - for (auto d : - llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) { - if (d.index() + startPos == static_cast(dimIndex)) - continue; - assert(!ShapedType::isDynamic(d.value()) && - "single dimension cannot be expanded into multiple dynamic " - "dimensions"); - linearizedStaticDim *= d.value(); +struct ReifyCollapseShapeOp + : public ReifyRankedShapedTypeOpInterface::ExternalModel< + ReifyCollapseShapeOp, CollapseShapeOp> { + LogicalResult + reifyResultShapes(Operation *op, OpBuilder &b, + ReifiedRankedShapedTypeDims &reifiedReturnShapes) const { + auto loc = op->getLoc(); + auto collapseShape = cast(op); + reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape( + b, loc, collapseShape.getSrc(), + collapseShape.getResultType().getShape(), + collapseShape.getReassociationMaps())); + return success(); } - OpFoldResult sourceDim = - builder.create(loc, src, sourceDimPos).getResult(); - - // Dynamic dimension: return Value. - return affine::makeComposedAffineApply( - builder, loc, - AffineMap::get( - 0, 1, - builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)), - sourceDim) - ->getResult(0); -} - -/// Given the `src` of an expanding reshape op, the reassociation maps and the -/// result type, compute the shape of the result of the reshape. -static SmallVector getExpandedOutputShapeFromInputShape( - OpBuilder &builder, Location loc, Value src, - ArrayRef dstStaticShape, ArrayRef reassociation) { - llvm::DenseMap expandedDimToCollapsedDim = - getExpandedDimToCollapsedDimMap(reassociation); - return llvm::to_vector<4>(llvm::map_range( - llvm::seq(0, dstStaticShape.size()), [&](int64_t dim) { - return getExpandedOutputDimFromInputShape(builder, loc, dim, src, - dstStaticShape, reassociation, - expandedDimToCollapsedDim); - })); -} - -static SmallVector -getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src, - ArrayRef dstStaticShape, - ArrayRef reassocation) { - return dstStaticShape.size() > - static_cast( - llvm::cast(src.getType()).getRank()) - ? getExpandedOutputShapeFromInputShape( - builder, loc, src, dstStaticShape, reassocation) - : getCollapsedOutputShapeFromInputShape( - builder, loc, src, dstStaticShape, reassocation); -} +}; -template -struct ReifyExpandOrCollapseShapeOp - : public ReifyRankedShapedTypeOpInterface::ExternalModel< - ReifyExpandOrCollapseShapeOp, OpTy> { +struct ReifyExpandShapeOp + : public ReifyRankedShapedTypeOpInterface::ExternalModel { LogicalResult reifyResultShapes(Operation *op, OpBuilder &b, ReifiedRankedShapedTypeDims &reifiedReturnShapes) const { auto loc = op->getLoc(); - auto reshapeOp = cast(op); - reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape( - b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(), - reshapeOp.getReassociationMaps())); + auto expandShape = cast(op); + SmallVector outputShape = getMixedValues( + expandShape.getStaticOutputShape(), expandShape.getOutputShape(), b); + reifiedReturnShapes.push_back(outputShape); return success(); } }; @@ -202,10 +131,8 @@ struct ReifyPadOp void mlir::tensor::registerInferTypeOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { - ExpandShapeOp::attachInterface< - ReifyExpandOrCollapseShapeOp>(*ctx); - CollapseShapeOp::attachInterface< - ReifyExpandOrCollapseShapeOp>(*ctx); + ExpandShapeOp::attachInterface(*ctx); + CollapseShapeOp::attachInterface(*ctx); PadOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir index 8fb84248c9613..0595ac2492c97 100644 --- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir +++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir @@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind %3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32> return %1, %2, %3 : index, index, index } -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> // CHECK: func @dim_reshape_expansion // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32> -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index -// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]] -// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]] -// CHECK: return %[[C3]], %[[C4]], %[[D1]] +// CHECK: return %[[C3]], %[[C4]], %[[ARG1]] // ----- diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir index 65ceb4ff3e3df..d3889f23e7d74 100644 --- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir +++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir @@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} { } } -// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> // CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)> func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> { @@ -19,11 +18,8 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4 return %1 : tensor<2x3x5x4x?x7xf32> } // CHECK-LABEL: func @empty_reshape_expansion -// CHECK-SAME: %[[ARG0:.+]]: index -// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32> -// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]] -// CHECK-NEXT: %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]] -// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) +// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index +// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[ARG1]]) // CHECK-NEXT: return %[[INIT]] func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {