diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.td b/mlir/examples/toy/Ch3/mlir/ToyCombine.td index 8bd2b442d69f2..9501a418ee5ce 100644 --- a/mlir/examples/toy/Ch3/mlir/ToyCombine.td +++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.td @@ -43,7 +43,7 @@ def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), // Reshape(Constant(x)) = x' def ReshapeConstant : - NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()))">; + NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()).getShape())">; def FoldConstantReshapeOptPattern : Pat< (ReshapeOp:$res (ConstantOp $arg)), (ConstantOp (ReshapeConstant $arg, $res))>; diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.td b/mlir/examples/toy/Ch4/mlir/ToyCombine.td index 11d783150ebe1..626bfa32bdae7 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.td +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.td @@ -42,7 +42,7 @@ def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), // Reshape(Constant(x)) = x' def ReshapeConstant : - NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()))">; + NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()).getShape())">; def FoldConstantReshapeOptPattern : Pat< (ReshapeOp:$res (ConstantOp $arg)), (ConstantOp (ReshapeConstant $arg, $res))>; diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.td b/mlir/examples/toy/Ch5/mlir/ToyCombine.td index 11d783150ebe1..626bfa32bdae7 100644 --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.td +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.td @@ -42,7 +42,7 @@ def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), // Reshape(Constant(x)) = x' def ReshapeConstant : - NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()))">; + NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()).getShape())">; def FoldConstantReshapeOptPattern : Pat< (ReshapeOp:$res (ConstantOp $arg)), (ConstantOp (ReshapeConstant $arg, $res))>; diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.td b/mlir/examples/toy/Ch6/mlir/ToyCombine.td index 11d783150ebe1..626bfa32bdae7 100644 --- a/mlir/examples/toy/Ch6/mlir/ToyCombine.td +++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.td @@ -42,7 +42,7 @@ def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), // Reshape(Constant(x)) = x' def ReshapeConstant : - NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()))">; + NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()).getShape())">; def FoldConstantReshapeOptPattern : Pat< (ReshapeOp:$res (ConstantOp $arg)), (ConstantOp (ReshapeConstant $arg, $res))>; diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.td b/mlir/examples/toy/Ch7/mlir/ToyCombine.td index 11d783150ebe1..626bfa32bdae7 100644 --- a/mlir/examples/toy/Ch7/mlir/ToyCombine.td +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.td @@ -42,7 +42,7 @@ def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), // Reshape(Constant(x)) = x' def ReshapeConstant : - NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()))">; + NativeCodeCall<"$0.reshape(::llvm::cast($1.getType()).getShape())">; def FoldConstantReshapeOptPattern : Pat< (ReshapeOp:$res (ConstantOp $arg)), (ConstantOp (ReshapeConstant $arg, $res))>; diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h index 704e39e908841..abe40227b31dc 100644 --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -92,7 +92,8 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, // Reshape of a constant can be replaced with a new constant. if (auto elements = dyn_cast_or_null(operands.front())) - return elements.reshape(cast(reshapeOp.getResult().getType())); + return elements.reshape( + cast(reshapeOp.getResult().getType()).getShape()); // Fold if the producer reshape source has the same shape with at most 1 // dynamic dimension. diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h index c07ade606a775..ee26537d20e8c 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -615,9 +615,9 @@ class DenseElementsAttr : public Attribute { //===--------------------------------------------------------------------===// /// Return a new DenseElementsAttr that has the same data as the current - /// attribute, but has been reshaped to 'newType'. The new type must have the - /// same total number of elements as well as element type. - DenseElementsAttr reshape(ShapedType newType); + /// attribute, but has been reshaped to 'newShape'. The new shape must have + /// the same total number of elements. + DenseElementsAttr reshape(ArrayRef newShape); /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but with a different shape for a splat type. The new type must diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 8d57ab6b59e79..f81832fcb981e 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -679,8 +679,9 @@ MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, MlirType shapedType) { - return wrap(llvm::cast(unwrap(attr)) - .reshape(llvm::cast(unwrap(shapedType)))); + return wrap( + llvm::cast(unwrap(attr)) + .reshape(llvm::cast(unwrap(shapedType)).getShape())); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 434d7df853a5e..dd11f4f2bafda 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -274,7 +274,7 @@ struct ConstantCompositeOpPattern final if (isa(srcType)) { dstAttrType = RankedTensorType::get(srcType.getNumElements(), srcType.getElementType()); - dstElementsAttr = dstElementsAttr.reshape(dstAttrType); + dstElementsAttr = dstElementsAttr.reshape(dstAttrType.getShape()); } else { // TODO: add support for large vectors. return failure(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 5758d8d5ef506..adeecb23528db 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1387,7 +1387,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { return {}; return operand.reshape( - llvm::cast(operand.getType()).clone(shapeVec)); + llvm::cast(operand.getType()).clone(shapeVec).getShape()); } return {}; @@ -1546,7 +1546,7 @@ OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { llvm::dyn_cast_if_present(adaptor.getInput1())) { if (input.isSplat() && resultTy.hasStaticShape() && input.getType().getElementType() == resultTy.getElementType()) - return input.reshape(resultTy); + return input.reshape(resultTy.getShape()); } // Transpose is not the identity transpose. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp index db7a3c671dedc..9090080534bc7 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp @@ -193,7 +193,7 @@ TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input, RankedTensorType::get(newShape, oldType.getElementType()); if (input.isSplat()) { - return input.reshape(newType); + return input.reshape(newType.getShape()); } auto rawData = input.getRawData(); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 7d615bfc12984..3e8927069cf77 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6000,12 +6000,11 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { // shape_cast(constant) -> constant if (auto splatAttr = llvm::dyn_cast_if_present(adaptor.getSource())) - return splatAttr.reshape(getType()); + return splatAttr.reshape(getType().getShape()); // shape_cast(poison) -> poison - if (llvm::dyn_cast_if_present(adaptor.getSource())) { + if (llvm::dyn_cast_if_present(adaptor.getSource())) return ub::PoisonAttr::get(getContext()); - } return {}; } @@ -6346,7 +6345,8 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { // Eliminate splat constant transpose ops. if (auto splat = llvm::dyn_cast_if_present(adaptor.getVector())) - return splat.reshape(getResultVectorType()); + return DenseElementsAttr::get(getResultVectorType(), + splat.getSplatValue()); // Eliminate poison transpose ops. if (llvm::dyn_cast_if_present(adaptor.getVector())) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index fe17b3c0b2cfc..515dd14081626 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -36,7 +36,7 @@ linearizeConstAttr(Location loc, ConversionPatternRewriter &rewriter, loc, "Cannot linearize a constant scalable vector that's not a splat"); - return dstElementsAttr.reshape(resType); + return dstElementsAttr.reshape(resType.getShape()); } if (auto poisonAttr = dyn_cast(value)) diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index fd898b7493c7f..81b2213dd5a93 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1241,8 +1241,10 @@ ArrayRef DenseElementsAttr::getRawStringData() const { /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has been reshaped to 'newType'. The new type must have the /// same total number of elements as well as element type. -DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { +DenseElementsAttr DenseElementsAttr::reshape(ArrayRef newShape) { + ShapedType curType = getType(); + auto newType = curType.cloneWith(newShape, curType.getElementType()); if (curType == newType) return *this; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 1604ebba190a1..5bdb0f8701cbf 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/TypeSwitch.h" +#include using namespace mlir; using namespace mlir::detail; @@ -244,10 +245,61 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) { return VectorType(); } -VectorType VectorType::cloneWith(std::optional> shape, +VectorType VectorType::cloneWith(std::optional> maybeShape, Type elementType) const { - return VectorType::get(shape.value_or(getShape()), elementType, - getScalableDims()); + + // Case where only the element type is modified: + if (!maybeShape.has_value()) + return VectorType::get(getShape(), elementType, getScalableDims()); + + ArrayRef shape = maybeShape.value(); + int64_t rankBefore = getRank(); + int64_t rankAfter = static_cast(shape.size()); + + // In the case where the rank is unchanged, the positions of the scalable + // dimensions are retained. + // Example: vector<4x[1]xf32> -> vector<1x[4]xi8> + if (rankBefore == rankAfter) + return VectorType::get(shape, elementType, getScalableDims()); + + // In the case where the rank increases, retain the scalable dimension + // position relative to front (outermost dimension). + // Example: vector<4x[1]xf32> -> vector<1x[2]x2x1xi8> + if (rankBefore < rankAfter) { + SmallVector newScalableDims(rankAfter, false); + std::copy(getScalableDims().begin(), getScalableDims().end(), + newScalableDims.begin() + (rankAfter - rankBefore)); + return VectorType::get(shape, elementType, newScalableDims); + } + + // In the case where the rank decreases, retain the first `rankAfter` scalable + // dimensions. Any scalable dimensions in the final `rankBefore - rankAfter` + // dimensions are packed into gaps, if possible. + // + // Examples: + // + // vector<4x[1]xf32> -> vector<[4]xi8> + // vector<[4]x1xf32> -> vector<[4]xi8> + // vector<[2]x3x[4]x5xf32> -> vector<[6]x[20]xi8> + // + // If the number of scalable dimensions excedes the number of dimensions in + // the new shape, there is an assertion failure. + assert(rankAfter < rankBefore); + SmallVector newScalableDims(getScalableDims().take_front(rankAfter)); + int nScalablesToRelocate = + llvm::count_if(getScalableDims().take_back(rankBefore - rankAfter), + [](bool b) { return b; }); + int currentIndex = newScalableDims.size() - 1; + while (nScalablesToRelocate > 0 && currentIndex >= 0) { + if (!newScalableDims[currentIndex]) { + newScalableDims[currentIndex] = true; + --nScalablesToRelocate; + } + } + + assert(nScalablesToRelocate == 0 && + "too many scalable dimensions for new (lower) rank"); + return VectorType::get(shape, elementType, newScalableDims); } //===----------------------------------------------------------------------===//