@@ -59,6 +59,37 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
5959 ArrayRef<bool > inputVecScalableFlags = {},
6060 bool flatten1DDepthwiseConv = false );
6161
62+ // / Vectorize tensor::InsertSliceOp with:
63+ // / * vector::TransferReadOp + vector::TransferWriteOp
64+ // / The vector sizes are either:
65+ // / * user-provided in `inputVectorSizes`, or
66+ // / * inferred from the static dims in the input and output tensors.
67+ // / Bails out if:
68+ // / * vector sizes are not user-provided, and
69+ // / * at least one dim is dynamic (in both the input and output tensors).
70+ // /
71+ // / Before:
72+ // / !t_in_type = tensor<1x2x3xf32>
73+ // / !t_out_type = tensor<9x8x7x1x2x3xf32>
74+ // / !v_type = vector<1x2x3xf32>
75+ // / %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
76+ // / into !t_out_type
77+ // / After:
78+ // / %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
79+ // / %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
80+ static LogicalResult
81+ vectorizeAsInsertSliceOp (RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
82+ ArrayRef<int64_t > inputVectorSizes,
83+ SmallVectorImpl<Value> &newResults);
84+
85+ // / Returns the effective Pad value for the input op, provided it's a scalar.
86+ // /
87+ // / Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
88+ // / this Op performs padding, retrieve the padding value provided that it's
89+ // / a scalar and static/fixed for all the padded values. Returns an empty value
90+ // / otherwise.
91+ static Value getStaticPadVal (Operation *op);
92+
6293// / Return the unique instance of OpType in `block` if it is indeed unique.
6394// / Return null if none or more than 1 instances exist.
6495template <typename OpType>
@@ -1557,6 +1588,7 @@ static LogicalResult
15571588vectorizeAsTensorPackOp (RewriterBase &rewriter, tensor::PackOp packOp,
15581589 ArrayRef<int64_t > inputVectorSizes,
15591590 SmallVectorImpl<Value> &newResults) {
1591+ // TODO: Introduce a parent class that will handle the insertion point update.
15601592 OpBuilder::InsertionGuard g (rewriter);
15611593 rewriter.setInsertionPoint (packOp);
15621594
@@ -1633,6 +1665,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
16331665 ArrayRef<int64_t > inputVectorSizes,
16341666 SmallVectorImpl<Value> &newResults) {
16351667
1668+ // TODO: Introduce a parent class that will handle the insertion point update.
16361669 OpBuilder::InsertionGuard g (rewriter);
16371670 rewriter.setInsertionPoint (unpackOp);
16381671
@@ -1763,7 +1796,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
17631796 auto padValue = padOp.getConstantPaddingValue ();
17641797 Location loc = padOp.getLoc ();
17651798
1766- // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1799+ // TODO: Introduce a parent class that will handle the insertion point update.
17671800 OpBuilder::InsertionGuard g (rewriter);
17681801 rewriter.setInsertionPoint (padOp);
17691802
@@ -1874,6 +1907,38 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
18741907 return success ();
18751908}
18761909
1910+ static LogicalResult
1911+ vectorizeInsertSliceOpPrecondition (tensor::InsertSliceOp sliceOp,
1912+ ArrayRef<int64_t > inputVectorSizes) {
1913+
1914+ TypedValue<RankedTensorType> source = sliceOp.getSource ();
1915+ auto sourceType = source.getType ();
1916+ if (!VectorType::isValidElementType (sourceType.getElementType ()))
1917+ return failure ();
1918+
1919+ // Get the pad value.
1920+ // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
1921+ // scalar padding value. Note that:
1922+ // * for in-bounds accesses,
1923+ // the value is actually irrelevant. There are 2 cases in which xfer.read
1924+ // accesses are known to be in-bounds:
1925+ // 1. The source shape is static (output vector sizes would be based on
1926+ // the source shape and hence all memory accesses would be in-bounds),
1927+ // 2. Masking is used, i.e. the output vector sizes are user-provided. In
1928+ // this case it is safe to assume that all memory accesses are in-bounds.
1929+ //
1930+ // When the value is not known and not needed, use 0. Otherwise, bail out.
1931+ Value padValue = getStaticPadVal (sliceOp);
1932+ bool isOutOfBoundsRead =
1933+ !sourceType.hasStaticShape () && inputVectorSizes.empty ();
1934+
1935+ if (!padValue && isOutOfBoundsRead) {
1936+ LDBG (" Failed to get a pad value for out-of-bounds read access\n " );
1937+ return failure ();
1938+ }
1939+ return success ();
1940+ }
1941+
18771942static LogicalResult vectorizeLinalgOpPrecondition (
18781943 LinalgOp linalgOp, ArrayRef<int64_t > inputVectorSizes,
18791944 bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
@@ -2144,6 +2209,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
21442209 .Case <tensor::UnPackOp>([&](auto unpackOp) {
21452210 return vectorizeUnPackOpPrecondition (unpackOp, inputVectorSizes);
21462211 })
2212+ .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
2213+ return vectorizeInsertSliceOpPrecondition (sliceOp, inputVectorSizes);
2214+ })
21472215 .Default ([](auto ) { return failure (); });
21482216}
21492217
@@ -2163,8 +2231,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
21632231}
21642232
21652233bool mlir::linalg::hasVectorizationImpl (Operation *op) {
2166- return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2167- op);
2234+ return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
2235+ tensor::InsertSliceOp>( op);
21682236}
21692237
21702238// / Emit a suitable vector form for an operation. If provided,
@@ -2244,6 +2312,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
22442312 return vectorizeAsTensorPackOp (rewriter, packOp, inputVectorSizes,
22452313 results);
22462314 })
2315+ .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
2316+ return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
2317+ results);
2318+ })
22472319 .Case <tensor::UnPackOp>([&](auto unpackOp) {
22482320 return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
22492321 inputVectorSizes, results);
@@ -2540,6 +2612,9 @@ struct PadOpVectorizationWithTransferWritePattern
25402612// / this Op performs padding, retrieve the padding value provided that it's
25412613// / a scalar and static/fixed for all the padded values. Returns an empty value
25422614// / otherwise.
2615+ // /
2616+ // / TODO: This is used twice (when checking vectorization pre-conditions and
2617+ // / when vectorizing). Cache results instead of re-running.
25432618static Value getStaticPadVal (Operation *op) {
25442619 if (!op)
25452620 return {};
@@ -2583,113 +2658,118 @@ static Value getStaticPadVal(Operation *op) {
25832658 return {};
25842659}
25852660
2586- // / Rewrite tensor.insert.slice as a vector.transfer_read +
2587- // / vector.transfer_write pair. The vector size is inferred from the static
2588- // / dims in the input and output tensors. If a dim is dynamic in both the input
2589- // / and output tensors, bails out.
2590- // /
2591- // / Before:
2592- // / !t_in_type = tensor<1x2x3xf32>
2593- // / !t_out_type = tensor<9x8x7x1x2x3xf32>
2594- // / !v_type = vector<1x2x3xf32>
2595- // / %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2596- // / into !t_out_type
2597- // / After:
2598- // / %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2599- // / %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2600- // /
2601- // / TODO: Support masking
2602- struct InsertSliceVectorizePattern
2603- : public OpRewritePattern<tensor::InsertSliceOp> {
2604- using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2661+ static LogicalResult
2662+ vectorizeAsInsertSliceOp (RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2663+ ArrayRef<int64_t > inputVectorSizes,
2664+ SmallVectorImpl<Value> &newResults) {
2665+ // TODO: Introduce a parent class that will handle the insertion point update.
2666+ OpBuilder::InsertionGuard g (rewriter);
2667+ rewriter.setInsertionPoint (sliceOp);
26052668
2606- LogicalResult matchAndRewrite (tensor::InsertSliceOp sliceOp,
2607- PatternRewriter &rewriter) const final {
2608- auto sourceType = sliceOp.getSource ().getType ();
2609- if (!VectorType::isValidElementType (sourceType.getElementType ()))
2610- return failure ();
2669+ TypedValue<RankedTensorType> source = sliceOp.getSource ();
2670+ auto sourceType = source.getType ();
2671+ auto resultType = sliceOp.getResultType ();
26112672
2612- auto resultType = sliceOp.getResultType ();
2613-
2614- // 1. Get the pad value.
2615- // TransferReadOp requires a scalar padding value. Note that:
2616- // * for in-bounds access, the value is actually irrelevant.
2617- // There are 2 cases in which xfer.read accesses are known to be in-bounds:
2618- // 1. The source shape is static (output vector sizes would be based on
2619- // the source shape and hence all memory accesses would be in-bounds),
2620- // 2. Masking is used (output vector sizes would be user-provided, in which
2621- // case it is assumed that all memory accesses are in-bounds). This
2622- // remains a TODO.
2623- //
2624- // When the value is not known and not needed, use 0. Otherwise, bail out.
2625- Value padValue = getStaticPadVal (sliceOp);
2626- bool isOutOfBoundsRead = !sourceType.hasStaticShape ();
2627-
2628- if (!padValue && isOutOfBoundsRead) {
2629- LDBG (" Failed to get a pad value for out-of-bounds read access\n " );
2673+ Value padValue = getStaticPadVal (sliceOp);
2674+
2675+ if (!padValue) {
2676+ auto elemType = sourceType.getElementType ();
2677+ padValue = rewriter.create <arith::ConstantOp>(
2678+ sliceOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2679+ }
2680+
2681+ // 2. Get the vector shape and in-bounds attributes
2682+ SmallVector<int64_t > vecShape;
2683+ SmallVector<bool > readInBounds;
2684+ SmallVector<bool > writeInBounds;
2685+ size_t rankDiff = resultType.getRank () - sourceType.getRank ();
2686+ for (int64_t i = 0 , end = sourceType.getRank (); i < end; ++i) {
2687+ if (!inputVectorSizes.empty ()) {
2688+ vecShape.push_back (inputVectorSizes[i]);
2689+ readInBounds.push_back (false );
2690+ writeInBounds.push_back (false );
2691+ } else if (!sourceType.isDynamicDim (i)) {
2692+ vecShape.push_back (sourceType.getDimSize (i));
2693+ // Source shape is statically known: Neither read nor write are
2694+ // out-of-bounds.
2695+ readInBounds.push_back (true );
2696+ writeInBounds.push_back (true );
2697+ } else if (!resultType.isDynamicDim (i)) {
2698+ // Source shape is not statically known, but result shape is.
2699+ // Vectorize with size of result shape. This may be larger than the
2700+ // source size.
2701+ // FIXME: Using rankDiff implies that the source tensor is inserted at
2702+ // the end of the destination tensor. However, that's not required.
2703+ vecShape.push_back (resultType.getDimSize (rankDiff + i));
2704+ // Read may be out-of-bounds because the result size could be larger
2705+ // than the source size.
2706+ readInBounds.push_back (false );
2707+ // Write will be in-bounds provided that the corresponding write idx is 0.
2708+ // To keep this logic simple, conservatively mark as out-of-bounds.
2709+ writeInBounds.push_back (false );
2710+ } else {
2711+ // Neither source nor result dim of padOp is static. Cannot vectorize
2712+ // the copy.
2713+ // TODO: Add support for masking
26302714 return failure ();
26312715 }
2716+ }
2717+ auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
26322718
2633- if (!padValue) {
2634- auto elemType = sourceType.getElementType ();
2635- padValue = rewriter.create <arith::ConstantOp>(
2636- sliceOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2637- }
2719+ // 3. Generate TransferReadOp.
2720+ SmallVector<Value> readIndices (
2721+ vecType.getRank (),
2722+ rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2723+ Operation *read = rewriter.create <vector::TransferReadOp>(
2724+ sliceOp.getLoc (), vecType, source, readIndices, padValue,
2725+ ArrayRef<bool >{readInBounds});
26382726
2639- // 2. Get the vector shape and in-bounds attributes
2640- SmallVector<int64_t > vecShape;
2641- SmallVector<bool > readInBounds;
2642- SmallVector<bool > writeInBounds;
2643- size_t rankDiff = resultType.getRank () - sourceType.getRank ();
2644- for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2645- if (!sourceType.isDynamicDim (i)) {
2646- vecShape.push_back (sourceType.getDimSize (i));
2647- // Source shape is statically known: Neither read nor write are
2648- // out-of-bounds.
2649- readInBounds.push_back (true );
2650- writeInBounds.push_back (true );
2651- } else if (!resultType.isDynamicDim (i)) {
2652- // Source shape is not statically known, but result shape is.
2653- // Vectorize with size of result shape. This may be larger than the
2654- // source size.
2655- // FIXME: Using rankDiff implies that the source tensor is inserted at
2656- // the end of the destination tensor. However, that's not required.
2657- vecShape.push_back (resultType.getDimSize (rankDiff + i));
2658- // Read may be out-of-bounds because the result size could be larger
2659- // than the source size.
2660- readInBounds.push_back (false );
2661- // Write will in-bounds provided that the corresponding write idx is 0.
2662- // To keep this logic simple, conservatively mark as out-of-bounds.
2663- writeInBounds.push_back (false );
2664- } else {
2665- // Neither source nor result dim of padOp is static. Cannot vectorize
2666- // the copy.
2667- // TODO: Add support for masking
2668- return failure ();
2669- }
2727+ // If vector sizes are user provided, make sure to mask xfer_read.
2728+ if (!inputVectorSizes.empty ()) {
2729+ auto *srcDefOp = source.getDefiningOp ();
2730+ if (!srcDefOp) {
2731+ LDBG (" Unable to get the defining Op of " << sliceOp);
2732+ return failure ();
26702733 }
2671- auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
26722734
2673- // 3. Generate TransferReadOp.
2674- SmallVector<Value> readIndices (
2675- vecType.getRank (),
2676- rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2677- auto read = rewriter.create <vector::TransferReadOp>(
2678- sliceOp.getLoc (), vecType, sliceOp.getSource (), readIndices, padValue,
2679- ArrayRef<bool >{readInBounds});
2735+ ReifiedRankedShapedTypeDims reifiedSrcSizes;
2736+ LogicalResult status =
2737+ cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes (
2738+ rewriter, reifiedSrcSizes);
2739+ if (status.failed ()) {
2740+ LDBG (" Unable to reify result shapes of " << sliceOp);
2741+ return failure ();
2742+ }
26802743
2681- // 4. Generate TransferWriteOp.
2682- auto writeIndices = getValueOrCreateConstantIndexOp (
2683- rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2744+ // Create the mask
2745+ SmallVector<int64_t > readMaskShape (
2746+ sliceOp.getSource ().getType ().getShape ());
2747+ auto readMaskType = VectorType::get (inputVectorSizes, rewriter.getI1Type ());
2748+ Value maskOp = rewriter.create <vector::CreateMaskOp>(
2749+ sliceOp.getLoc (), readMaskType, reifiedSrcSizes[0 ]);
26842750
2685- // 5. Finalize
2686- rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
2687- sliceOp, read, sliceOp.getDest (), writeIndices,
2688- ArrayRef<bool >{writeInBounds});
2751+ // Mask the xfer_read Op
2752+ read = mlir::vector::maskOperation (rewriter, read, maskOp);
2753+ }
26892754
2690- return success ();
2755+ // 4. Generate TransferWriteOp.
2756+ if (!inputVectorSizes.empty () &&
2757+ ShapedType::isDynamicShape (resultType.getShape ())) {
2758+ LDBG (" TODO: Masking of xfer_write when vectorising " << sliceOp);
2759+ return failure ();
26912760 }
2692- };
2761+
2762+ auto writeIndices = getValueOrCreateConstantIndexOp (
2763+ rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2764+
2765+ // 5. Finalize
2766+ Operation *write = rewriter.create <vector::TransferWriteOp>(
2767+ sliceOp.getLoc (), read->getResult (0 ), sliceOp.getDest (), writeIndices,
2768+ ArrayRef<bool >{writeInBounds});
2769+ newResults.push_back (write->getResult (0 ));
2770+
2771+ return success ();
2772+ }
26932773
26942774// / Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
26952775// / ```
@@ -2778,11 +2858,6 @@ struct PadOpVectorizationWithInsertSlicePattern
27782858 }
27792859};
27802860
2781- void mlir::linalg::populateInsertSliceVectorizationPatterns (
2782- RewritePatternSet &patterns) {
2783- patterns.add <InsertSliceVectorizePattern>(patterns.getContext ());
2784- }
2785-
27862861void mlir::linalg::populatePadOpVectorizationPatterns (
27872862 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
27882863 patterns.add <PadOpVectorizationWithTransferReadPattern,
0 commit comments