Skip to content

Commit a7bab8c

Browse files
qedawkinskeshavvinayak01
authored andcommitted
[Codegen] Add pattern to bubble bitcast past extract_slice (iree-org#21518)
It's easiest to handle bitcasts if we're able to fold them into input bindings. Since we'll want an analogous pattern to "fuse" bitcasts when they aren't foldable (and do them at the vector level) this is only run when propagating reshapes by expansion earlier on. Signed-off-by: keshavvinayak01 <[email protected]>
1 parent 3f37e32 commit a7bab8c

File tree

2 files changed

+160
-1
lines changed

2 files changed

+160
-1
lines changed

compiler/src/iree/compiler/Codegen/Common/PropagateReshapesByExpansion.cpp

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
#include "iree/compiler/Codegen/Utils/Utils.h"
1212
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1313
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
14+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
15+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
16+
#include "mlir/IR/BuiltinTypes.h"
1417
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1518

1619
namespace mlir::iree_compiler {
@@ -294,6 +297,76 @@ struct ExpandDestinationForallOp final
294297
}
295298
};
296299

300+
/// This pattern exchanges bitcast(extract_slice) to extract_slice(bitcast) in
301+
/// an attempt to move the bitcast closer to the loads. There is a related
302+
/// pattern that does the reverse when folding the bitcast is not possible and
303+
/// should be applied later.
304+
struct SwapInnerBitcastWithExtractSlice
305+
: OpRewritePattern<IREE::TensorExt::BitCastOp> {
306+
using OpRewritePattern::OpRewritePattern;
307+
308+
LogicalResult matchAndRewrite(IREE::TensorExt::BitCastOp bitcastOp,
309+
PatternRewriter &rewriter) const override {
310+
Value bitcastSrc = bitcastOp.getSource();
311+
auto sliceOp = bitcastSrc.getDefiningOp<tensor::ExtractSliceOp>();
312+
if (!sliceOp) {
313+
return rewriter.notifyMatchFailure(bitcastOp, "non-slice producer");
314+
}
315+
316+
auto bitcastSrcType = cast<RankedTensorType>(bitcastSrc.getType());
317+
auto bitcastResType = cast<RankedTensorType>(bitcastOp.getType());
318+
319+
// Verify that only the inner most dimension is changed by the bitcast by
320+
// comparing dynamic and static sizes for equality.
321+
if (bitcastOp.getSourceDims() != bitcastOp.getResultDims() ||
322+
bitcastSrcType.getShape().drop_back() !=
323+
bitcastResType.getShape().drop_back() ||
324+
ShapedType::isDynamic(bitcastSrcType.getShape().back())) {
325+
return rewriter.notifyMatchFailure(
326+
bitcastOp, "bitcast affects more than inner most dim");
327+
}
328+
329+
// Fail if the inner most dim is sliced or if this is an encoded tensor.
330+
RankedTensorType sliceInputType = sliceOp.getSource().getType();
331+
if (sliceInputType.getEncoding() ||
332+
sliceInputType.getRank() != bitcastSrcType.getRank() ||
333+
sliceInputType.getShape().back() != bitcastSrcType.getShape().back()) {
334+
return rewriter.notifyMatchFailure(
335+
bitcastOp,
336+
"inner dimension is sliced or rank reducing or tensor is encoded");
337+
}
338+
339+
int64_t newInnerSize = bitcastResType.getShape().back();
340+
SmallVector<int64_t> newBitcastShape(sliceInputType.getShape());
341+
newBitcastShape.back() = newInnerSize;
342+
343+
auto newBitcastType =
344+
RankedTensorType::get(newBitcastShape, bitcastResType.getElementType());
345+
346+
// Get the dynamic sizes of the slice source. Extracting a slice can remove
347+
// dynamic dimensions or introduce new ones, so a new list of sizes is
348+
// needed.
349+
SmallVector<OpFoldResult> newMixedSizes =
350+
tensor::getMixedSizes(rewriter, sliceOp.getLoc(), sliceOp.getSource());
351+
SmallVector<Value> sliceSourceDynamicSizes;
352+
SmallVector<int64_t> sliceSourceStaticSizes;
353+
dispatchIndexOpFoldResults(newMixedSizes, sliceSourceDynamicSizes,
354+
sliceSourceStaticSizes);
355+
356+
Value newBitcast = rewriter.create<IREE::TensorExt::BitCastOp>(
357+
bitcastOp.getLoc(), newBitcastType, sliceOp.getSource(),
358+
sliceSourceDynamicSizes, sliceSourceDynamicSizes);
359+
SmallVector<int64_t> newSizes(sliceOp.getStaticSizes());
360+
newSizes.back() = newInnerSize;
361+
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
362+
bitcastOp, bitcastResType, newBitcast, sliceOp.getOffsets(),
363+
sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
364+
newSizes, sliceOp.getStaticStrides());
365+
366+
return success();
367+
}
368+
};
369+
297370
struct PropagateReshapesByExpansionPass final
298371
: impl::PropagateReshapesByExpansionPassBase<
299372
PropagateReshapesByExpansionPass> {
@@ -341,7 +414,9 @@ void PropagateReshapesByExpansionPass::runOnOperation() {
341414
tensor::ExpandShapeOp::getCanonicalizationPatterns(bubbleExpandShapePatterns,
342415
context);
343416
populateReshapeToInterfaceTensorPatterns(bubbleExpandShapePatterns);
344-
bubbleExpandShapePatterns.add<ExpandDestinationForallOp>(context);
417+
bubbleExpandShapePatterns
418+
.add<ExpandDestinationForallOp, SwapInnerBitcastWithExtractSlice>(
419+
context);
345420

346421
if (failed(applyPatternsGreedily(getOperation(),
347422
std::move(bubbleExpandShapePatterns)))) {

compiler/src/iree/compiler/Codegen/Common/test/propagate_reshapes_by_expansion.mlir

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,87 @@ func.func @expand_dest_forall_no_crash_issue_20736(%arg0: tensor<16x8x48x32x3x96
341341
// CHECK: scf.forall
342342
// CHECK-NOT: tensor.collapse_shape
343343
// CHECK: tensor.parallel_insert_slice
344+
345+
// -----
346+
347+
func.func @swap_inner_bitcast(%arg0: tensor<3x6xi8>) -> tensor<2x3xi16> {
348+
%0 = tensor.extract_slice %arg0 [0, 0] [2, 6] [1, 1] : tensor<3x6xi8> to tensor<2x6xi8>
349+
%1 = iree_tensor_ext.bitcast %0 : tensor<2x6xi8> -> tensor<2x3xi16>
350+
return %1 : tensor<2x3xi16>
351+
}
352+
353+
// CHECK-LABEL: @swap_inner_bitcast
354+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x6xi8>
355+
// CHECK-NEXT: %[[BITCAST:.+]] = iree_tensor_ext.bitcast %[[ARG0]] : tensor<3x6xi8> -> tensor<3x3xi16>
356+
// CHECK-NEXT: %[[SLICE:.+]] = tensor.extract_slice %[[BITCAST]]{{.*}} : tensor<3x3xi16> to tensor<2x3xi16>
357+
// CHECK-NEXT: return %[[SLICE]]
358+
359+
// -----
360+
361+
func.func @no_swap_arbitrary_bitcast(%arg0: tensor<3x6xi8>) -> tensor<6xi16> {
362+
%0 = tensor.extract_slice %arg0 [0, 0] [2, 6] [1, 1] : tensor<3x6xi8> to tensor<2x6xi8>
363+
%1 = iree_tensor_ext.bitcast %0 : tensor<2x6xi8> -> tensor<6xi16>
364+
return %1 : tensor<6xi16>
365+
}
366+
367+
// CHECK-LABEL: @no_swap_arbitrary_bitcast
368+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x6xi8>
369+
// CHECK-NEXT: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
370+
// CHECK-NEXT: %[[BITCAST:.+]] = iree_tensor_ext.bitcast %[[SLICE]]
371+
// CHECK-NEXT: return %[[BITCAST]]
372+
373+
// -----
374+
375+
func.func @swap_inner_bitcast_dynamic_source(%arg0: tensor<?x6xi8>) -> tensor<2x3xi16> {
376+
%0 = tensor.extract_slice %arg0 [0, 0] [2, 6] [1, 1] : tensor<?x6xi8> to tensor<2x6xi8>
377+
%1 = iree_tensor_ext.bitcast %0 : tensor<2x6xi8> -> tensor<2x3xi16>
378+
return %1 : tensor<2x3xi16>
379+
}
380+
381+
// CHECK-LABEL: @swap_inner_bitcast_dynamic_source
382+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<?x6xi8>
383+
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %c0 : tensor<?x6xi8>
384+
// CHECK-NEXT: %[[BITCAST:.+]] = iree_tensor_ext.bitcast %[[ARG0]] : tensor<?x6xi8>{%[[DIM]]} -> tensor<?x3xi16>{%[[DIM]]}
385+
// CHECK-NEXT: %[[SLICE:.+]] = tensor.extract_slice %[[BITCAST]]{{.*}} : tensor<?x3xi16> to tensor<2x3xi16>
386+
// CHECK-NEXT: return %[[SLICE]]
387+
388+
// -----
389+
390+
func.func @swap_inner_bitcast_dynamic_result(%arg0: tensor<3x6xi8>, %arg1: index) -> tensor<?x3xi16> {
391+
%0 = tensor.extract_slice %arg0 [0, 0] [%arg1, 6] [1, 1] : tensor<3x6xi8> to tensor<?x6xi8>
392+
%1 = iree_tensor_ext.bitcast %0 : tensor<?x6xi8>{%arg1} -> tensor<?x3xi16>{%arg1}
393+
return %1 : tensor<?x3xi16>
394+
}
395+
396+
// CHECK-LABEL: @swap_inner_bitcast_dynamic_result
397+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x6xi8>
398+
// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: index
399+
// CHECK-NEXT: %[[BITCAST:.+]] = iree_tensor_ext.bitcast %[[ARG0]] : tensor<3x6xi8> -> tensor<3x3xi16>
400+
// CHECK-NEXT: %[[SLICE:.+]] = tensor.extract_slice %[[BITCAST]]{{.*}} : tensor<3x3xi16> to tensor<?x3xi16>
401+
// CHECK-NEXT: return %[[SLICE]]
402+
403+
// -----
404+
405+
func.func @no_swap_encoded_bitcast(%arg0: tensor<3x6xi8, 1>) -> tensor<2x3xi16, 1> {
406+
%0 = tensor.extract_slice %arg0 [0, 0] [2, 6] [1, 1] : tensor<3x6xi8, 1> to tensor<2x6xi8, 1>
407+
%1 = iree_tensor_ext.bitcast %0 : tensor<2x6xi8, 1> -> tensor<2x3xi16, 1>
408+
return %1 : tensor<2x3xi16, 1>
409+
}
410+
411+
// CHECK-LABEL: @no_swap_encoded_bitcast
412+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x6xi8, 1 : i64>
413+
// CHECK-NEXT: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
414+
// CHECK-NEXT: iree_tensor_ext.bitcast %[[SLICE]]
415+
416+
// -----
417+
418+
func.func @no_swap_rank_reducing_slice(%arg0: tensor<3x6xi8>) -> tensor<3xi16> {
419+
%0 = tensor.extract_slice %arg0 [0, 0] [1, 6] [1, 1] : tensor<3x6xi8> to tensor<6xi8>
420+
%1 = iree_tensor_ext.bitcast %0 : tensor<6xi8> -> tensor<3xi16>
421+
return %1 : tensor<3xi16>
422+
}
423+
424+
// CHECK-LABEL: @no_swap_rank_reducing_slice
425+
// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<3x6xi8>
426+
// CHECK-NEXT: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]]
427+
// CHECK-NEXT: iree_tensor_ext.bitcast %[[SLICE]]

0 commit comments

Comments
 (0)