@@ -2281,115 +2281,6 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
22812281// ----------------------------------------------------------------------------//
22822282// Misc. vectorization patterns.
22832283// ----------------------------------------------------------------------------//
2284-
2285- // / Helper function that retrieves the value of an IntegerAttr.
2286- static int64_t getIntFromAttr (Attribute attr) {
2287- return cast<IntegerAttr>(attr).getInt ();
2288- }
2289-
2290- // / Given an ArrayRef of OpFoldResults, return a vector of Values.
2291- // / IntegerAttrs are converted to ConstantIndexOps. Other attribute types are
2292- // / not supported.
2293- static SmallVector<Value> ofrToIndexValues (RewriterBase &rewriter, Location loc,
2294- ArrayRef<OpFoldResult> ofrs) {
2295- SmallVector<Value> result;
2296- for (auto o : ofrs) {
2297- if (auto val = llvm::dyn_cast_if_present<Value>(o)) {
2298- result.push_back (val);
2299- } else {
2300- result.push_back (rewriter.create <arith::ConstantIndexOp>(
2301- loc, getIntFromAttr (o.template get <Attribute>())));
2302- }
2303- }
2304- return result;
2305- }
2306-
2307- // / Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and
2308- // / InsertSliceOp. For now, only constant padding values are supported.
2309- // / If there is enough static type information, TransferReadOps and
2310- // / TransferWriteOps may be generated instead of InsertSliceOps.
2311- struct GenericPadOpVectorizationPattern : public GeneralizePadOpPattern {
2312- GenericPadOpVectorizationPattern (MLIRContext *context,
2313- PatternBenefit benefit = 1 )
2314- : GeneralizePadOpPattern(context, tryVectorizeCopy, benefit) {}
2315- // / Vectorize the copying of a tensor::PadOp's source. This is possible if
2316- // / each dimension size is statically know in the source type or the result
2317- // / type (or both).
2318- static LogicalResult tryVectorizeCopy (RewriterBase &rewriter,
2319- tensor::PadOp padOp, Value dest) {
2320- auto sourceType = padOp.getSourceType ();
2321- auto resultType = padOp.getResultType ();
2322- if (!VectorType::isValidElementType (sourceType.getElementType ()))
2323- return failure ();
2324-
2325- // Copy cannot be vectorized if pad value is non-constant and source shape
2326- // is dynamic. In case of a dynamic source shape, padding must be appended
2327- // by TransferReadOp, but TransferReadOp supports only constant padding.
2328- auto padValue = padOp.getConstantPaddingValue ();
2329- if (!padValue) {
2330- if (!sourceType.hasStaticShape ())
2331- return failure ();
2332- // Create dummy padding value.
2333- auto elemType = sourceType.getElementType ();
2334- padValue = rewriter.create <arith::ConstantOp>(
2335- padOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2336- }
2337-
2338- SmallVector<int64_t > vecShape;
2339- SmallVector<bool > readInBounds;
2340- SmallVector<bool > writeInBounds;
2341- for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2342- if (!sourceType.isDynamicDim (i)) {
2343- vecShape.push_back (sourceType.getDimSize (i));
2344- // Source shape is statically known: Neither read nor write are
2345- // out-of- bounds.
2346- readInBounds.push_back (true );
2347- writeInBounds.push_back (true );
2348- } else if (!resultType.isDynamicDim (i)) {
2349- // Source shape is not statically known, but result shape is.
2350- // Vectorize with size of result shape. This may be larger than the
2351- // source size.
2352- vecShape.push_back (resultType.getDimSize (i));
2353- // Read may be out-of-bounds because the result size could be larger
2354- // than the source size.
2355- readInBounds.push_back (false );
2356- // Write is out-of-bounds if low padding > 0.
2357- writeInBounds.push_back (
2358- getConstantIntValue (padOp.getMixedLowPad ()[i]) ==
2359- static_cast <int64_t >(0 ));
2360- } else {
2361- // Neither source nor result dim of padOp is static. Cannot vectorize
2362- // the copy.
2363- return failure ();
2364- }
2365- }
2366- auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2367-
2368- // Generate TransferReadOp.
2369- SmallVector<Value> readIndices (
2370- vecType.getRank (),
2371- rewriter.create <arith::ConstantIndexOp>(padOp.getLoc (), 0 ));
2372- auto read = rewriter.create <vector::TransferReadOp>(
2373- padOp.getLoc (), vecType, padOp.getSource (), readIndices, padValue,
2374- ArrayRef<bool >{readInBounds});
2375-
2376- // If `dest` is a FillOp and the TransferWriteOp would overwrite the
2377- // entire tensor, write directly to the FillOp's operand.
2378- if (llvm::equal (vecShape, resultType.getShape ()) &&
2379- llvm::all_of (writeInBounds, [](bool b) { return b; }))
2380- if (auto fill = dest.getDefiningOp <FillOp>())
2381- dest = fill.output ();
2382-
2383- // Generate TransferWriteOp.
2384- auto writeIndices =
2385- ofrToIndexValues (rewriter, padOp.getLoc (), padOp.getMixedLowPad ());
2386- rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
2387- padOp, read, dest, writeIndices, ArrayRef<bool >{writeInBounds});
2388-
2389- return success ();
2390- }
2391- };
2392-
23932284// / Base pattern for rewriting tensor::PadOps whose result is consumed by a
23942285// / given operation type OpTy.
23952286template <typename OpTy>
@@ -2623,6 +2514,163 @@ struct PadOpVectorizationWithTransferWritePattern
26232514 }
26242515};
26252516
2517+ // / Returns the effective Pad value for the input op, provided it's a scalar.
2518+ // /
2519+ // / Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
2520+ // / this Op performs padding, retrieve the padding value provided that it's
2521+ // / a scalar and static/fixed for all the padded values. Returns an empty value
2522+ // / otherwise.
2523+ static Value getStaticPadVal (Operation *op) {
2524+ if (!op)
2525+ return {};
2526+
2527+ // 1. vector.broadcast (f32 -> vector <...xf32>) - return the value that's
2528+ // being broadcast, provided that it's a scalar.
2529+ if (auto bcast = llvm::dyn_cast<vector::BroadcastOp>(op)) {
2530+ auto source = bcast.getSource ();
2531+ if (llvm::dyn_cast<VectorType>(source.getType ()))
2532+ return {};
2533+
2534+ return source;
2535+ }
2536+
2537+ // 2. linalg.fill - use the scalar input value that used to fill the output
2538+ // tensor.
2539+ if (auto fill = llvm::dyn_cast<linalg::FillOp>(op)) {
2540+ return fill.getInputs ()[0 ];
2541+ }
2542+
2543+ // 3. tensor.generateOp - can't guarantee the value is fixed without
2544+ // analysing, bail out.
2545+ if (auto generate = llvm::dyn_cast<tensor::GenerateOp>(op)) {
2546+ return {};
2547+ }
2548+
2549+ // 4. vector.transfer_write - inspect the input vector that's written from. If
2550+ // if contains a single value that has been broadcast (e.g. via
2551+ // vector.broadcast), extract it, fail otherwise.
2552+ if (auto xferWrite = llvm::dyn_cast<vector::TransferWriteOp>(op))
2553+ return getStaticPadVal (xferWrite.getVector ().getDefiningOp ());
2554+
2555+ // 5. tensor.insert_slice - inspect the destination tensor. If it's larger
2556+ // than the input tensor, then, provided it's constant, we'll extract the
2557+ // value that was used to generate it (via e.g. linalg.fill), fail otherwise.
2558+ // TODO: Clarify the semantics when the input tensor is larger than the
2559+ // destination.
2560+ if (auto slice = llvm::dyn_cast<tensor::InsertSliceOp>(op))
2561+ return getStaticPadVal (slice.getDest ().getDefiningOp ());
2562+
2563+ return {};
2564+ }
2565+
2566+ // / Rewrite tensor.insert.slice as a vector.transfer_read +
2567+ // / vector.transfer_write pair. The vector size is inferred from the static
2568+ // / dims in the input and output tensors. If a dim is dynamic in both the input
2569+ // / and output tensors, bails out.
2570+ // /
2571+ // / Before:
2572+ // / !t_in_type = tensor<1x2x3xf32>
2573+ // / !t_out_type = tensor<9x8x7x1x2x3xf32>
2574+ // / !v_type = vector<1x2x3xf32>
2575+ // / %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2576+ // / into !t_out_type
2577+ // / After:
2578+ // / %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2579+ // / %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2580+ // /
2581+ // / TODO: Support masking
2582+ struct InsertSliceVectorizePattern
2583+ : public OpRewritePattern<tensor::InsertSliceOp> {
2584+ using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2585+
2586+ LogicalResult matchAndRewrite (tensor::InsertSliceOp sliceOp,
2587+ PatternRewriter &rewriter) const final {
2588+ auto sourceType = sliceOp.getSource ().getType ();
2589+ if (!VectorType::isValidElementType (sourceType.getElementType ()))
2590+ return failure ();
2591+
2592+ auto resultType = sliceOp.getResultType ();
2593+
2594+ // 1. Get the pad value.
2595+ // TransferReadOp requires a scalar padding value. Note that:
2596+ // * for in-bounds access, the value is actually irrelevant.
2597+ // There are 2 cases in which xfer.read accesses are known to be in-bounds:
2598+ // 1. The source shape is static (output vector sizes would be based on
2599+ // the source shape and hence all memory accesses would be in-bounds),
2600+ // 2. Masking is used (output vector sizes would be user-provided, in which
2601+ // case it is assumed that all memory accesses are in-bounds). This
2602+ // remains a TODO.
2603+ //
2604+ // When the value is not known and not needed, use 0. Otherwise, bail out.
2605+ Value padValue = getStaticPadVal (sliceOp);
2606+ bool isOutOfBoundsRead = !sourceType.hasStaticShape ();
2607+
2608+ if (!padValue && isOutOfBoundsRead) {
2609+ LDBG (" Failed to get a pad value for out-of-bounds read access\n " );
2610+ return failure ();
2611+ }
2612+
2613+ if (!padValue) {
2614+ auto elemType = sourceType.getElementType ();
2615+ padValue = rewriter.create <arith::ConstantOp>(
2616+ sliceOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2617+ }
2618+
2619+ // 2. Get the vector shape and in-bounds attributes
2620+ SmallVector<int64_t > vecShape;
2621+ SmallVector<bool > readInBounds;
2622+ SmallVector<bool > writeInBounds;
2623+ size_t rankDiff = resultType.getRank () - sourceType.getRank ();
2624+ for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2625+ if (!sourceType.isDynamicDim (i)) {
2626+ vecShape.push_back (sourceType.getDimSize (i));
2627+ // Source shape is statically known: Neither read nor write are
2628+ // out-of-bounds.
2629+ readInBounds.push_back (true );
2630+ writeInBounds.push_back (true );
2631+ } else if (!resultType.isDynamicDim (i)) {
2632+ // Source shape is not statically known, but result shape is.
2633+ // Vectorize with size of result shape. This may be larger than the
2634+ // source size.
2635+ // FIXME: Using rankDiff implies that the source tensor is inserted at
2636+ // the end of the destination tensor. However, that's not required.
2637+ vecShape.push_back (resultType.getDimSize (rankDiff + i));
2638+ // Read may be out-of-bounds because the result size could be larger
2639+ // than the source size.
2640+ readInBounds.push_back (false );
2641+ // Write will in-bounds provided that the corresponding write idx is 0.
2642+ // To keep this logic simple, conservatively mark as out-of-bounds.
2643+ writeInBounds.push_back (false );
2644+ } else {
2645+ // Neither source nor result dim of padOp is static. Cannot vectorize
2646+ // the copy.
2647+ // TODO: Add support for masking
2648+ return failure ();
2649+ }
2650+ }
2651+ auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2652+
2653+ // 3. Generate TransferReadOp.
2654+ SmallVector<Value> readIndices (
2655+ vecType.getRank (),
2656+ rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2657+ auto read = rewriter.create <vector::TransferReadOp>(
2658+ sliceOp.getLoc (), vecType, sliceOp.getSource (), readIndices, padValue,
2659+ ArrayRef<bool >{readInBounds});
2660+
2661+ // 4. Generate TransferWriteOp.
2662+ auto writeIndices = getValueOrCreateConstantIndexOp (
2663+ rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2664+
2665+ // 5. Finalize
2666+ rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
2667+ sliceOp, read, sliceOp.getDest (), writeIndices,
2668+ ArrayRef<bool >{writeInBounds});
2669+
2670+ return success ();
2671+ }
2672+ };
2673+
26262674// / Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
26272675// / ```
26282676// / %0 = tensor.pad %src ... : tensor<?x?xf32> to tensor<17x5xf32>
@@ -2699,8 +2747,8 @@ struct PadOpVectorizationWithInsertSlicePattern
26992747 // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at
27002748 // specified offsets. Write is fully in-bounds because a InsertSliceOp's
27012749 // source must fit into the destination at the specified offsets.
2702- auto writeIndices =
2703- ofrToIndexValues ( rewriter, padOp.getLoc (), insertOp.getMixedOffsets ());
2750+ auto writeIndices = getValueOrCreateConstantIndexOp (
2751+ rewriter, padOp.getLoc (), insertOp.getMixedOffsets ());
27042752 SmallVector<bool > inBounds (vecRank, true );
27052753 rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
27062754 insertOp, read, insertOp.getDest (), writeIndices,
@@ -2710,13 +2758,18 @@ struct PadOpVectorizationWithInsertSlicePattern
27102758 }
27112759};
27122760
2761+ void mlir::linalg::populateInsertSliceVectorizationPatterns (
2762+ RewritePatternSet &patterns) {
2763+ patterns.add <InsertSliceVectorizePattern>(patterns.getContext ());
2764+ }
2765+
27132766void mlir::linalg::populatePadOpVectorizationPatterns (
27142767 RewritePatternSet &patterns, PatternBenefit baseBenefit) {
27152768 // TODO: The following pattern implements "decomposition" and
27162769 // optional "vectorization". Seperate "decomposition" into a sepereate
27172770 // pre-processing pattern group.
2718- patterns.add <GenericPadOpVectorizationPattern >(patterns.getContext (),
2719- baseBenefit);
2771+ patterns.add <GeneralizePadOpPattern >(patterns.getContext (), baseBenefit);
2772+
27202773 // Try these specialized patterns first before resorting to the generic one.
27212774 patterns.add <PadOpVectorizationWithTransferReadPattern,
27222775 PadOpVectorizationWithTransferWritePattern,
0 commit comments