|
11 | 11 | #include "iree/compiler/Codegen/Utils/Utils.h" |
12 | 12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
13 | 13 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 14 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 15 | +#include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 16 | +#include "mlir/IR/BuiltinTypes.h" |
14 | 17 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
15 | 18 |
|
16 | 19 | namespace mlir::iree_compiler { |
@@ -294,6 +297,76 @@ struct ExpandDestinationForallOp final |
294 | 297 | } |
295 | 298 | }; |
296 | 299 |
|
| 300 | +/// This pattern exchanges bitcast(extract_slice) to extract_slice(bitcast) in |
| 301 | +/// an attempt to move the bitcast closer to the loads. There is a related |
| 302 | +/// pattern that does the reverse when folding the bitcast is not possible and |
| 303 | +/// should be applied later. |
| 304 | +struct SwapInnerBitcastWithExtractSlice |
| 305 | + : OpRewritePattern<IREE::TensorExt::BitCastOp> { |
| 306 | + using OpRewritePattern::OpRewritePattern; |
| 307 | + |
| 308 | + LogicalResult matchAndRewrite(IREE::TensorExt::BitCastOp bitcastOp, |
| 309 | + PatternRewriter &rewriter) const override { |
| 310 | + Value bitcastSrc = bitcastOp.getSource(); |
| 311 | + auto sliceOp = bitcastSrc.getDefiningOp<tensor::ExtractSliceOp>(); |
| 312 | + if (!sliceOp) { |
| 313 | + return rewriter.notifyMatchFailure(bitcastOp, "non-slice producer"); |
| 314 | + } |
| 315 | + |
| 316 | + auto bitcastSrcType = cast<RankedTensorType>(bitcastSrc.getType()); |
| 317 | + auto bitcastResType = cast<RankedTensorType>(bitcastOp.getType()); |
| 318 | + |
| 319 | + // Verify that only the inner most dimension is changed by the bitcast by |
| 320 | + // comparing dynamic and static sizes for equality. |
| 321 | + if (bitcastOp.getSourceDims() != bitcastOp.getResultDims() || |
| 322 | + bitcastSrcType.getShape().drop_back() != |
| 323 | + bitcastResType.getShape().drop_back() || |
| 324 | + ShapedType::isDynamic(bitcastSrcType.getShape().back())) { |
| 325 | + return rewriter.notifyMatchFailure( |
| 326 | + bitcastOp, "bitcast affects more than inner most dim"); |
| 327 | + } |
| 328 | + |
| 329 | + // Fail if the inner most dim is sliced or if this is an encoded tensor. |
| 330 | + RankedTensorType sliceInputType = sliceOp.getSource().getType(); |
| 331 | + if (sliceInputType.getEncoding() || |
| 332 | + sliceInputType.getRank() != bitcastSrcType.getRank() || |
| 333 | + sliceInputType.getShape().back() != bitcastSrcType.getShape().back()) { |
| 334 | + return rewriter.notifyMatchFailure( |
| 335 | + bitcastOp, |
| 336 | + "inner dimension is sliced or rank reducing or tensor is encoded"); |
| 337 | + } |
| 338 | + |
| 339 | + int64_t newInnerSize = bitcastResType.getShape().back(); |
| 340 | + SmallVector<int64_t> newBitcastShape(sliceInputType.getShape()); |
| 341 | + newBitcastShape.back() = newInnerSize; |
| 342 | + |
| 343 | + auto newBitcastType = |
| 344 | + RankedTensorType::get(newBitcastShape, bitcastResType.getElementType()); |
| 345 | + |
| 346 | + // Get the dynamic sizes of the slice source. Extracting a slice can remove |
| 347 | + // dynamic dimensions or introduce new ones, so a new list of sizes is |
| 348 | + // needed. |
| 349 | + SmallVector<OpFoldResult> newMixedSizes = |
| 350 | + tensor::getMixedSizes(rewriter, sliceOp.getLoc(), sliceOp.getSource()); |
| 351 | + SmallVector<Value> sliceSourceDynamicSizes; |
| 352 | + SmallVector<int64_t> sliceSourceStaticSizes; |
| 353 | + dispatchIndexOpFoldResults(newMixedSizes, sliceSourceDynamicSizes, |
| 354 | + sliceSourceStaticSizes); |
| 355 | + |
| 356 | + Value newBitcast = rewriter.create<IREE::TensorExt::BitCastOp>( |
| 357 | + bitcastOp.getLoc(), newBitcastType, sliceOp.getSource(), |
| 358 | + sliceSourceDynamicSizes, sliceSourceDynamicSizes); |
| 359 | + SmallVector<int64_t> newSizes(sliceOp.getStaticSizes()); |
| 360 | + newSizes.back() = newInnerSize; |
| 361 | + rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>( |
| 362 | + bitcastOp, bitcastResType, newBitcast, sliceOp.getOffsets(), |
| 363 | + sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), |
| 364 | + newSizes, sliceOp.getStaticStrides()); |
| 365 | + |
| 366 | + return success(); |
| 367 | + } |
| 368 | +}; |
| 369 | + |
297 | 370 | struct PropagateReshapesByExpansionPass final |
298 | 371 | : impl::PropagateReshapesByExpansionPassBase< |
299 | 372 | PropagateReshapesByExpansionPass> { |
@@ -341,7 +414,9 @@ void PropagateReshapesByExpansionPass::runOnOperation() { |
341 | 414 | tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns, |
342 | 415 | context); |
343 | 416 | populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns); |
344 | | - bubbleExpandShapePatterns.add<ExpandDestinationForallOp>(context); |
| 417 | + bubbleExpandShapePatterns |
| 418 | + .add<ExpandDestinationForallOp, SwapInnerBitcastWithExtractSlice>( |
| 419 | + context); |
345 | 420 |
|
346 | 421 | if (failed(applyPatternsGreedily(getOperation(), |
347 | 422 | std::move(bubbleExpandShapePatterns)))) { |
|
0 commit comments