From 67649194b893a9a017082964d285056f4c6656fa Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Mon, 14 Oct 2024 13:29:42 -0500 Subject: [PATCH 1/6] [mlir] Fold expand of cast Sink tensor.cast op through tensor.expand_shape ops when it makes the expand op more static. This allows for other ops further down infer their shapes. --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 31 +++++++++++++++++++++- mlir/test/Dialect/Tensor/canonicalize.mlir | 14 ++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 4d6c5965c4fcc..9be647f687e60 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1982,6 +1982,35 @@ struct FoldDimOfCollapseShape : public OpRewritePattern { return success(); } }; + +struct FoldExpandOfCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + auto castOp = expandOp.getSrc().getDefiningOp(); + if (!canFoldIntoConsumerOp(castOp)) + return failure(); + + SmallVector outputOfr = + getMixedValues(expandOp.getResultType().getShape(), + expandOp.getOutputShape(), rewriter); + std::optional> constantOutputShape = + getConstantIntValues(outputOfr); + if (!constantOutputShape.has_value()) { + return failure(); + } + auto newType = RankedTensorType::get( + constantOutputShape.value(), expandOp.getSrcType().getElementType()); + + auto newExpand = rewriter.create( + castOp.getLoc(), newType, castOp.getSource(), + expandOp.getReassociationIndices()); + rewriter.replaceOpWithNewOp(expandOp, expandOp.getType(), + newExpand.getResult()); + return success(); + } +}; } // namespace void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -1989,7 +2018,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add< ComposeReassociativeReshapeOps, ComposeExpandOfCollapseOp, - FoldReshapeWithConstant, + FoldExpandOfCast, FoldReshapeWithConstant, FoldReshapeWithSplat, FoldReshapeWithFromElements, FoldDimOfExpandShape, FoldDimOfCollapseShape>(context); diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 0aa2d33ef17ed..1509d26151119 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2718,3 +2718,17 @@ func.func @pack_dont_drop_attributes(%arg0: tensor, %arg1: tensor<128 %pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor -> tensor<128x?x100x16x1xf16> return %pack : tensor<128x?x100x16x1xf16> } + +// ----- + +func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>) + -> tensor { + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func.func @fold_expand_of_cast +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10] From 3f4c7bb63fc16dcfa809ae917d039c25782c7cb9 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Tue, 15 Oct 2024 02:19:37 +0000 Subject: [PATCH 2/6] Convert to static expand_shape When output_sizes can be determined, convert to a static expand_shape op and insert cast ops. The top cast will be (dynamic -> static) allowing it to be propagated upwards and the bottom will be (static -> dynamic) allowing it to propagate down (or cancel with adjacent tensor.cast ops). [skip ci] --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 61 ++++++++++++++++------ mlir/test/Dialect/Tensor/canonicalize.mlir | 14 +++++ 2 files changed, 60 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 9be647f687e60..96384385b6a06 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1983,29 +1983,60 @@ struct FoldDimOfCollapseShape : public OpRewritePattern { } }; -struct FoldExpandOfCast : public OpRewritePattern { +struct ConvertToStaticExpandShape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExpandShapeOp expandOp, PatternRewriter &rewriter) const override { - auto castOp = expandOp.getSrc().getDefiningOp(); - if (!canFoldIntoConsumerOp(castOp)) - return failure(); + SmallVector newOutputShape(expandOp.getResultType().getShape()); + SmallVector dynamicOutputShape; + auto outputIt = expandOp.getOutputShape().begin(); + for (auto [i, staticShape] : llvm::enumerate(newOutputShape)) { + if (!ShapedType::isDynamic(staticShape)) + continue; - SmallVector outputOfr = - getMixedValues(expandOp.getResultType().getShape(), - expandOp.getOutputShape(), rewriter); - std::optional> constantOutputShape = - getConstantIntValues(outputOfr); - if (!constantOutputShape.has_value()) { + APInt cst; + Value val = *outputIt; + ++outputIt; + if (matchPattern(val, m_ConstantInt(&cst))) { + newOutputShape[i] = cst.getSExtValue(); + } else { + dynamicOutputShape.push_back(val); + } + } + + // Couldn't match any values, nothing to change + if (expandOp.getOutputShape().size() == dynamicOutputShape.size()) return failure(); + + // Calculate the input shape from the output + SmallVector reassoc = + expandOp.getReassociationIndices(); + SmallVector newInputShape(expandOp.getSrcType().getRank(), 1l); + for (uint64_t inDim = 0; inDim < newInputShape.size(); inDim++) { + for (auto outDim : reassoc[inDim]) { + auto ofr = newOutputShape[outDim]; + if (ShapedType::isDynamic(ofr)) { + newInputShape[inDim] = ShapedType::kDynamic; + break; + } + newInputShape[inDim] *= ofr; + } } - auto newType = RankedTensorType::get( - constantOutputShape.value(), expandOp.getSrcType().getElementType()); + // `inputCast` can be propagated up and the final cast can be propagated + // down. + SmallVector outputOfr = + getMixedValues(newOutputShape, dynamicOutputShape, rewriter); + auto inputType = RankedTensorType::get( + newInputShape, expandOp.getSrcType().getElementType()); + auto outputType = RankedTensorType::get( + newOutputShape, expandOp.getSrcType().getElementType()); + auto inputCast = rewriter.create(expandOp.getLoc(), inputType, + expandOp.getSrc()); auto newExpand = rewriter.create( - castOp.getLoc(), newType, castOp.getSource(), - expandOp.getReassociationIndices()); + expandOp.getLoc(), outputType, inputCast.getResult(), + expandOp.getReassociationIndices(), outputOfr); rewriter.replaceOpWithNewOp(expandOp, expandOp.getType(), newExpand.getResult()); return success(); @@ -2018,7 +2049,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add< ComposeReassociativeReshapeOps, ComposeExpandOfCollapseOp, - FoldExpandOfCast, FoldReshapeWithConstant, + ConvertToStaticExpandShape, FoldReshapeWithConstant, FoldReshapeWithSplat, FoldReshapeWithFromElements, FoldDimOfExpandShape, FoldDimOfCollapseShape>(context); diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 1509d26151119..52dcfd1d427d9 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2732,3 +2732,17 @@ func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>) } // CHECK-LABEL: func.func @fold_expand_of_cast // CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10] + +// ----- + +func.func @fold_expand_of_cast_dynamic(%arg0 : tensor) + -> tensor { + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %0 = tensor.cast %arg0 : tensor to tensor + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func.func @fold_expand_of_cast_dynamic +// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10] From 4e16a9764cc0b125d3b851fd077865fe50b62003 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Tue, 15 Oct 2024 21:58:24 +0000 Subject: [PATCH 3/6] Redo logic to ensure cast gets folded --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 46 +++++++++++++++------- mlir/test/Dialect/Tensor/canonicalize.mlir | 38 +++++++++++++++--- 2 files changed, 64 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 96384385b6a06..ee0e8c2d20122 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -1988,20 +1989,41 @@ struct ConvertToStaticExpandShape : public OpRewritePattern { LogicalResult matchAndRewrite(ExpandShapeOp expandOp, PatternRewriter &rewriter) const override { + auto castOp = expandOp.getSrc().getDefiningOp(); + if (!canFoldIntoConsumerOp(castOp)) + return failure(); + + const ArrayRef castSrcShape = + castOp.getSource().getType().getShape(); + const SmallVector reassoc = + expandOp.getReassociationIndices(); + SmallVector newOutputShape(expandOp.getResultType().getShape()); SmallVector dynamicOutputShape; auto outputIt = expandOp.getOutputShape().begin(); - for (auto [i, staticShape] : llvm::enumerate(newOutputShape)) { - if (!ShapedType::isDynamic(staticShape)) - continue; - APInt cst; - Value val = *outputIt; - ++outputIt; - if (matchPattern(val, m_ConstantInt(&cst))) { - newOutputShape[i] = cst.getSExtValue(); - } else { - dynamicOutputShape.push_back(val); + for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) { + for (const uint64_t outDim : innerReassoc) { + if (!ShapedType::isDynamic(newOutputShape[outDim])) + continue; + + // If the cast's src type is dynamic, don't infer any of the + // corresponding expanded dimensions. `tensor.expand_shape` requires at + // least one of the expanded dimensions to be dynamic if the input is + // dynamic. + Value val = *outputIt; + ++outputIt; + if (ShapedType::isDynamic(castSrcShape[inputDim])) { + dynamicOutputShape.push_back(val); + continue; + } + + APInt cst; + if (matchPattern(val, m_ConstantInt(&cst))) { + newOutputShape[outDim] = cst.getSExtValue(); + } else { + dynamicOutputShape.push_back(val); + } } } @@ -2010,8 +2032,6 @@ struct ConvertToStaticExpandShape : public OpRewritePattern { return failure(); // Calculate the input shape from the output - SmallVector reassoc = - expandOp.getReassociationIndices(); SmallVector newInputShape(expandOp.getSrcType().getRank(), 1l); for (uint64_t inDim = 0; inDim < newInputShape.size(); inDim++) { for (auto outDim : reassoc[inDim]) { @@ -2024,8 +2044,6 @@ struct ConvertToStaticExpandShape : public OpRewritePattern { } } - // `inputCast` can be propagated up and the final cast can be propagated - // down. SmallVector outputOfr = getMixedValues(newOutputShape, dynamicOutputShape, rewriter); auto inputType = RankedTensorType::get( diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 52dcfd1d427d9..63f394a14d389 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2722,20 +2722,22 @@ func.func @pack_dont_drop_attributes(%arg0: tensor, %arg1: tensor<128 // ----- func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>) - -> tensor { + -> tensor<10x1x10xf32> { %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10] : tensor into tensor - return %1 : tensor + %2 = tensor.cast %1 : tensor to tensor<10x1x10xf32> + return %2 : tensor<10x1x10xf32> } // CHECK-LABEL: func.func @fold_expand_of_cast -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10] +// CHECK: %[[RES:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10] +// CHECK: return %[[RES]] // ----- -func.func @fold_expand_of_cast_dynamic(%arg0 : tensor) +func.func @sink_expand_of_cast(%arg0 : tensor) -> tensor { %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index @@ -2744,5 +2746,29 @@ func.func @fold_expand_of_cast_dynamic(%arg0 : tensor) : tensor into tensor return %1 : tensor } -// CHECK-LABEL: func.func @fold_expand_of_cast_dynamic -// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10] +// CHECK-LABEL: func.func @sink_expand_of_cast +// CHECK-DAG: %[[C10:.*]] = arith.constant 10 +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: output_shape [%[[C10]], %[[C1]], 10] +// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]] +// CHECK: return %[[RES]] + +// ----- + +func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, %arg2 : index) + -> tensor { + %c10 = arith.constant 10 : index + %0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor + %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %c10] + : tensor into tensor + return %1 : tensor +} +// CHECK-LABEL: func.func @partial_sink_expand_of_cast +// CHECK: %[[CAST:.+]] = tensor.cast +// CHECK-SAME: tensor<10x10xf32> to tensor +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] +// CHECK-SAME: output_shape [%{{.*}}, %{{.*}}, 10] +// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]] +// CHECK-SAME: tensor to tensor +// CHECK: return %[[RES]] From c6e1139536a622ab2b46f98a85336db4cc7fa404 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Tue, 22 Oct 2024 16:16:56 +0000 Subject: [PATCH 4/6] Drop const qualifier --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index ee0e8c2d20122..c7a675733311a 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1993,9 +1993,8 @@ struct ConvertToStaticExpandShape : public OpRewritePattern { if (!canFoldIntoConsumerOp(castOp)) return failure(); - const ArrayRef castSrcShape = - castOp.getSource().getType().getShape(); - const SmallVector reassoc = + ArrayRef castSrcShape = castOp.getSource().getType().getShape(); + SmallVector reassoc = expandOp.getReassociationIndices(); SmallVector newOutputShape(expandOp.getResultType().getShape()); @@ -2003,7 +2002,7 @@ struct ConvertToStaticExpandShape : public OpRewritePattern { auto outputIt = expandOp.getOutputShape().begin(); for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) { - for (const uint64_t outDim : innerReassoc) { + for (uint64_t outDim : innerReassoc) { if (!ShapedType::isDynamic(newOutputShape[outDim])) continue; From 8cf255b1552413ec568d172644baf00b30fb64a4 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Tue, 22 Oct 2024 21:36:44 +0000 Subject: [PATCH 5/6] Add comment to rewrite pattern --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index c7a675733311a..6bf01b2ee1b9f 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1984,6 +1984,10 @@ struct FoldDimOfCollapseShape : public OpRewritePattern { } }; +/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by +/// matching constant output_shape operands of the expand. This makes the +/// `tensor.expand_shape` more static and creates a consumer cast that can be +/// propagated further. struct ConvertToStaticExpandShape : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; From f8670ebc77393d1c7b061dfe7ebb208eddf1c663 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Thu, 24 Oct 2024 20:37:13 +0000 Subject: [PATCH 6/6] Use llvm::seq --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 6bf01b2ee1b9f..9d12ebd307725 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2036,7 +2036,7 @@ struct ConvertToStaticExpandShape : public OpRewritePattern { // Calculate the input shape from the output SmallVector newInputShape(expandOp.getSrcType().getRank(), 1l); - for (uint64_t inDim = 0; inDim < newInputShape.size(); inDim++) { + for (auto inDim : llvm::seq(0, newInputShape.size())) { for (auto outDim : reassoc[inDim]) { auto ofr = newOutputShape[outDim]; if (ShapedType::isDynamic(ofr)) {