Skip to content

Commit 9641898

Browse files
committed
[mlir][vector] Canonicalize gathers/scatters with trivial offsets
Cononicalize gathers/scatters with contiguous (i.e. [0, 1, 2, ...]) offsets into vector masked load/store ops.
1 parent 7e749d4 commit 9641898

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5184,6 +5184,19 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
51845184
return llvm::to_vector<4>(getVectorType().getShape());
51855185
}
51865186

5187+
static LogicalResult isContiguousIndices(Value val) {
5188+
auto vecType = dyn_cast<VectorType>(val.getType());
5189+
if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5190+
return failure();
5191+
5192+
DenseIntElementsAttr elements;
5193+
if (!matchPattern(val, m_Constant(&elements)))
5194+
return failure();
5195+
5196+
return success(
5197+
llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5198+
}
5199+
51875200
namespace {
51885201
class GatherFolder final : public OpRewritePattern<GatherOp> {
51895202
public:
@@ -5202,11 +5215,26 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
52025215
llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
52035216
}
52045217
};
5218+
5219+
class GatherTrivialIndices final : public OpRewritePattern<GatherOp> {
5220+
public:
5221+
using OpRewritePattern::OpRewritePattern;
5222+
LogicalResult matchAndRewrite(GatherOp op,
5223+
PatternRewriter &rewriter) const override {
5224+
if (failed(isContiguousIndices(op.getIndexVec())))
5225+
return failure();
5226+
5227+
rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
5228+
op.getIndices(), op.getMask(),
5229+
op.getPassThru());
5230+
return success();
5231+
}
5232+
};
52055233
} // namespace
52065234

52075235
void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
52085236
MLIRContext *context) {
5209-
results.add<GatherFolder>(context);
5237+
results.add<GatherFolder, GatherTrivialIndices>(context);
52105238
}
52115239

52125240
//===----------------------------------------------------------------------===//
@@ -5248,11 +5276,25 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
52485276
llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
52495277
}
52505278
};
5279+
5280+
class ScatterTrivialIndices final : public OpRewritePattern<ScatterOp> {
5281+
public:
5282+
using OpRewritePattern::OpRewritePattern;
5283+
LogicalResult matchAndRewrite(ScatterOp op,
5284+
PatternRewriter &rewriter) const override {
5285+
if (failed(isContiguousIndices(op.getIndexVec())))
5286+
return failure();
5287+
5288+
rewriter.replaceOpWithNewOp<MaskedStoreOp>(
5289+
op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5290+
return success();
5291+
}
5292+
};
52515293
} // namespace
52525294

52535295
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
52545296
MLIRContext *context) {
5255-
results.add<ScatterFolder>(context);
5297+
results.add<ScatterFolder, ScatterTrivialIndices>(context);
52565298
}
52575299

52585300
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2838,3 +2838,34 @@ func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_s
28382838
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
28392839
return %1 : vector<1x1x2x1x1x1xi32>
28402840
}
2841+
2842+
// -----
2843+
2844+
// CHECK-LABEL: @contiguous_gather
2845+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
2846+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2847+
// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2848+
// CHECK: return %[[R]]
2849+
func.func @contiguous_gather(%base: memref<?xf32>,
2850+
%mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
2851+
%c0 = arith.constant 0 : index
2852+
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
2853+
%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
2854+
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2855+
return %1 : vector<16xf32>
2856+
}
2857+
2858+
// -----
2859+
2860+
// CHECK-LABEL: @contiguous_scatter
2861+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
2862+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2863+
// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
2864+
func.func @contiguous_scatter(%base: memref<?xf32>,
2865+
%mask: vector<16xi1>, %value: vector<16xf32>){
2866+
%c0 = arith.constant 0 : index
2867+
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
2868+
vector.scatter %base[%c0][%indices], %mask, %value :
2869+
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
2870+
return
2871+
}

0 commit comments

Comments
 (0)