@@ -97,57 +97,23 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
9797 return success ();
9898}
9999
100- static xegpu::CreateNdDescOp
101- createNdDescriptor (PatternRewriter &rewriter, Location loc,
102- xegpu::TensorDescType descType, TypedValue<MemRefType> src ,
103- Operation::operand_range offsets ) {
100+ static xegpu::CreateNdDescOp createNdDescriptor (PatternRewriter &rewriter,
101+ Location loc,
102+ xegpu::TensorDescType descType,
103+ TypedValue<MemRefType> src ) {
104104 MemRefType srcTy = src.getType ();
105105 auto [strides, offset] = srcTy.getStridesAndOffset ();
106106
107107 xegpu::CreateNdDescOp ndDesc;
108108 if (srcTy.hasStaticShape ()) {
109- ndDesc = xegpu::CreateNdDescOp::create (rewriter, loc, descType, src,
110- getAsOpFoldResult (offsets));
109+ ndDesc = xegpu::CreateNdDescOp::create (rewriter, loc, descType, src);
111110 } else {
112111 // In case of any dynamic shapes, source's shape and strides have to be
113112 // explicitly provided.
114- SmallVector<Value> sourceDims;
115- unsigned srcRank = srcTy.getRank ();
116- for (unsigned i = 0 ; i < srcRank; ++i)
117- sourceDims.push_back (memref::DimOp::create (rewriter, loc, src, i));
118-
119- SmallVector<int64_t > constOffsets;
120- SmallVector<Value> dynOffsets;
121- for (Value offset : offsets) {
122- std::optional<int64_t > staticVal = getConstantIntValue (offset);
123- if (!staticVal)
124- dynOffsets.push_back (offset);
125- constOffsets.push_back (staticVal.value_or (ShapedType::kDynamic ));
126- }
127-
128- SmallVector<Value> dynShapes;
129- for (auto [idx, shape] : llvm::enumerate (srcTy.getShape ())) {
130- if (shape == ShapedType::kDynamic )
131- dynShapes.push_back (sourceDims[idx]);
132- }
133-
134- // Compute strides in reverse order.
135- SmallVector<Value> dynStrides;
136- Value accStride = arith::ConstantIndexOp::create (rewriter, loc, 1 );
137- // Last stride is guaranteed to be static and unit.
138- for (int i = static_cast <int >(strides.size ()) - 2 ; i >= 0 ; --i) {
139- accStride =
140- arith::MulIOp::create (rewriter, loc, accStride, sourceDims[i + 1 ]);
141- if (strides[i] == ShapedType::kDynamic )
142- dynStrides.push_back (accStride);
143- }
144- std::reverse (dynStrides.begin (), dynStrides.end ());
145-
146- ndDesc = xegpu::CreateNdDescOp::create (
147- rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
148- DenseI64ArrayAttr::get (rewriter.getContext (), constOffsets),
149- DenseI64ArrayAttr::get (rewriter.getContext (), srcTy.getShape ()),
150- DenseI64ArrayAttr::get (rewriter.getContext (), strides));
113+ auto meta = memref::ExtractStridedMetadataOp::create (rewriter, loc, src);
114+ ndDesc = xegpu::CreateNdDescOp::create (rewriter, loc, descType, src,
115+ meta.getConstifiedMixedSizes (),
116+ meta.getConstifiedMixedStrides ());
151117 }
152118
153119 return ndDesc;
@@ -392,6 +358,62 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
392358 .getResult ();
393359}
394360
361+ // Collapses shapes of a nD memref to the target rank while applying offsets for
362+ // the collapsed dimensions. Returns the new memref value and the remaining
363+ // offsets for the last targetRank dimensions. For example:
364+ // input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
365+ // output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
366+ static std::pair<Value, SmallVector<OpFoldResult>>
367+ convertMemrefAndOffsetsToTargetRank (PatternRewriter &rewriter, Location loc,
368+ Value memref,
369+ SmallVector<OpFoldResult> offsets,
370+ int64_t targetRank) {
371+ auto memrefType = cast<MemRefType>(memref.getType ());
372+ unsigned rank = memrefType.getRank ();
373+
374+ if (rank <= targetRank)
375+ return {memref, offsets};
376+
377+ int64_t numCombinedDims = rank - targetRank;
378+ SmallVector<OpFoldResult> subviewOffsets;
379+ SmallVector<OpFoldResult> subviewSizes;
380+ SmallVector<OpFoldResult> subviewStrides;
381+
382+ // For the combined dimensions: use the provided offsets, size=1, stride=1
383+ for (unsigned i = 0 ; i < numCombinedDims; ++i) {
384+ subviewOffsets.push_back (offsets[i]);
385+ subviewSizes.push_back (rewriter.getI64IntegerAttr (1 ));
386+ subviewStrides.push_back (rewriter.getI64IntegerAttr (1 ));
387+ }
388+
389+ // For the last targetRank dimensions: offset=0, use full size, stride=1
390+ SmallVector<int64_t > resultShape;
391+ auto originalShape = memrefType.getShape ();
392+ auto meta = memref::ExtractStridedMetadataOp::create (rewriter, loc, memref);
393+ for (unsigned i = numCombinedDims; i < rank; ++i) {
394+ subviewOffsets.push_back (rewriter.getI64IntegerAttr (0 ));
395+ if (ShapedType::isDynamic (originalShape[i])) {
396+ subviewSizes.push_back (meta.getSizes ()[i]);
397+ resultShape.push_back (ShapedType::kDynamic );
398+ } else {
399+ subviewSizes.push_back (rewriter.getI64IntegerAttr (originalShape[i]));
400+ resultShape.push_back (originalShape[i]);
401+ }
402+ subviewStrides.push_back (rewriter.getI64IntegerAttr (1 ));
403+ }
404+
405+ auto resultType = memref::SubViewOp::inferRankReducedResultType (
406+ resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
407+ auto subviewOp =
408+ memref::SubViewOp::create (rewriter, loc, resultType, memref,
409+ subviewOffsets, subviewSizes, subviewStrides);
410+
411+ // Return the remaining offsets for the last targetRank dimensions
412+ SmallVector<OpFoldResult> newOffsets (offsets.begin () + numCombinedDims,
413+ offsets.end ());
414+ return {subviewOp.getResult (), newOffsets};
415+ }
416+
395417template <
396418 typename OpType,
397419 typename = std::enable_if_t <llvm::is_one_of<
@@ -523,18 +545,19 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
523545 descShape, elementType, /* array_length=*/ 1 ,
524546 /* boundary_check=*/ isOutOfBounds, xegpu::MemorySpace::Global);
525547
526- xegpu::CreateNdDescOp ndDesc =
527- createNdDescriptor (rewriter, loc, descType,
528- dyn_cast<TypedValue<MemRefType>>(readOp.getBase ()),
529- readOp.getIndices ());
530-
531548 DenseI64ArrayAttr transposeAttr =
532549 !isTransposeLoad ? nullptr
533550 : DenseI64ArrayAttr::get (rewriter.getContext (),
534551 ArrayRef<int64_t >{1 , 0 });
552+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank (
553+ rewriter, loc, readOp.getBase (), getAsOpFoldResult (readOp.getIndices ()),
554+ vecTy.getRank ());
535555 // By default, no specific caching policy is assigned.
536556 xegpu::CachePolicyAttr hint = nullptr ;
537- auto loadOp = xegpu::LoadNdOp::create (rewriter, loc, vecTy, ndDesc,
557+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
558+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
559+
560+ auto loadOp = xegpu::LoadNdOp::create (rewriter, loc, vecTy, ndDesc, indices,
538561 /* packed=*/ nullptr , transposeAttr,
539562 /* l1_hint=*/ hint,
540563 /* l2_hint=*/ hint, /* l3_hint=*/ hint);
@@ -575,21 +598,23 @@ struct TransferWriteLowering
575598 if (!map.isMinorIdentity ())
576599 return rewriter.notifyMatchFailure (writeOp, " Expects identity map" );
577600
601+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank (
602+ rewriter, loc, writeOp.getBase (),
603+ getAsOpFoldResult (writeOp.getIndices ()), vecTy.getRank ());
604+
578605 auto descType = xegpu::TensorDescType::get (
579606 vecTy.getShape (), vecTy.getElementType (),
580607 /* array_length=*/ 1 , /* boundary_check=*/ writeOp.hasOutOfBoundsDim (),
581608 xegpu::MemorySpace::Global);
582- xegpu::CreateNdDescOp ndDesc =
583- createNdDescriptor (rewriter, loc, descType,
584- dyn_cast<TypedValue<MemRefType>>(writeOp.getBase ()),
585- writeOp.getIndices ());
586-
587609 // By default, no specific caching policy is assigned.
588610 xegpu::CachePolicyAttr hint = nullptr ;
589- auto storeOp =
590- xegpu::StoreNdOp::create (rewriter, loc, writeOp.getVector (), ndDesc,
591- /* l1_hint=*/ hint,
592- /* l2_hint=*/ hint, /* l3_hint=*/ hint);
611+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
612+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
613+
614+ auto storeOp = xegpu::StoreNdOp::create (rewriter, loc, writeOp.getVector (),
615+ ndDesc, indices,
616+ /* l1_hint=*/ hint,
617+ /* l2_hint=*/ hint, /* l3_hint=*/ hint);
593618 rewriter.replaceOp (writeOp, storeOp);
594619
595620 return success ();
@@ -674,19 +699,24 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
674699
675700 // Boundary check is available only for block instructions.
676701 bool boundaryCheck = vecTy.getRank () > 1 ;
702+ // By default, no specific caching policy is assigned.
703+ xegpu::CachePolicyAttr hint = nullptr ;
704+
705+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank (
706+ rewriter, loc, loadOp.getBase (), getAsOpFoldResult (loadOp.getIndices ()),
707+ vecTy.getRank ());
677708
678709 auto descType = xegpu::TensorDescType::get (
679710 vecTy.getShape (), vecTy.getElementType (), /* array_length=*/ 1 ,
680711 boundaryCheck, xegpu::MemorySpace::Global);
681- xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
682- rewriter, loc, descType, loadOp.getBase (), loadOp.getIndices ());
683712
684- // By default, no specific caching policy is assigned.
685- xegpu::CachePolicyAttr hint = nullptr ;
686- auto loadNdOp = xegpu::LoadNdOp::create (
687- rewriter, loc, vecTy, ndDesc, /* packed=*/ nullptr , /* transpose=*/ nullptr ,
688- /* l1_hint=*/ hint,
689- /* l2_hint=*/ hint, /* l3_hint=*/ hint);
713+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
714+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
715+ auto loadNdOp =
716+ xegpu::LoadNdOp::create (rewriter, loc, vecTy, ndDesc, indices,
717+ /* packed=*/ nullptr , /* transpose=*/ nullptr ,
718+ /* l1_hint=*/ hint,
719+ /* l2_hint=*/ hint, /* l3_hint=*/ hint);
690720 rewriter.replaceOp (loadOp, loadNdOp);
691721
692722 return success ();
@@ -708,18 +738,24 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
708738 // Boundary check is available only for block instructions.
709739 bool boundaryCheck = vecTy.getRank () > 1 ;
710740
741+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank (
742+ rewriter, loc, storeOp.getBase (),
743+ getAsOpFoldResult (storeOp.getIndices ()), vecTy.getRank ());
744+
711745 auto descType = xegpu::TensorDescType::get (
712746 vecTy.getShape (), vecTy.getElementType (),
713747 /* array_length=*/ 1 , boundaryCheck, xegpu::MemorySpace::Global);
714- xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
715- rewriter, loc, descType, storeOp.getBase (), storeOp.getIndices ());
716748
717749 // By default, no specific caching policy is assigned.
718750 xegpu::CachePolicyAttr hint = nullptr ;
751+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
752+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
753+
719754 auto storeNdOp =
720- xegpu::StoreNdOp::create (rewriter, loc, vector, ndDesc,
755+ xegpu::StoreNdOp::create (rewriter, loc, vector, ndDesc, indices,
721756 /* l1_hint=*/ hint,
722757 /* l2_hint=*/ hint, /* l3_hint=*/ hint);
758+
723759 rewriter.replaceOp (storeOp, storeNdOp);
724760
725761 return success ();
0 commit comments