diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 696d1e0f9b1e6..dde6846b672ce 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5184,6 +5184,23 @@ std::optional> GatherOp::getShapeForUnroll() { return llvm::to_vector<4>(getVectorType().getShape()); } +/// Cheeck if `indexVec` is constant 1D vec of consecutive values [0, 1, 2, ...] +static LogicalResult isZeroBasedContiguousSeq(Value indexVec) { + auto vecType = dyn_cast(indexVec.getType()); + if (!vecType || vecType.getRank() != 1 || vecType.isScalable()) + return failure(); + + if (indexVec.getDefiningOp()) + return success(); + + DenseIntElementsAttr elements; + if (!matchPattern(indexVec, m_Constant(&elements))) + return failure(); + + return success( + llvm::equal(elements, llvm::seq(0, vecType.getNumElements()))); +} + namespace { class GatherFolder final : public OpRewritePattern { public: @@ -5202,11 +5219,28 @@ class GatherFolder final : public OpRewritePattern { llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder"); } }; + +/// Fold gathers with consecutive offsets [0, 1, 2, ...] into contiguous +/// maskedload. Only 1D fixed vectors are supported for now. +class FoldContiguousGather final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GatherOp op, + PatternRewriter &rewriter) const override { + if (failed(isZeroBasedContiguousSeq(op.getIndexVec()))) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), op.getBase(), + op.getIndices(), op.getMask(), + op.getPassThru()); + return success(); + } +}; } // namespace void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// @@ -5248,11 +5282,27 @@ class ScatterFolder final : public OpRewritePattern { llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder"); } }; + +/// Fold scatters with consecutive offsets [0, 1, 2, ...] into contiguous +/// maskedstore. Only 1D fixed vectors are supported for now. +class FoldContiguousScatter final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ScatterOp op, + PatternRewriter &rewriter) const override { + if (failed(isZeroBasedContiguousSeq(op.getIndexVec()))) + return failure(); + + rewriter.replaceOpWithNewOp( + op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore()); + return success(); + } +}; } // namespace void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 89af0f7332f5c..0eebb6e8d612d 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2838,3 +2838,144 @@ func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_s %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32> return %1 : vector<1x1x2x1x1x1xi32> } + +// ----- + +// CHECK-LABEL: @contiguous_gather +// CHECK-SAME: (%[[BASE:.*]]: memref, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK: return %[[R]] +func.func @contiguous_gather(%base: memref, + %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> { + %c0 = arith.constant 0 : index + %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> + %1 = vector.gather %base[%c0][%indices], %mask, %passthru : + memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %1 : vector<16xf32> +} + +// ----- + +// CHECK-LABEL: @contiguous_gather_non_zero_start( +// TODO: Non-zero start is not supported yet. +// CHECK: %[[R:.*]] = vector.gather +// CHECK: return %[[R]] +func.func @contiguous_gather_non_zero_start(%base: memref, + %mask: vector<16xi1>, + %passthru: vector<16xf32>) -> vector<16xf32> { + %c0 = arith.constant 0 : index + %indices = arith.constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]> : vector<16xi32> + %1 = vector.gather %base[%c0][%indices], %mask, %passthru : + memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %1 : vector<16xf32> +} + +// ----- + +// CHECK-LABEL: @contiguous_gather_2d( +// TODO: Only 1D vectors are supported. +// CHECK: %[[R:.*]] = vector.gather +// CHECK: return %[[R]] +func.func @contiguous_gather_2d(%base: memref, + %mask: vector<4x4xi1>, %passthru: vector<4x4xf32>) -> vector<4x4xf32> { + %c0 = arith.constant 0 : index + %indices = arith.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : vector<4x4xi32> + %1 = vector.gather %base[%c0, %c0][%indices], %mask, %passthru : + memref, vector<4x4xi32>, vector<4x4xi1>, vector<4x4xf32> into vector<4x4xf32> + return %1 : vector<4x4xf32> +} + +// ----- + +// CHECK-LABEL: @contiguous_gather_const_mask +// CHECK-SAME: (%[[BASE:.*]]: memref, %[[PASSTHRU:.*]]: vector<16xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[R:.*]] = vector.load %[[BASE]][%[[C0]]] : memref, vector<16xf32> +// CHECK: return %[[R]] +func.func @contiguous_gather_const_mask(%base: memref, + %passthru: vector<16xf32>) -> vector<16xf32> { + %c0 = arith.constant 0 : index + %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> + %mask = arith.constant dense : vector<16xi1> + %1 = vector.gather %base[%c0][%indices], %mask, %passthru : + memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %1 : vector<16xf32> +} + +// ----- + +// CHECK-LABEL: @contiguous_gather_step +// CHECK-SAME: (%[[BASE:.*]]: memref, %[[MASK:.*]]: vector<16xi1>, %[[PASSTHRU:.*]]: vector<16xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[R:.*]] = vector.maskedload %[[BASE]][%[[C0]]], %[[MASK]], %[[PASSTHRU]] : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> +// CHECK: return %[[R]] +func.func @contiguous_gather_step(%base: memref, + %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> { + %c0 = arith.constant 0 : index + %indices = vector.step : vector<16xindex> + %1 = vector.gather %base[%c0][%indices], %mask, %passthru : + memref, vector<16xindex>, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %1 : vector<16xf32> +} + +// ----- + +// CHECK-LABEL: @gather_broadcast( +// TODO: Broadcast is not supported yet +// CHECK: %[[R:.*]] = vector.gather +// CHECK: return %[[R]] +func.func @gather_broadcast(%base: memref, + %mask: vector<16xi1>, %passthru: vector<16xf32>) -> vector<16xf32> { + %c0 = arith.constant 0 : index + %indices = arith.constant dense<0> : vector<16xi32> + %1 = vector.gather %base[%c0][%indices], %mask, %passthru : + memref, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + return %1 : vector<16xf32> +} + +// ----- + +// CHECK-LABEL: @contiguous_scatter +// CHECK-SAME: (%[[BASE:.*]]: memref, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref, vector<16xi1>, vector<16xf32> +func.func @contiguous_scatter(%base: memref, + %mask: vector<16xi1>, %value: vector<16xf32>) { + %c0 = arith.constant 0 : index + %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> + vector.scatter %base[%c0][%indices], %mask, %value : + memref, vector<16xi32>, vector<16xi1>, vector<16xf32> + return +} + +// ----- + +// CHECK-LABEL: @contiguous_scatter_const_mask +// CHECK-SAME: (%[[BASE:.*]]: memref, %[[VALUE:.*]]: vector<16xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: vector.store %[[VALUE]], %[[BASE]][%[[C0]]] : memref, vector<16xf32> +func.func @contiguous_scatter_const_mask(%base: memref, + %value: vector<16xf32>) { + %c0 = arith.constant 0 : index + %indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> + %mask = vector.constant_mask [16] : vector<16xi1> + vector.scatter %base[%c0][%indices], %mask, %value : + memref, vector<16xi32>, vector<16xi1>, vector<16xf32> + return +} + +// ----- + +// CHECK-LABEL: @contiguous_scatter_step +// CHECK-SAME: (%[[BASE:.*]]: memref, %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: vector.maskedstore %[[BASE]][%[[C0]]], %[[MASK]], %[[VALUE]] : memref, vector<16xi1>, vector<16xf32> +func.func @contiguous_scatter_step(%base: memref, + %mask: vector<16xi1>, %value: vector<16xf32>) { + %c0 = arith.constant 0 : index + %indices = vector.step : vector<16xindex> + vector.scatter %base[%c0][%indices], %mask, %value : + memref, vector<16xindex>, vector<16xi1>, vector<16xf32> + return +}