Skip to content

Commit a073601

Browse files
authored
[VectorExt] Implement masked vectorization for iree_linalg_ext.gather (iree-org#21189)
Implements masked vectorization for iree_linalg_ext.gather to iree_vector_ext.transfer_gather. The masked vector shape can be inferred from the dps result.
1 parent 43b0edf commit a073601

File tree

10 files changed

+159
-37
lines changed

10 files changed

+159
-37
lines changed

compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ getVectorSizes(Operation *op, bool useConfiguredVectorSizes) {
7373
return;
7474
vectorSizes = SmallVector<int64_t>(ty.getShape());
7575
})
76+
.Case<IREE::LinalgExt::GatherOp>([&](IREE::LinalgExt::GatherOp gatherOp) {
77+
std::optional<VectorizationTileSizes> result =
78+
inferSizesFromIR(gatherOp.getOutput());
79+
if (result) {
80+
vectorSizes = result->vectorSizes;
81+
}
82+
})
7683
.Default([&](Operation *) {});
7784

7885
if (vectorSizes) {
@@ -168,8 +175,8 @@ void GenericVectorizationPass::runOnOperation() {
168175
rewriter, cast<linalg::GenericOp>(op), vectorSizes, scalableVecDims,
169176
vectorizeGatherAccesses);
170177
} else if (auto gatherOp = dyn_cast<IREE::LinalgExt::GatherOp>(op)) {
171-
(void)IREE::VectorExt::vectorizeLinalgExtGatherToTransferGather(rewriter,
172-
gatherOp);
178+
(void)IREE::VectorExt::vectorizeLinalgExtGatherToTransferGather(
179+
rewriter, gatherOp, vectorSizes);
173180
} else {
174181
FailureOr<linalg::VectorizationResult> result = linalg::vectorize(
175182
rewriter, op, vectorSizes, scalableVecDims, vectorizeGatherAccesses);
@@ -221,6 +228,7 @@ void GenericVectorizationPass::runOnOperation() {
221228
if (enableVectorMasking) {
222229
vector::populateVectorMaskLoweringPatternsForSideEffectingOps(
223230
vectorizationPatterns);
231+
IREE::VectorExt::populateVectorMaskLoweringPatterns(vectorizationPatterns);
224232
vectorizationPatterns.add<linalg::LinalgCopyVTRForwardingPattern,
225233
linalg::LinalgCopyVTWForwardingPattern>(
226234
funcOp.getContext(), /*benefit=*/2);

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ iree_td_library(
3434
deps = [
3535
"@llvm-project//mlir:BuiltinDialectTdFiles",
3636
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
37+
"@llvm-project//mlir:MaskableOpInterfaceTdFiles",
3738
"@llvm-project//mlir:OpBaseTdFiles",
3839
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
3940
"@llvm-project//mlir:VectorInterfacesTdFiles",

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.h"
1111
#include "mlir/Bytecode/BytecodeImplementation.h"
1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
1314
#include "mlir/IR/Dialect.h"
1415
#include "mlir/IR/OpDefinition.h"
1516
#include "mlir/Interfaces/SideEffectInterfaces.h"

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,12 @@ void TransferGatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
655655
ctx);
656656
}
657657

658+
// MaskableOpInterface methods.
659+
660+
Type TransferGatherOp::getExpectedMaskType() {
661+
return vector::inferTransferOpMaskType(getVectorType(), getPermutationMap());
662+
}
663+
658664
// clang-format off
659665
#define GET_OP_CLASSES
660666
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.cpp.inc" // IWYU pragma: keep

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "llvm/ADT/MapVector.h"
1313
#include "mlir/Bytecode/BytecodeImplementation.h"
1414
#include "mlir/Bytecode/BytecodeOpInterface.h"
15+
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
1516
#include "mlir/IR/Attributes.h"
1617
#include "mlir/IR/Builders.h"
1718
#include "mlir/IR/BuiltinTypes.h"

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
include "mlir/Interfaces/VectorInterfaces.td"
1111
include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtBase.td"
1212
include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.td"
13+
include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
1314
include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtInterfaces.td"
1415
include "mlir/IR/OpBase.td"
1516
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -118,6 +119,7 @@ def IREEVectorExt_TransferGatherOp : IREEVectorExt_PureOp<"transfer_gather", [
118119
DeclareOpInterfaceMethods<VectorTransferOpInterface>,
119120
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
120121
DeclareOpInterfaceMethods<ConditionallySpeculatable>,
122+
DeclareOpInterfaceMethods<MaskableOpInterface>,
121123
AttrSizedOperandSegments
122124
]> {
123125
let arguments = (ins AnyShaped:$base,

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/Transforms.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,27 @@ LogicalResult vectorizeGatherLikeGenericToTransferGather(
1919
ArrayRef<int64_t> vectorSizes = {}, ArrayRef<bool> scalableVecDims = {},
2020
bool vectorizeNDExtract = false);
2121

22+
/// Vectorizes iree_linalg_ext.gather to iree_vector_ext.transfer_gather.
23+
/// Currently, this pattern only works when the index_depth and batch rank of
24+
/// the gather is 1.
25+
///
26+
/// %gather = iree_linalg_ext.gather dimension_map=[0] ins(%source, %indices)
27+
/// outs(%output)
28+
///
29+
/// vectorizes to:
30+
///
31+
/// %indices_vec = vector.transfer_read %indices
32+
/// %gather_vec = iree_vector_ext.gather %source[...][%indices_vec...]
33+
/// %gather = vector.transfer_write %gather_vec, %output
2234
LogicalResult
2335
vectorizeLinalgExtGatherToTransferGather(RewriterBase &rewriter,
24-
IREE::LinalgExt::GatherOp gatherOp);
36+
IREE::LinalgExt::GatherOp gatherOp,
37+
ArrayRef<int64_t> vectorSizes = {});
2538

2639
void populateVectorTransferGatherLoweringPatterns(RewritePatternSet &patterns);
2740

41+
void populateVectorMaskLoweringPatterns(RewritePatternSet &patterns);
42+
2843
}; // namespace mlir::iree_compiler::IREE::VectorExt
2944

3045
#endif // IREE_COMPILER_CODEGEN_DIALECT_VECTOR_EXT_TRANSFORMS_TRANSFORMS_H_

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66

77
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
88
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"
9+
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Transforms.h"
910
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1011
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
1112
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1213
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1314
#include "mlir/Dialect/Vector/IR/VectorOps.h"
15+
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1416
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
1517
#include "mlir/IR/Builders.h"
1618
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -494,9 +496,17 @@ LogicalResult vectorizeGatherLikeGenericToTransferGather(
494496
return success();
495497
}
496498

499+
Value maskOperation(RewriterBase &rewriter, Operation *op, Value mask) {
500+
Value maskedOp =
501+
cast<vector::MaskOp>(mlir::vector::maskOperation(rewriter, op, mask))
502+
.getResult(0);
503+
return maskedOp;
504+
}
505+
497506
LogicalResult
498507
vectorizeLinalgExtGatherToTransferGather(RewriterBase &rewriter,
499-
IREE::LinalgExt::GatherOp gatherOp) {
508+
IREE::LinalgExt::GatherOp gatherOp,
509+
ArrayRef<int64_t> vectorSizes) {
500510

501511
// TODO: need to split the innermost dim of `indices` into `indexDepth`
502512
// vectors so that each independent index can be passed to the
@@ -520,23 +530,29 @@ vectorizeLinalgExtGatherToTransferGather(RewriterBase &rewriter,
520530
ShapedType gatherTy = gatherOp.getOutputType();
521531
ShapedType sourceTy = gatherOp.getSourceType();
522532

523-
auto gatherVectorTy =
524-
VectorType::get(gatherTy.getShape(), gatherTy.getElementType());
533+
if (vectorSizes.empty()) {
534+
vectorSizes = gatherTy.getShape();
535+
}
536+
537+
auto gatherVectorTy = VectorType::get(vectorSizes, gatherTy.getElementType());
525538
// Rank-reduced to remove the innermost unit dim.
526-
auto indicesVecTy =
527-
VectorType::get(indicesTy.getShape().take_front(gatherOp.getBatchRank()),
528-
rewriter.getIndexType());
539+
auto indicesVecTy = VectorType::get(
540+
vectorSizes.take_front(gatherOp.getBatchRank()), rewriter.getIndexType());
529541

530-
// Read `indices` tensor via `vector.transfer_read` and cast from int to
531-
// index.
532542
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
533-
Value indicesVec = rewriter.create<vector::TransferReadOp>(
543+
auto indicesVecRead = rewriter.create<vector::TransferReadOp>(
534544
loc, indicesVecTy.clone(indicesTy.getElementType()),
535545
gatherOp.getIndices(), SmallVector<Value>(indicesTy.getRank(), zero));
546+
VectorType indicesMaskType = indicesVecTy.clone(rewriter.getI1Type());
547+
SmallVector<OpFoldResult> gatherDims =
548+
tensor::getMixedSizes(rewriter, loc, gatherOp.getOutput());
549+
Value indicesMask = rewriter.create<vector::CreateMaskOp>(
550+
loc, indicesMaskType,
551+
ArrayRef(gatherDims).take_front(gatherOp.getBatchRank()));
552+
Value indicesVec = maskOperation(rewriter, indicesVecRead, indicesMask);
536553
indicesVec =
537554
rewriter.create<arith::IndexCastOp>(loc, indicesVecTy, indicesVec);
538555

539-
// Create transfer_gather op
540556
SmallVector<Value> baseIndices(sourceTy.getRank(), zero);
541557
SmallVector<bool> indexed(sourceTy.getRank(), false);
542558
indexed[0] = true;
@@ -547,21 +563,59 @@ vectorizeLinalgExtGatherToTransferGather(RewriterBase &rewriter,
547563
rewriter.getMultiDimIdentityMap(sourceTy.getRank()).getMajorSubMap(1)));
548564
Value padding = rewriter.create<arith::ConstantOp>(
549565
loc, rewriter.getZeroAttr(gatherTy.getElementType()));
550-
551566
auto transferGatherOp = rewriter.create<IREE::VectorExt::TransferGatherOp>(
552567
loc, gatherVectorTy, gatherOp.getSource(), baseIndices,
553568
ValueRange{indicesVec}, rewriter.getBoolArrayAttr(indexed), indexedMaps,
554569
rewriter.getMultiDimIdentityMap(gatherTy.getRank()), padding,
555570
/*mask=*/Value(), inBounds);
556571

557-
// Write back into tensor.
558-
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, gatherTy.getShape(),
559-
gatherTy.getElementType());
572+
VectorType gatherMaskType = gatherVectorTy.clone(rewriter.getI1Type());
573+
Value gatherMask =
574+
rewriter.create<vector::CreateMaskOp>(loc, gatherMaskType, gatherDims);
575+
Value maskedGather = maskOperation(rewriter, transferGatherOp, gatherMask);
560576
SmallVector<Value> writeIndices(gatherTy.getRank(), zero);
561577
auto writeOp = rewriter.create<vector::TransferWriteOp>(
562-
loc, transferGatherOp.getResult(), emptyOp, writeIndices);
563-
rewriter.replaceOp(gatherOp, writeOp);
578+
loc, maskedGather, gatherOp.getOutput(), writeIndices);
579+
Value maskedWrite = maskOperation(rewriter, writeOp, gatherMask);
580+
581+
rewriter.replaceOp(gatherOp, maskedWrite);
564582
return success();
565583
}
566584

585+
/// Lowers vector.mask %mask { iree_vector_ext.transfer_gather }
586+
/// into
587+
/// iree_vector_ext.transfer_gather %mask
588+
///
589+
/// Ideally, the mask should have just been put on transfer_gather directly,
590+
/// but this is done this way to match upstream vector.transfer_read masking.
591+
struct MaskedTransferGatherOpPattern : public OpRewritePattern<vector::MaskOp> {
592+
public:
593+
using OpRewritePattern::OpRewritePattern;
594+
595+
LogicalResult matchAndRewrite(vector::MaskOp maskOp,
596+
PatternRewriter &rewriter) const override {
597+
auto gatherOp = dyn_cast<TransferGatherOp>(maskOp.getMaskableOp());
598+
if (!gatherOp) {
599+
return failure();
600+
}
601+
// TODO: The 'vector.mask' passthru is a vector and 'transfer_gather'
602+
// expects a scalar. We could only lower one to the other for cases where
603+
// the passthru is a broadcast of a scalar.
604+
if (maskOp.hasPassthru()) {
605+
return rewriter.notifyMatchFailure(
606+
maskOp, "can't lower passthru to transfer_gather");
607+
}
608+
rewriter.replaceOpWithNewOp<TransferGatherOp>(
609+
maskOp, gatherOp.getVectorType(), gatherOp.getBase(),
610+
gatherOp.getIndices(), gatherOp.getIndexVecs(), gatherOp.getIndexed(),
611+
gatherOp.getIndexedMaps(), gatherOp.getPermutationMap(),
612+
gatherOp.getPadding(), maskOp.getMask(), gatherOp.getInBounds());
613+
return success();
614+
}
615+
};
616+
617+
void populateVectorMaskLoweringPatterns(RewritePatternSet &patterns) {
618+
patterns.add<MaskedTransferGatherOpPattern>(patterns.getContext());
619+
}
620+
567621
} // namespace mlir::iree_compiler::IREE::VectorExt

compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,30 @@ func.func @linalg_ext_gather_unit_dim(%source : tensor<1024x128xi32>, %indices :
146146
// CHECK: %[[CAST:.+]] = arith.index_cast %[[READ]]
147147
// CHECK: %[[GATHER:.+]] = iree_vector_ext.transfer_gather %[[ARG0]]
148148
// CHECK-SAME: [%[[C0]], %[[C0]]][%[[CAST]]: vector<10xindex>, None]
149+
150+
// -----
151+
152+
func.func @linalg_ext_gather_masked(%source : tensor<?x128xi32>, %indices : tensor<?x1xi32>) -> (tensor<?x128xi32>) {
153+
%c0 = arith.constant 0 : index
154+
%dim = tensor.dim %indices, %c0 : tensor<?x1xi32>
155+
%dim_ub = util.assume.int %dim[<umin = 1, umax = 12>] : index
156+
%empty = tensor.empty(%dim_ub) : tensor<?x128xi32>
157+
%result = iree_linalg_ext.gather dimension_map = [0]
158+
ins(%source, %indices : tensor<?x128xi32>, tensor<?x1xi32>)
159+
outs(%empty: tensor<?x128xi32>) -> tensor<?x128xi32>
160+
return %result : tensor<?x128xi32>
161+
}
162+
163+
// CHECK-LABEL: @linalg_ext_gather_masked
164+
// CHECK-SAME: %[[SOURCE:[a-zA-Z0-9]+]]
165+
// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]]
166+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
167+
// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
168+
// CHECK: %[[DIM:.+]] = tensor.dim %[[INDICES]], %[[C0]]
169+
// CHECK: %[[DIM_UB:.+]] = util.assume.int %[[DIM]]
170+
// CHECK: %[[INDICES_MASK:.+]] = vector.create_mask %[[DIM_UB]]
171+
// CHECK: vector.transfer_read %[[INDICES]]
172+
// CHECK-SAME: %[[INDICES_MASK]]
173+
// CHECK: %[[MASK:.+]] = vector.create_mask %[[DIM_UB]], %[[C128]]
174+
// CHECK: iree_vector_ext.transfer_gather %[[SOURCE]]
175+
// CHECK-SAME: %[[MASK]]

compiler/src/iree/compiler/Codegen/Utils/Utils.cpp

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,37 +1863,44 @@ std::optional<VectorizationTileSizes> inferSizesFromIR(linalg::UnPackOp op) {
18631863
return result;
18641864
}
18651865

1866+
std::optional<VectorizationTileSizes> static inferSizesFromMixedSizes(
1867+
SmallVector<OpFoldResult> shape) {
1868+
VectorizationTileSizes result;
1869+
for (OpFoldResult dim : shape) {
1870+
LLVM_DEBUG(llvm::dbgs() << "Dim #" << dim << ": ");
1871+
FailureOr<int64_t> maybeDimBound =
1872+
ValueBoundsConstraintSet::computeConstantBound(
1873+
presburger::BoundType::UB, dim,
1874+
/*stopCondition=*/nullptr, /*closedUB=*/true);
1875+
if (failed(maybeDimBound)) {
1876+
LLVM_DEBUG(llvm::dbgs() << "failed\n");
1877+
return std::nullopt;
1878+
}
1879+
1880+
LLVM_DEBUG(llvm::dbgs() << maybeDimBound.value() << "\n");
1881+
result.vectorSizes.push_back(maybeDimBound.value());
1882+
result.destShape.push_back(maybeDimBound.value());
1883+
}
1884+
return result;
1885+
}
1886+
18661887
std::optional<VectorizationTileSizes> inferSizesFromIR(Value val) {
18671888
if (!val.getDefiningOp())
18681889
return std::nullopt;
18691890

18701891
std::optional<VectorizationTileSizes> result;
1892+
LLVM_DEBUG(llvm::dbgs() << "Inferring sizes for:\n" << val << "\n");
18711893
TypeSwitch<Operation *, void>(val.getDefiningOp())
18721894
.Case<linalg::LinalgOp>(
18731895
[&](auto op) { result = inferSizesFromIR(op, cast<OpResult>(val)); })
18741896
.Case<linalg::PackOp>([&](auto op) { result = inferSizesFromIR(op); })
1875-
.Case<tensor::ExtractSliceOp>([&](tensor::ExtractSliceOp op) {
1897+
.Case<tensor::ExtractSliceOp, tensor::EmptyOp>([&](auto op) {
18761898
// tensor::ExtractSliceOp is not vectorizable, so only `destShape` has
18771899
// the values.
1878-
result = VectorizationTileSizes();
1879-
LLVM_DEBUG(llvm::dbgs() << "Inferring sizes for:\n" << op << "\n");
1880-
int64_t destRank = op.getResult().getType().getRank();
1881-
for (int dim = 0; dim < destRank; ++dim) {
1882-
LLVM_DEBUG(llvm::dbgs() << "Dim #" << dim << ": ");
1883-
FailureOr<int64_t> maybeDimBound =
1884-
ValueBoundsConstraintSet::computeConstantBound(
1885-
presburger::BoundType::UB, {op, dim},
1886-
/*stopCondition=*/nullptr, /*closedUB=*/true);
1887-
if (failed(maybeDimBound)) {
1888-
LLVM_DEBUG(llvm::dbgs() << "failed\n");
1889-
result = std::nullopt;
1890-
return;
1891-
}
1892-
LLVM_DEBUG(llvm::dbgs() << maybeDimBound.value() << "\n");
1893-
result->destShape.push_back(maybeDimBound.value());
1894-
}
1900+
result = inferSizesFromMixedSizes(op.getMixedSizes());
18951901
})
18961902
.Default([&](Operation *) {});
1903+
18971904
return result;
18981905
}
18991906

0 commit comments

Comments
 (0)