-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][tensor] Introduce FoldTensorCastUnPackOp
#121393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4837,15 +4837,17 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> { | |
| // Already a constant | ||
| newMixedTileSizes.push_back(std::get<1>(it)); | ||
| } else { | ||
| int64_t tileSize = getConstantIntValue(std::get<1>(it)).value(); | ||
| assert(tileSize == shape && "tile size and dim size don't match!"); | ||
| (void)tileSize; | ||
| assert(getConstantIntValue(std::get<1>(it)).value() == shape && | ||
| "tile size and dim size don't match!"); | ||
| newMixedTileSizes.push_back( | ||
| (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); | ||
| } | ||
| } | ||
|
|
||
| // Clone op. | ||
| // TODO: Strictly speaking, discardable attributes should be _discarded_ at | ||
| // this point. However, in practice, we use them for things that we'd like | ||
| // to preserve. Implement a better abstraction. | ||
| PackOp newOp = rewriter.create<PackOp>( | ||
| op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(), | ||
| newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm()); | ||
|
|
@@ -4865,6 +4867,83 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> { | |
| } | ||
| }; | ||
|
|
||
| /// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the | ||
| /// `tensor.cast` has source that is more static than the consuming op. | ||
| /// | ||
| /// Example: | ||
| /// ```mlir | ||
| /// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> | ||
| /// %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32> | ||
| /// ``` | ||
| /// | ||
| /// folds into: | ||
| /// | ||
| /// ```mlir | ||
| /// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32> | ||
| /// ``` | ||
| struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> { | ||
| using OpRewritePattern<UnPackOp>::OpRewritePattern; | ||
|
Comment on lines
+4902
to
+4903
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Most of the logic in this function is the same as for tensor.pack, but with the source type instead of the dest type. Could you refactor the logic a bit to try to share code from a single function (mainly for finding the new mixed tile sizes)?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great point, sending update shortly. Thanks for the suggestion! |
||
|
|
||
| LogicalResult matchAndRewrite(UnPackOp op, | ||
| PatternRewriter &rewriter) const override { | ||
| if (!foldTensorCastPrecondition(op)) | ||
| return failure(); | ||
|
|
||
| SmallVector<Type> newResultTypes(op->getResultTypes()); | ||
| SmallVector<Value> newOperands = getNewOperands(op, newResultTypes); | ||
| Value sourceTensor = newOperands[0]; | ||
|
|
||
| // Get the updated mixed-tile-sizes attribute. | ||
| SmallVector<OpFoldResult> newMixedTileSizes; | ||
| for (auto it : llvm::zip(cast<ShapedType>(sourceTensor.getType()) | ||
| .getShape() | ||
| .take_back(op.getMixedTiles().size()), | ||
| op.getMixedTiles())) { | ||
| int64_t shape = std::get<0>(it); | ||
| // If the current source shape is dynamic, just preserve this mixed | ||
| // size. | ||
| if (shape == ShapedType::kDynamic) { | ||
| newMixedTileSizes.push_back(std::get<1>(it)); | ||
| continue; | ||
| } | ||
|
|
||
| // If the current source is static, update the dynamic mixed-size | ||
| // (provided the original value is dynamic). | ||
| if (Attribute attr = | ||
| llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) { | ||
|
||
| // Already a constant | ||
| newMixedTileSizes.push_back(std::get<1>(it)); | ||
| } else { | ||
| assert(getConstantIntValue(std::get<1>(it)).value() == shape && | ||
| "tile size and dim size don't match!"); | ||
| newMixedTileSizes.push_back( | ||
| (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); | ||
| } | ||
| } | ||
|
|
||
| // Clone op. | ||
| // TODO: Strictly speaking, discardable attributes should be _discarded_ at | ||
| // this point. However, in practice, we use them for things that we'd like | ||
| // to preserve. Implement a better abstraction. | ||
| UnPackOp newOp = rewriter.create<UnPackOp>( | ||
| op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(), | ||
| newMixedTileSizes, op.getOuterDimsPerm()); | ||
| newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); | ||
|
|
||
| // Replace op. | ||
| Value oldResult = op.getResult(); | ||
| Value newResult = newOp.getResult(); | ||
| Value replacement = (newResult.getType() != oldResult.getType()) | ||
| ? rewriter.create<tensor::CastOp>( | ||
| op->getLoc(), oldResult.getType(), newResult) | ||
| : newResult; | ||
|
|
||
| rewriter.replaceOp(op, {replacement}); | ||
|
|
||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if | ||
| /// the `tensor.cast` has source that is more static than the consuming op. | ||
| /// | ||
|
|
@@ -4890,7 +4969,8 @@ struct FoldTensorCastProducerOp | |
| PatternRewriter &rewriter) const override { | ||
|
|
||
| // Reject tensor::PackOp - there's dedicated pattern for that instead. | ||
| if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op)) | ||
| if (!foldTensorCastPrecondition(op) || | ||
| isa<tensor::PackOp, tensor::UnPackOp>(*op)) | ||
| return failure(); | ||
|
|
||
| SmallVector<Type> newResultTypes(op->getResultTypes()); | ||
|
|
@@ -4923,6 +5003,7 @@ struct FoldTensorCastProducerOp | |
| void TensorDialect::getCanonicalizationPatterns( | ||
| RewritePatternSet &results) const { | ||
| results.add<FoldTensorCastPackOp>(getContext()); | ||
| results.add<FoldTensorCastUnPackOp>(getContext()); | ||
| results.add<FoldTensorCastProducerOp>(getContext()); | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x | |||||||||||||||||||||||||||||||||||||||||||||||
| %0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index | ||||||||||||||||||||||||||||||||||||||||||||||||
| return %0#1 : index | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| // ----- | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| // CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2814,6 +2815,26 @@ func.func @fold_cast_pack_dynamic_tile_size( | |||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| // ----- | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| // CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size( | ||||||||||||||||||||||||||||||||||||||||||||||||
| // CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>, | ||||||||||||||||||||||||||||||||||||||||||||||||
| // CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> { | ||||||||||||||||||||||||||||||||||||||||||||||||
| // CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {some_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32> | ||||||||||||||||||||||||||||||||||||||||||||||||
| // CHECK: return %[[RES]] : tensor<7x?xi32> | ||||||||||||||||||||||||||||||||||||||||||||||||
| func.func @fold_cast_unpack_dynamic_tile_size( | ||||||||||||||||||||||||||||||||||||||||||||||||
| %src: tensor<1x1x8x1xi32>, | ||||||||||||||||||||||||||||||||||||||||||||||||
| %res: tensor<7x?xi32>) -> tensor<7x?xi32> { | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> | ||||||||||||||||||||||||||||||||||||||||||||||||
| %c8 = arith.constant 8 : index | ||||||||||||||||||||||||||||||||||||||||||||||||
| %unpack = tensor.unpack %cast | ||||||||||||||||||||||||||||||||||||||||||||||||
| inner_dims_pos = [0, 1] | ||||||||||||||||||||||||||||||||||||||||||||||||
| inner_tiles = [%c8, 1] | ||||||||||||||||||||||||||||||||||||||||||||||||
| into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32> | ||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||
| // CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size | |
| // CHECK-SAME: %[[DEST:.*]]: tensor<1x1x8x1xi32>, | |
| // CHECK-SAME: %[[SRC:.*]]: tensor<7x?xi32>, | |
| // CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> { | |
| // CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) | |
| // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] | |
| // CHECK-SAME: some_attr | |
| // CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32> | |
| // CHECK: return %[[PACK]] : tensor<1x1x8x1xi32> | |
| func.func @fold_cast_pack_dynamic_tile_size( | |
| %dest: tensor<1x1x8x1xi32>, | |
| %src: tensor<7x?xi32>, | |
| %pad: i32) -> tensor<1x1x8x1xi32> { | |
| %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> | |
| %c8 = arith.constant 8 : index | |
| %pack = tensor.pack %src padding_value(%pad : i32) | |
| inner_dims_pos = [0, 1] | |
| inner_tiles = [%c8, 1] | |
| into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32> | |
| %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32> | |
| return %res : tensor<1x1x8x1xi32> | |
| } |
😂 Let me unify this.
Uh oh!
There was an error while loading. Please reload this page.