Skip to content

Commit 88136f9

Browse files
authored
[mlir][vector] Canonicalize gathers/scatters with trivial offsets (llvm#117939)
Canonicalize gathers/scatters with contiguous (i.e. [0, 1, 2, ...]) offsets into vector masked load/store ops.
1 parent 0ee037b commit 88136f9

File tree

2 files changed

+193
-2
lines changed

2 files changed

+193
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5184,6 +5184,23 @@ std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
51845184
return llvm::to_vector<4>(getVectorType().getShape());
51855185
}
51865186

5187+
/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...]
5188+
static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
5189+
auto vecType = dyn_cast<VectorType>(indexVec.getType());
5190+
if (!vecType || vecType.getRank() != 1 || vecType.isScalable())
5191+
return failure();
5192+
5193+
if (indexVec.getDefiningOp<StepOp>())
5194+
return success();
5195+
5196+
DenseIntElementsAttr elements;
5197+
if (!matchPattern(indexVec, m_Constant(&elements)))
5198+
return failure();
5199+
5200+
return success(
5201+
llvm::equal(elements, llvm::seq<int64_t>(0, vecType.getNumElements())));
5202+
}
5203+
51875204
namespace {
51885205
class GatherFolder final : public OpRewritePattern<GatherOp> {
51895206
public:
@@ -5202,11 +5219,28 @@ class GatherFolder final : public OpRewritePattern<GatherOp> {
52025219
llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder");
52035220
}
52045221
};
5222+
5223+
/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous
5224+
/// maskedload. Only 1D fixed vectors are supported for now.
5225+
class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
5226+
public:
5227+
using OpRewritePattern::OpRewritePattern;
5228+
LogicalResult matchAndRewrite(GatherOp op,
5229+
PatternRewriter &rewriter) const override {
5230+
if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5231+
return failure();
5232+
5233+
rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(),
5234+
op.getIndices(), op.getMask(),
5235+
op.getPassThru());
5236+
return success();
5237+
}
5238+
};
52055239
} // namespace
52065240

52075241
void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
52085242
MLIRContext *context) {
5209-
results.add<GatherFolder>(context);
5243+
results.add<GatherFolder, FoldContiguousGather>(context);
52105244
}
52115245

52125246
//===----------------------------------------------------------------------===//
@@ -5248,11 +5282,27 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
52485282
llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder");
52495283
}
52505284
};
5285+
5286+
/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous
5287+
/// maskedstore. Only 1D fixed vectors are supported for now.
5288+
class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
5289+
public:
5290+
using OpRewritePattern::OpRewritePattern;
5291+
LogicalResult matchAndRewrite(ScatterOp op,
5292+
PatternRewriter &rewriter) const override {
5293+
if (failed(isZeroBasedContiguousSeq(op.getIndexVec())))
5294+
return failure();
5295+
5296+
rewriter.replaceOpWithNewOp<MaskedStoreOp>(
5297+
op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore());
5298+
return success();
5299+
}
5300+
};
52515301
} // namespace
52525302

52535303
void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
52545304
MLIRContext *context) {
5255-
results.add<ScatterFolder>(context);
5305+
results.add<ScatterFolder, FoldContiguousScatter>(context);
52565306
}
52575307

52585308
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2838,3 +2838,144 @@ func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_s
28382838
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
28392839
return %1 : vector<1x1x2x1x1x1xi32>
28402840
}
2841+
2842+
// -----
2843+
2844+
// CHECK-LABEL: @contiguous_gather
2845+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
2846+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2847+
// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2848+
// CHECK: return %[[R]]
2849+
func.func @contiguous_gather(%base: memref<?xf32>,
2850+
%mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
2851+
%c0 = arith.constant 0 : index
2852+
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
2853+
%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
2854+
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2855+
return %1 : vector<16xf32>
2856+
}
2857+
2858+
// -----
2859+
2860+
// CHECK-LABEL: @contiguous_gather_non_zero_start(
2861+
// TODO: Non-zero start is not supported yet.
2862+
// CHECK: %[[R:.*]] = vector.gather
2863+
// CHECK: return %[[R]]
2864+
func.func @contiguous_gather_non_zero_start(%base: memref<?xf32>,
2865+
%mask: vector<16xi1>,
2866+
%passthru: vector<16xf32>) -> vector<16xf32> {
2867+
%c0 = arith.constant 0 : index
2868+
%indices = arith.constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<16xi32>
2869+
%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
2870+
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2871+
return %1 : vector<16xf32>
2872+
}
2873+
2874+
// -----
2875+
2876+
// CHECK-LABEL: @contiguous_gather_2d(
2877+
// TODO: Only 1D vectors are supported.
2878+
// CHECK: %[[R:.*]] = vector.gather
2879+
// CHECK: return %[[R]]
2880+
func.func @contiguous_gather_2d(%base: memref<?x?xf32>,
2881+
%mask: vector<4x4xi1>, %passthru: vector<4x4xf32>) -> vector<4x4xf32> {
2882+
%c0 = arith.constant 0 : index
2883+
%indices = arith.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : vector<4x4xi32>
2884+
%1 = vector.gather %base[%c0, %c0][%indices], %mask, %passthru :
2885+
memref<?x?xf32>, vector<4x4xi32>, vector<4x4xi1>, vector<4x4xf32> into vector<4x4xf32>
2886+
return %1 : vector<4x4xf32>
2887+
}
2888+
2889+
// -----
2890+
2891+
// CHECK-LABEL: @contiguous_gather_const_mask
2892+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[PASSTHRU:.*]]: vector<16xf32>)
2893+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2894+
// CHECK: %[[R:.*]] = vector.load %[[BASE]][%[[C0]]] : memref<?xf32>, vector<16xf32>
2895+
// CHECK: return %[[R]]
2896+
func.func @contiguous_gather_const_mask(%base: memref<?xf32>,
2897+
%passthru: vector<16xf32>) -> vector<16xf32> {
2898+
%c0 = arith.constant 0 : index
2899+
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
2900+
%mask = arith.constant dense<true> : vector<16xi1>
2901+
%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
2902+
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2903+
return %1 : vector<16xf32>
2904+
}
2905+
2906+
// -----
2907+
2908+
// CHECK-LABEL: @contiguous_gather_step
2909+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>)
2910+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2911+
// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2912+
// CHECK: return %[[R]]
2913+
func.func @contiguous_gather_step(%base: memref<?xf32>,
2914+
%mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
2915+
%c0 = arith.constant 0 : index
2916+
%indices = vector.step : vector<16xindex>
2917+
%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
2918+
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2919+
return %1 : vector<16xf32>
2920+
}
2921+
2922+
// -----
2923+
2924+
// CHECK-LABEL: @gather_broadcast(
2925+
// TODO: Broadcast is not supported yet
2926+
// CHECK: %[[R:.*]] = vector.gather
2927+
// CHECK: return %[[R]]
2928+
func.func @gather_broadcast(%base: memref<?xf32>,
2929+
%mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> {
2930+
%c0 = arith.constant 0 : index
2931+
%indices = arith.constant dense<0> : vector<16xi32>
2932+
%1 = vector.gather %base[%c0][%indices], %mask, %passthru :
2933+
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
2934+
return %1 : vector<16xf32>
2935+
}
2936+
2937+
// -----
2938+
2939+
// CHECK-LABEL: @contiguous_scatter
2940+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
2941+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2942+
// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
2943+
func.func @contiguous_scatter(%base: memref<?xf32>,
2944+
%mask: vector<16xi1>, %value: vector<16xf32>) {
2945+
%c0 = arith.constant 0 : index
2946+
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
2947+
vector.scatter %base[%c0][%indices], %mask, %value :
2948+
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
2949+
return
2950+
}
2951+
2952+
// -----
2953+
2954+
// CHECK-LABEL: @contiguous_scatter_const_mask
2955+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[VALUE:.*]]: vector<16xf32>)
2956+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2957+
// CHECK: vector.store %[[VALUE]], %[[BASE]][%[[C0]]] : memref<?xf32>, vector<16xf32>
2958+
func.func @contiguous_scatter_const_mask(%base: memref<?xf32>,
2959+
%value: vector<16xf32>) {
2960+
%c0 = arith.constant 0 : index
2961+
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
2962+
%mask = vector.constant_mask [16] : vector<16xi1>
2963+
vector.scatter %base[%c0][%indices], %mask, %value :
2964+
memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
2965+
return
2966+
}
2967+
2968+
// -----
2969+
2970+
// CHECK-LABEL: @contiguous_scatter_step
2971+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>)
2972+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
2973+
// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref<?xf32>, vector<16xi1>, vector<16xf32>
2974+
func.func @contiguous_scatter_step(%base: memref<?xf32>,
2975+
%mask: vector<16xi1>, %value: vector<16xf32>) {
2976+
%c0 = arith.constant 0 : index
2977+
%indices = vector.step : vector<16xindex>
2978+
vector.scatter %base[%c0][%indices], %mask, %value :
2979+
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
2980+
return
2981+
}

0 commit comments

Comments
 (0)