@@ -358,6 +358,63 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
358358 .getResult ();
359359}
360360
361+ // Collapses shapes of a nD memref to the target rank while applying offsets for
362+ // the collapsed dimensions. Returns the new memref value and the remaining
363+ // offsets for the last targetRank dimensions. For example:
364+ // input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
365+ // targetRank=2 output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, returned
366+ // offsets: [%i2, %i3]
367+ static std::pair<Value, SmallVector<OpFoldResult>>
368+ convertMemrefAndOffsetsToTargetRank (PatternRewriter &rewriter, Location loc,
369+ Value memref,
370+ SmallVector<OpFoldResult> offsets,
371+ int64_t targetRank) {
372+ auto memrefType = cast<MemRefType>(memref.getType ());
373+ unsigned rank = memrefType.getRank ();
374+
375+ if (rank <= targetRank)
376+ return {memref, offsets};
377+
378+ int64_t numCombinedDims = rank - targetRank;
379+ SmallVector<OpFoldResult> subviewOffsets;
380+ SmallVector<OpFoldResult> subviewSizes;
381+ SmallVector<OpFoldResult> subviewStrides;
382+
383+ // For the combined dimensions: use the provided offsets, size=1, stride=1
384+ for (unsigned i = 0 ; i < numCombinedDims; ++i) {
385+ subviewOffsets.push_back (offsets[i]);
386+ subviewSizes.push_back (rewriter.getI64IntegerAttr (1 ));
387+ subviewStrides.push_back (rewriter.getI64IntegerAttr (1 ));
388+ }
389+
390+ // For the last targetRank dimensions: offset=0, use full size, stride=1
391+ SmallVector<int64_t > resultShape;
392+ auto originalShape = memrefType.getShape ();
393+ auto meta = memref::ExtractStridedMetadataOp::create (rewriter, loc, memref);
394+ for (unsigned i = numCombinedDims; i < rank; ++i) {
395+ subviewOffsets.push_back (rewriter.getI64IntegerAttr (0 ));
396+ if (ShapedType::isDynamic (originalShape[i])) {
397+ subviewSizes.push_back (meta.getSizes ()[i]);
398+ resultShape.push_back (ShapedType::kDynamic );
399+ } else {
400+ subviewSizes.push_back (rewriter.getI64IntegerAttr (originalShape[i]));
401+ resultShape.push_back (originalShape[i]);
402+ }
403+ subviewStrides.push_back (rewriter.getI64IntegerAttr (1 ));
404+ }
405+
406+ auto resultType = memref::SubViewOp::inferRankReducedResultType (
407+ resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
408+ auto subviewOp =
409+ memref::SubViewOp::create (rewriter, loc, resultType, memref,
410+ subviewOffsets, subviewSizes, subviewStrides);
411+
412+ // Return the remaining offsets for the last targetRank dimensions
413+ SmallVector<OpFoldResult> newOffsets (offsets.begin () + numCombinedDims,
414+ offsets.end ());
415+ return {subviewOp.getResult (), newOffsets};
416+ }
417+
361418template <
362419 typename OpType,
363420 typename = std::enable_if_t <llvm::is_one_of<
@@ -493,17 +550,18 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
493550 !isTransposeLoad ? nullptr
494551 : DenseI64ArrayAttr::get (rewriter.getContext (),
495552 ArrayRef<int64_t >{1 , 0 });
553+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank (
554+ rewriter, loc, readOp.getBase (), getAsOpFoldResult (readOp.getIndices ()),
555+ vecTy.getRank ());
496556 // By default, no specific caching policy is assigned.
497557 xegpu::CachePolicyAttr hint = nullptr ;
498- xegpu::CreateNdDescOp ndDesc =
499- createNdDescriptor (rewriter, loc, descType,
500- dyn_cast<TypedValue<MemRefType>>(readOp.getBase ()));
501-
502- auto loadOp = xegpu::LoadNdOp::create (
503- rewriter, loc, vecTy, ndDesc, getAsOpFoldResult (readOp.getIndices ()),
504- /* packed=*/ nullptr , transposeAttr,
505- /* l1_hint=*/ hint,
506- /* l2_hint=*/ hint, /* l3_hint=*/ hint);
558+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
559+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
560+
561+ auto loadOp = xegpu::LoadNdOp::create (rewriter, loc, vecTy, ndDesc, indices,
562+ /* packed=*/ nullptr , transposeAttr,
563+ /* l1_hint=*/ hint,
564+ /* l2_hint=*/ hint, /* l3_hint=*/ hint);
507565 rewriter.replaceOp (readOp, loadOp);
508566
509567 return success ();
@@ -541,21 +599,23 @@ struct TransferWriteLowering
541599 if (!map.isMinorIdentity ())
542600 return rewriter.notifyMatchFailure (writeOp, " Expects identity map" );
543601
602+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank (
603+ rewriter, loc, writeOp.getBase (),
604+ getAsOpFoldResult (writeOp.getIndices ()), vecTy.getRank ());
605+
544606 auto descType = xegpu::TensorDescType::get (
545607 vecTy.getShape (), vecTy.getElementType (),
546608 /* array_length=*/ 1 , /* boundary_check=*/ writeOp.hasOutOfBoundsDim (),
547609 xegpu::MemorySpace::Global);
548610 // By default, no specific caching policy is assigned.
549611 xegpu::CachePolicyAttr hint = nullptr ;
550- xegpu::CreateNdDescOp ndDesc =
551- createNdDescriptor (rewriter, loc, descType,
552- dyn_cast<TypedValue<MemRefType>>(writeOp.getBase ()));
612+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
613+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
553614
554- auto storeOp =
555- xegpu::StoreNdOp::create (rewriter, loc, writeOp.getVector (), ndDesc,
556- getAsOpFoldResult (writeOp.getIndices ()),
557- /* l1_hint=*/ hint,
558- /* l2_hint=*/ hint, /* l3_hint=*/ hint);
615+ auto storeOp = xegpu::StoreNdOp::create (rewriter, loc, writeOp.getVector (),
616+ ndDesc, indices,
617+ /* l1_hint=*/ hint,
618+ /* l2_hint=*/ hint, /* l3_hint=*/ hint);
559619 rewriter.replaceOp (writeOp, storeOp);
560620
561621 return success ();
@@ -643,17 +703,21 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
643703 // By default, no specific caching policy is assigned.
644704 xegpu::CachePolicyAttr hint = nullptr ;
645705
706+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank (
707+ rewriter, loc, loadOp.getBase (), getAsOpFoldResult (loadOp.getIndices ()),
708+ vecTy.getRank ());
709+
646710 auto descType = xegpu::TensorDescType::get (
647711 vecTy.getShape (), vecTy.getElementType (), /* array_length=*/ 1 ,
648712 boundaryCheck, xegpu::MemorySpace::Global);
649713
650- xegpu::CreateNdDescOp ndDesc =
651- createNdDescriptor ( rewriter, loc, descType, loadOp. getBase ( ));
652- auto loadNdOp = xegpu::LoadNdOp::create (
653- rewriter, loc, vecTy, ndDesc, getAsOpFoldResult (loadOp. getIndices ()) ,
654- /* packed=*/ nullptr , /* transpose=*/ nullptr ,
655- /* l1_hint=*/ hint,
656- /* l2_hint=*/ hint, /* l3_hint=*/ hint);
714+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
715+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src ));
716+ auto loadNdOp =
717+ xegpu::LoadNdOp::create ( rewriter, loc, vecTy, ndDesc, indices ,
718+ /* packed=*/ nullptr , /* transpose=*/ nullptr ,
719+ /* l1_hint=*/ hint,
720+ /* l2_hint=*/ hint, /* l3_hint=*/ hint);
657721 rewriter.replaceOp (loadOp, loadNdOp);
658722
659723 return success ();
@@ -675,19 +739,23 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
675739 // Boundary check is available only for block instructions.
676740 bool boundaryCheck = vecTy.getRank () > 1 ;
677741
742+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank (
743+ rewriter, loc, storeOp.getBase (),
744+ getAsOpFoldResult (storeOp.getIndices ()), vecTy.getRank ());
745+
678746 auto descType = xegpu::TensorDescType::get (
679747 vecTy.getShape (), vecTy.getElementType (),
680748 /* array_length=*/ 1 , boundaryCheck, xegpu::MemorySpace::Global);
681749
682750 // By default, no specific caching policy is assigned.
683751 xegpu::CachePolicyAttr hint = nullptr ;
684- xegpu::CreateNdDescOp ndDesc =
685- createNdDescriptor ( rewriter, loc, descType, storeOp. getBase ( ));
752+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
753+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src ));
686754
687- auto storeNdOp = xegpu::StoreNdOp::create (
688- rewriter, loc, vector, ndDesc, getAsOpFoldResult (storeOp. getIndices ()) ,
689- /* l1_hint=*/ hint,
690- /* l2_hint=*/ hint, /* l3_hint=*/ hint);
755+ auto storeNdOp =
756+ xegpu::StoreNdOp::create ( rewriter, loc, vector, ndDesc, indices ,
757+ /* l1_hint=*/ hint,
758+ /* l2_hint=*/ hint, /* l3_hint=*/ hint);
691759
692760 rewriter.replaceOp (storeOp, storeNdOp);
693761
0 commit comments