From 2e027d809cdcad0df674afd6cfae4a053d58359c Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 14 Apr 2025 13:39:40 -0700 Subject: [PATCH 1/4] move canonicalizers to folders --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 161 ++++++++------------- mlir/test/Dialect/Vector/canonicalize.mlir | 7 + 2 files changed, 66 insertions(+), 102 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index bee5c1fd6ed58..18dbd1167995e 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3717,6 +3717,59 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) { return getVector(); if (succeeded(foldExtractStridedOpFromInsertChain(*this))) return getResult(); + + Attribute foldInput = adaptor.getVector(); + if (!foldInput) { + return {}; + } + + // rewrite : ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. + if (auto splat = llvm::dyn_cast(foldInput)) + DenseElementsAttr::get(getType(), splat.getSplatValue()); + + // rewrite ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp. + if (auto dense = llvm::dyn_cast(foldInput)) { + // TODO: Handle non-unit strides when they become available. + if (hasNonUnitStrides()) + return {}; + + Value sourceVector = getVector(); + auto sourceVecTy = llvm::cast(sourceVector.getType()); + ArrayRef sourceShape = sourceVecTy.getShape(); + SmallVector sourceStrides = computeStrides(sourceShape); + + VectorType sliceVecTy = getType(); + ArrayRef sliceShape = sliceVecTy.getShape(); + int64_t sliceRank = sliceVecTy.getRank(); + + // Expand offsets and sizes to match the vector rank. + SmallVector offsets(sliceRank, 0); + copy(getI64SubArray(getOffsets()), offsets.begin()); + + SmallVector sizes(sourceShape); + copy(getI64SubArray(getSizes()), sizes.begin()); + + // Calculate the slice elements by enumerating all slice positions and + // linearizing them. The enumeration order is lexicographic which yields a + // sequence of monotonically increasing linearized position indices. + auto denseValuesBegin = dense.value_begin(); + SmallVector sliceValues; + sliceValues.reserve(sliceVecTy.getNumElements()); + SmallVector currSlicePosition(offsets.begin(), offsets.end()); + do { + int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides); + assert(linearizedPosition < sourceVecTy.getNumElements() && + "Invalid index"); + sliceValues.push_back(*(denseValuesBegin + linearizedPosition)); + } while ( + succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets))); + + assert(static_cast(sliceValues.size()) == + sliceVecTy.getNumElements() && + "Invalid number of slice elements"); + return DenseElementsAttr::get(sliceVecTy, sliceValues); + } + return {}; } @@ -3781,98 +3834,6 @@ class StridedSliceConstantMaskFolder final } }; -// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. -class StridedSliceSplatConstantFolder final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, - PatternRewriter &rewriter) const override { - // Return if 'ExtractStridedSliceOp' operand is not defined by a splat - // ConstantOp. - Value sourceVector = extractStridedSliceOp.getVector(); - Attribute vectorCst; - if (!matchPattern(sourceVector, m_Constant(&vectorCst))) - return failure(); - - auto splat = llvm::dyn_cast(vectorCst); - if (!splat) - return failure(); - - auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(), - splat.getSplatValue()); - rewriter.replaceOpWithNewOp(extractStridedSliceOp, - newAttr); - return success(); - } -}; - -// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) -> -// ConstantOp. -class StridedSliceNonSplatConstantFolder final - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, - PatternRewriter &rewriter) const override { - // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat - // ConstantOp. - Value sourceVector = extractStridedSliceOp.getVector(); - Attribute vectorCst; - if (!matchPattern(sourceVector, m_Constant(&vectorCst))) - return failure(); - - // The splat case is handled by `StridedSliceSplatConstantFolder`. - auto dense = llvm::dyn_cast(vectorCst); - if (!dense || dense.isSplat()) - return failure(); - - // TODO: Handle non-unit strides when they become available. - if (extractStridedSliceOp.hasNonUnitStrides()) - return failure(); - - auto sourceVecTy = llvm::cast(sourceVector.getType()); - ArrayRef sourceShape = sourceVecTy.getShape(); - SmallVector sourceStrides = computeStrides(sourceShape); - - VectorType sliceVecTy = extractStridedSliceOp.getType(); - ArrayRef sliceShape = sliceVecTy.getShape(); - int64_t sliceRank = sliceVecTy.getRank(); - - // Expand offsets and sizes to match the vector rank. - SmallVector offsets(sliceRank, 0); - copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin()); - - SmallVector sizes(sourceShape); - copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin()); - - // Calculate the slice elements by enumerating all slice positions and - // linearizing them. The enumeration order is lexicographic which yields a - // sequence of monotonically increasing linearized position indices. - auto denseValuesBegin = dense.value_begin(); - SmallVector sliceValues; - sliceValues.reserve(sliceVecTy.getNumElements()); - SmallVector currSlicePosition(offsets.begin(), offsets.end()); - do { - int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides); - assert(linearizedPosition < sourceVecTy.getNumElements() && - "Invalid index"); - sliceValues.push_back(*(denseValuesBegin + linearizedPosition)); - } while ( - succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets))); - - assert(static_cast(sliceValues.size()) == - sliceVecTy.getNumElements() && - "Invalid number of slice elements"); - auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues); - rewriter.replaceOpWithNewOp(extractStridedSliceOp, - newAttr); - return success(); - } -}; - // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to // BroadcastOp(ExtractStrideSliceOp). class StridedSliceBroadcast final @@ -4016,8 +3977,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. - results.add( context); } @@ -5657,10 +5617,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { // shape_cast(constant) -> constant if (auto splatAttr = - llvm::dyn_cast_if_present(adaptor.getSource())) { - return DenseElementsAttr::get(resultType, - splatAttr.getSplatValue()); - } + llvm::dyn_cast_if_present(adaptor.getSource())) + return splatAttr.reshape(getType()); // shape_cast(poison) -> poison if (llvm::dyn_cast_if_present(adaptor.getSource())) { @@ -6004,10 +5962,9 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { // Eliminate splat constant transpose ops. - if (auto attr = - llvm::dyn_cast_if_present(adaptor.getVector())) - if (attr.isSplat()) - return attr.reshape(getResultVectorType()); + if (auto splat = + llvm::dyn_cast_if_present(adaptor.getVector())) + return splat.reshape(getResultVectorType()); // Eliminate identity transpose ops. This happens when the dimensions of the // input vector remain in their original order after the transpose operation. diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 78b0ea78849e8..6556df22e069b 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1121,6 +1121,8 @@ func.func @bitcast_folding(%I1: vector<4x8xf32>, %I2: vector<2xi32>) -> (vector< return %0, %2 : vector<4x8xf32>, vector<2xi32> } +// ----- + // CHECK-LABEL: func @bitcast_f16_to_f32 // bit pattern: 0x40004000 // CHECK-DAG: %[[CST1:.+]] = arith.constant dense<2.00390625> : vector<4xf32> @@ -1135,6 +1137,8 @@ func.func @bitcast_f16_to_f32() -> (vector<4xf32>, vector<4xf32>) { return %cast0, %cast1: vector<4xf32>, vector<4xf32> } +// ----- + // CHECK-LABEL: func @bitcast_i8_to_i32 // bit pattern: 0xA0A0A0A0 // CHECK-DAG: %[[CST1:.+]] = arith.constant dense<-1600085856> : vector<4xi32> @@ -1710,6 +1714,7 @@ func.func @vector_multi_reduction_unit_dimensions(%source: vector<5x1x4x1x20xf32 } // ----- + // CHECK-LABEL: func.func @vector_multi_reduction_scalable( // CHECK-SAME: %[[VAL_0:.*]]: vector<1x[4]x1xf32>, // CHECK-SAME: %[[VAL_1:.*]]: vector<1x[4]xf32>, @@ -2251,6 +2256,8 @@ func.func @transpose_splat_constant() -> vector<8x4xf32> { return %0 : vector<8x4xf32> } +// ----- + // CHECK-LABEL: func @transpose_splat2( // CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> { // CHECK: %[[VAL_1:.*]] = vector.splat %[[VAL_0]] : vector<3x4xf32> From c0cebd2ecad27414d72f619e932c18c2511cba3e Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 15 Apr 2025 00:26:26 -0700 Subject: [PATCH 2/4] tidy --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 18dbd1167995e..1eb6daa402422 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3718,32 +3718,31 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) { if (succeeded(foldExtractStridedOpFromInsertChain(*this))) return getResult(); + // All subsequent successful folds require a constant input. Attribute foldInput = adaptor.getVector(); - if (!foldInput) { + if (!foldInput) return {}; - } - // rewrite : ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. + // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. if (auto splat = llvm::dyn_cast(foldInput)) DenseElementsAttr::get(getType(), splat.getSplatValue()); - // rewrite ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp. + // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp. if (auto dense = llvm::dyn_cast(foldInput)) { // TODO: Handle non-unit strides when they become available. if (hasNonUnitStrides()) return {}; - Value sourceVector = getVector(); - auto sourceVecTy = llvm::cast(sourceVector.getType()); + VectorType sourceVecTy = getSourceVectorType(); ArrayRef sourceShape = sourceVecTy.getShape(); SmallVector sourceStrides = computeStrides(sourceShape); VectorType sliceVecTy = getType(); ArrayRef sliceShape = sliceVecTy.getShape(); - int64_t sliceRank = sliceVecTy.getRank(); + int64_t rank = sliceVecTy.getRank(); // Expand offsets and sizes to match the vector rank. - SmallVector offsets(sliceRank, 0); + SmallVector offsets(rank, 0); copy(getI64SubArray(getOffsets()), offsets.begin()); SmallVector sizes(sourceShape); @@ -3752,7 +3751,7 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) { // Calculate the slice elements by enumerating all slice positions and // linearizing them. The enumeration order is lexicographic which yields a // sequence of monotonically increasing linearized position indices. - auto denseValuesBegin = dense.value_begin(); + const auto denseValuesBegin = dense.value_begin(); SmallVector sliceValues; sliceValues.reserve(sliceVecTy.getNumElements()); SmallVector currSlicePosition(offsets.begin(), offsets.end()); From 788f44cf917a4b17bcaa8cacaf7fef1786083bf6 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 15 Apr 2025 16:35:34 -0700 Subject: [PATCH 3/4] factorize to function --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 101 ++++++++++++----------- 1 file changed, 53 insertions(+), 48 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 1eb6daa402422..29d9d1f0f50ae 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3712,64 +3712,69 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { return failure(); } +namespace { + +// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp. +OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, + Attribute foldInput) { + + auto dense = llvm::dyn_cast_if_present(foldInput); + if (!dense) + return {}; + + // TODO: Handle non-unit strides when they become available. + if (op.hasNonUnitStrides()) + return {}; + + VectorType sourceVecTy = op.getSourceVectorType(); + ArrayRef sourceShape = sourceVecTy.getShape(); + SmallVector sourceStrides = computeStrides(sourceShape); + + VectorType sliceVecTy = op.getType(); + ArrayRef sliceShape = sliceVecTy.getShape(); + int64_t rank = sliceVecTy.getRank(); + + // Expand offsets and sizes to match the vector rank. + SmallVector offsets(rank, 0); + copy(getI64SubArray(op.getOffsets()), offsets.begin()); + + SmallVector sizes(sourceShape); + copy(getI64SubArray(op.getSizes()), sizes.begin()); + + // Calculate the slice elements by enumerating all slice positions and + // linearizing them. The enumeration order is lexicographic which yields a + // sequence of monotonically increasing linearized position indices. + const auto denseValuesBegin = dense.value_begin(); + SmallVector sliceValues; + sliceValues.reserve(sliceVecTy.getNumElements()); + SmallVector currSlicePosition(offsets.begin(), offsets.end()); + do { + int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides); + assert(linearizedPosition < sourceVecTy.getNumElements() && + "Invalid index"); + sliceValues.push_back(*(denseValuesBegin + linearizedPosition)); + } while (succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets))); + + assert(static_cast(sliceValues.size()) == + sliceVecTy.getNumElements() && + "Invalid number of slice elements"); + return DenseElementsAttr::get(sliceVecTy, sliceValues); +} +} // namespace + OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) { if (getSourceVectorType() == getResult().getType()) return getVector(); if (succeeded(foldExtractStridedOpFromInsertChain(*this))) return getResult(); - // All subsequent successful folds require a constant input. - Attribute foldInput = adaptor.getVector(); - if (!foldInput) - return {}; - // ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. - if (auto splat = llvm::dyn_cast(foldInput)) + if (auto splat = + llvm::dyn_cast_if_present(adaptor.getVector())) DenseElementsAttr::get(getType(), splat.getSplatValue()); // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp. - if (auto dense = llvm::dyn_cast(foldInput)) { - // TODO: Handle non-unit strides when they become available. - if (hasNonUnitStrides()) - return {}; - - VectorType sourceVecTy = getSourceVectorType(); - ArrayRef sourceShape = sourceVecTy.getShape(); - SmallVector sourceStrides = computeStrides(sourceShape); - - VectorType sliceVecTy = getType(); - ArrayRef sliceShape = sliceVecTy.getShape(); - int64_t rank = sliceVecTy.getRank(); - - // Expand offsets and sizes to match the vector rank. - SmallVector offsets(rank, 0); - copy(getI64SubArray(getOffsets()), offsets.begin()); - - SmallVector sizes(sourceShape); - copy(getI64SubArray(getSizes()), sizes.begin()); - - // Calculate the slice elements by enumerating all slice positions and - // linearizing them. The enumeration order is lexicographic which yields a - // sequence of monotonically increasing linearized position indices. - const auto denseValuesBegin = dense.value_begin(); - SmallVector sliceValues; - sliceValues.reserve(sliceVecTy.getNumElements()); - SmallVector currSlicePosition(offsets.begin(), offsets.end()); - do { - int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides); - assert(linearizedPosition < sourceVecTy.getNumElements() && - "Invalid index"); - sliceValues.push_back(*(denseValuesBegin + linearizedPosition)); - } while ( - succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets))); - - assert(static_cast(sliceValues.size()) == - sliceVecTy.getNumElements() && - "Invalid number of slice elements"); - return DenseElementsAttr::get(sliceVecTy, sliceValues); - } - - return {}; + return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getVector()); } void ExtractStridedSliceOp::getOffsets(SmallVectorImpl &results) { From fc62a9b2445a581246fe1c14f602ad8041fe0ee4 Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 16 Apr 2025 10:17:58 -0700 Subject: [PATCH 4/4] prefer static to anonymous namespace in this file --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 29d9d1f0f50ae..bbe222e72bf24 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3712,11 +3712,10 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { return failure(); } -namespace { - // ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp. -OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, - Attribute foldInput) { +static OpFoldResult +foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, + Attribute foldInput) { auto dense = llvm::dyn_cast_if_present(foldInput); if (!dense) @@ -3760,7 +3759,6 @@ OpFoldResult foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op, "Invalid number of slice elements"); return DenseElementsAttr::get(sliceVecTy, sliceValues); } -} // namespace OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) { if (getSourceVectorType() == getResult().getType())