@@ -298,16 +298,156 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
298298 }
299299};
300300
301+ // / Transforms a `transfer_read` operation so it reads vector of a type that
302+ // / can be mapped to an LLVM type ("LLVM-legal" type). This is done by
303+ // / collapsing trailing dimensions so we obtain a vector type with a single
304+ // / scalable dimension in the rightmost position.
305+ // /
306+ // / Example:
307+ // / ```
308+ // / %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
309+ // / {in_bounds = [false, true, true, true]}
310+ // / : memref<?x?x2x8xi8>, vector<2x[4]x2x8xi8>
311+ // / ```
312+ // / is rewritten to
313+ // / ```
314+ // / %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
315+ // / : memref<?x?x2x8xi8> into memref<?x?xi8>
316+ // / %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
317+ // / {in_bounds = [false, true]}
318+ // / : memref<?x?xi8>, vector<2x[64]xi8>
319+ // / %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
320+ // / ```
321+ struct LegalizeTransferRead : public OpRewritePattern <vector::TransferReadOp> {
322+ using OpRewritePattern::OpRewritePattern;
323+
324+ LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
325+ PatternRewriter &rewriter) const override {
326+
327+ // Do not try to transform masked reads. For example, if we have a transfer
328+ // to a `vector<[4]x4xi8>` we could have a mask like
329+ // 1 1 1 0
330+ // 1 1 1 0
331+ // 1 1 1 0
332+ // 0 0 0 0
333+ // Flattening this mask would look like
334+ // 1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
335+ // and we have not yet figured out an efficient way to build such a mask,
336+ // neither from the mask operand, nor from the original `vector.create_mask`
337+ // operation (if visible at all).
338+ if (readOp.isMasked () || readOp.getMask ())
339+ return rewriter.notifyMatchFailure (readOp,
340+ " masked transfers not-supported" );
341+
342+ // General permutation maps are not supported. The issue is with transpose,
343+ // broadcast, and other forms of non-identify mapping in the minor
344+ // dimensions which is impossible to represent after collapsing (at least
345+ // because the resulting "collapsed" maps would have smaller number of
346+ // dimension indices).
347+ // TODO: We have not had yet the need for it, but some forms of permutation
348+ // maps with identity in the minor dimensions voukld be supported, for
349+ // example `(i, j, k, p) -> (j, i, k, p)` where we need to collapse only `k`
350+ // and `p`.
351+ if (!readOp.getPermutationMap ().isMinorIdentity ())
352+ return rewriter.notifyMatchFailure (readOp, " non-identity permutation" );
353+
354+ // We handle transfers of vectors with rank >= 2 and a single scalable
355+ // dimension. This transformation aims to transform an LLVM-illegal type
356+ // into an LLVM-legal type and one dimensional vectors are already
357+ // LLVM-legal, even if scalable. A value of a vector type with more than one
358+ // scalable dimension is impossible to represent using a vector type with no
359+ // scalable dimensions or a single one. For example a `vector<[4]x[4]xi8>`
360+ // would have `4 * 4 * vscale * vscale` elements and this quantity is
361+ // impossible to represent as `N` or `N * vscale` (where `N` is a constant).
362+ VectorType origVT = readOp.getVectorType ();
363+ ArrayRef<bool > origScalableDims = origVT.getScalableDims ();
364+ const int64_t origVRank = origVT.getRank ();
365+ if (origVRank < 2 || origVT.getNumScalableDims () != 1 )
366+ return rewriter.notifyMatchFailure (readOp, " wrong dimensions" );
367+
368+ // Number of trailing dimensions to collapse, including the scalable
369+ // dimension. Nothing to do if the single scalable dimension is already the
370+ // last one.
371+ const int64_t numCollapseDims = std::distance (
372+ llvm::find (origScalableDims, true ), origScalableDims.end ());
373+ if (numCollapseDims < 2 )
374+ return rewriter.notifyMatchFailure (readOp,
375+ " scalable dimension is trailing" );
376+
377+ // We want a simple memref (not a tensor) with contiguous elements for at
378+ // least all the trailing dimensions up to and including the scalable one.
379+ auto memTy = dyn_cast<MemRefType>(readOp.getBase ().getType ());
380+ if (!(memTy && memTy.areTrailingDimsContiguous (numCollapseDims)))
381+ return rewriter.notifyMatchFailure (
382+ readOp, " non-contiguous memref dimensions to collapse" );
383+
384+ // The dimensions to collapse (excluding the scalable one) of the vector and
385+ // the memref must match. A dynamic memref dimension is considered
386+ // non-matching. The transfers from the dimensions to collapse must be
387+ // in-bounds (it follows the corresponding indices would be zero). This
388+ // guarantees that the operation transfers a contiguous block
389+ // and no padding is necessary.
390+ if (!llvm::equal (memTy.getShape ().take_back (numCollapseDims - 1 ),
391+ origVT.getShape ().take_back (numCollapseDims - 1 )))
392+ return rewriter.notifyMatchFailure (
393+ readOp, " memref and vector dimensions do not match" );
394+
395+ SmallVector<bool > origInBounds = readOp.getInBoundsValues ();
396+ if (!llvm::all_of (
397+ ArrayRef<bool >(origInBounds).take_back (numCollapseDims - 1 ),
398+ [](bool v) { return v; }))
399+ return rewriter.notifyMatchFailure (
400+ readOp, " out-of-bounds transfer from a dimension to collapse" );
401+
402+ // Collapse the trailing dimensions of the memref.
403+ SmallVector<ReassociationIndices> reassoc;
404+ for (int64_t i = 0 ; i < memTy.getRank () - numCollapseDims + 1 ; ++i)
405+ reassoc.push_back ({i});
406+ for (int64_t i = memTy.getRank () - numCollapseDims + 1 ; i < memTy.getRank ();
407+ ++i)
408+ reassoc.back ().push_back (i);
409+ if (!memref::CollapseShapeOp::isGuaranteedCollapsible (memTy, reassoc))
410+ return failure ();
411+ Value collapsedMem = rewriter.create <memref::CollapseShapeOp>(
412+ readOp.getLoc (), readOp.getBase (), reassoc);
413+
414+ // Get a vector type with collapsed trailing dimensions.
415+ SmallVector<int64_t > shape (origVT.getShape ());
416+ for (int64_t i = origVRank - numCollapseDims + 1 ; i < origVRank; ++i)
417+ shape[origVRank - numCollapseDims] *= shape[i];
418+ shape.pop_back_n (numCollapseDims - 1 );
419+ auto collapsedVT =
420+ VectorType::get (shape, origVT.getElementType (),
421+ origScalableDims.drop_back (numCollapseDims - 1 ));
422+
423+ // Drop the extra (zero) indices.
424+ auto indices = readOp.getIndices ().drop_back (numCollapseDims - 1 );
425+
426+ // Create the new `transfer_read`.
427+ auto newReadOp = rewriter.create <vector::TransferReadOp>(
428+ readOp.getLoc (), collapsedVT, collapsedMem, indices,
429+ ArrayRef<bool >(origInBounds).drop_back (numCollapseDims - 1 ));
430+
431+ // Cast back to the original vector type.
432+ auto toOrigShape = rewriter.create <vector::ShapeCastOp>(readOp.getLoc (),
433+ origVT, newReadOp);
434+
435+ rewriter.replaceOp (readOp, toOrigShape);
436+ return success ();
437+ }
438+ };
439+
301440} // namespace
302441
303442void mlir::arm_sve::populateLegalizeVectorStoragePatterns (
304443 RewritePatternSet &patterns) {
305- patterns.add <RelaxScalableVectorAllocaAlignment,
306- LegalizeSVEMaskAllocation<memref::AllocaOp>,
307- LegalizeSVEMaskAllocation<memref::AllocOp>,
308- LegalizeSVEMaskTypeCastConversion,
309- LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
310- patterns.getContext ());
444+ patterns
445+ .add <RelaxScalableVectorAllocaAlignment,
446+ LegalizeSVEMaskAllocation<memref::AllocaOp>,
447+ LegalizeSVEMaskAllocation<memref::AllocOp>,
448+ LegalizeSVEMaskTypeCastConversion, LegalizeSVEMaskStoreConversion,
449+ LegalizeSVEMaskLoadConversion, LegalizeTransferRead>(
450+ patterns.getContext ());
311451}
312452
313453namespace {
0 commit comments