@@ -180,26 +180,31 @@ static void adjustStridesForPermutation(AffineMap permMap,
180180 strides = applyPermutation (strides, perms64);
181181}
182182
183- // Computes memory strides for vector transfer operations, handling both
184- // static and dynamic memrefs while applying permutation transformations
185- // for XeGPU lowering.
186- static SmallVector<Value> computeStrides (VectorTransferOpInterface xferOp,
187- PatternRewriter &rewriter) {
183+ // Computes memory strides and a memref offset for vector transfer operations,
184+ // handling both static and dynamic memrefs while applying permutation
185+ // transformations for XeGPU lowering.
186+ static std::pair< SmallVector<Value>, Value>
187+ computeMemrefMeta (VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
188188 SmallVector<Value> strides;
189189 Value baseMemref = xferOp.getBase ();
190190 AffineMap permMap = xferOp.getPermutationMap ();
191191 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType ());
192192
193193 Location loc = xferOp.getLoc ();
194+ Value offsetVal = nullptr ;
194195 if (memrefType.hasStaticShape ()) {
195196 int64_t offset;
196197 SmallVector<int64_t > intStrides;
197198 if (failed (memrefType.getStridesAndOffset (intStrides, offset)))
198- return {};
199+ return {{}, offsetVal };
199200 // Wrap static strides as MLIR values
200201 for (int64_t s : intStrides)
201202 strides.push_back (arith::ConstantIndexOp::create (rewriter, loc, s));
202- } else {
203+ if (!ShapedType::isDynamic (offset))
204+ offsetVal = arith::ConstantIndexOp::create (rewriter, loc, offset);
205+ }
206+
207+ if (strides.empty () || !offsetVal) {
203208 // For dynamic shape memref, use memref.extract_strided_metadata to get
204209 // stride values
205210 unsigned rank = memrefType.getRank ();
@@ -220,11 +225,16 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
220225
221226 auto meta = memref::ExtractStridedMetadataOp::create (
222227 rewriter, loc, resultTypes, baseMemref);
223- strides.append (meta.getStrides ().begin (), meta.getStrides ().end ());
228+
229+ if (strides.empty ())
230+ strides.append (meta.getStrides ().begin (), meta.getStrides ().end ());
231+
232+ if (!offsetVal)
233+ offsetVal = meta.getOffset ();
224234 }
225235 // Adjust strides according to the permutation map (e.g., for transpose)
226236 adjustStridesForPermutation (permMap, strides);
227- return strides;
237+ return { strides, offsetVal} ;
228238}
229239
230240// This function compute the vectors of localOffsets for scattered load/stores.
@@ -254,10 +264,10 @@ static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
254264// %23 = arith.add %20, %21
255265// %local_offsets = arith.add %22, %23
256266// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
257- // %offsets = orig_offset + local_offsets
267+ // %offsets = memref_offset + orig_offset + local_offsets
258268static Value computeOffsets (VectorTransferOpInterface xferOp,
259- PatternRewriter &rewriter,
260- ArrayRef< Value> strides ) {
269+ PatternRewriter &rewriter, ArrayRef<Value> strides,
270+ Value baseOffset ) {
261271 Location loc = xferOp.getLoc ();
262272 VectorType vectorType = xferOp.getVectorType ();
263273 SmallVector<Value> indices (xferOp.getIndices ().begin (),
@@ -315,51 +325,30 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
315325 arith::AddIOp::create (rewriter, loc, localOffsets, broadcasted[i]);
316326
317327 // Compute base offset from transfer read indices
318- Value baseOffset = nullptr ;
319- if (!indices.empty ()) {
320- baseOffset = arith::ConstantIndexOp::create (rewriter, loc, 0 );
321- for (size_t i = 0 ; i < indices.size (); ++i) {
322- Value strideVal = strides[i];
323- Value offsetContrib =
324- arith::MulIOp::create (rewriter, loc, indices[i], strideVal);
325- baseOffset =
326- arith::AddIOp::create (rewriter, loc, baseOffset, offsetContrib);
327- }
328- // Broadcast base offset to match vector shape
329- Value bcastBase = vector::BroadcastOp::create (
330- rewriter, loc, fullIndexVectorType, baseOffset);
331- localOffsets =
332- arith::AddIOp::create (rewriter, loc, bcastBase, localOffsets);
328+ for (size_t i = 0 ; i < indices.size (); ++i) {
329+ Value strideVal = strides[i];
330+ Value offsetContrib =
331+ arith::MulIOp::create (rewriter, loc, indices[i], strideVal);
332+ baseOffset =
333+ arith::AddIOp::create (rewriter, loc, baseOffset, offsetContrib);
333334 }
335+ // Broadcast base offset to match vector shape
336+ Value bcastBase = vector::BroadcastOp::create (
337+ rewriter, loc, fullIndexVectorType, baseOffset);
338+ localOffsets = arith::AddIOp::create (rewriter, loc, bcastBase, localOffsets);
334339 return localOffsets;
335340}
336341
337- // Collapse memref shape to 1D
338- static Value collapseMemrefTo1D (VectorTransferOpInterface xferOp,
339- PatternRewriter &rewriter) {
342+ // Convert memref to i64 base pointer
343+ static Value memrefToIndexPtr (VectorTransferOpInterface xferOp,
344+ PatternRewriter &rewriter) {
340345 Location loc = xferOp.getLoc ();
341-
342- Value baseMemref = xferOp.getBase ();
343- MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType ());
344- Type elementType = memrefType.getElementType ();
345-
346- // Compute the total number of elements in the memref
347- MemRefType flatMemrefType;
348- if (memrefType.hasStaticShape ()) {
349- auto totalElements = memrefType.getNumElements ();
350- flatMemrefType = MemRefType::get ({totalElements}, elementType);
351- } else {
352- flatMemrefType = MemRefType::get ({ShapedType::kDynamic }, elementType);
353- }
354-
355- SmallVector<ReassociationIndices> reassociation;
356- ReassociationIndices allDims =
357- llvm::to_vector (llvm::seq<int64_t >(0 , memrefType.getRank ()));
358- reassociation.push_back (allDims);
359-
360- auto collapseOp = memref::CollapseShapeOp::create (
361- rewriter, loc, flatMemrefType, baseMemref, reassociation);
362- return collapseOp;
346+ auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create (
347+ rewriter, loc, xferOp.getBase ())
348+ .getResult ();
349+ return arith::IndexCastOp::create (rewriter, loc, rewriter.getI64Type (),
350+ indexPtr)
351+ .getResult ();
363352}
364353
365354static LogicalResult lowerToScatteredLoadOp (vector::TransferReadOp readOp,
@@ -372,13 +361,14 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
372361 if (!memrefType)
373362 return rewriter.notifyMatchFailure (readOp, " Expected memref source" );
374363
375- SmallVector<Value> strides = computeStrides (readOp, rewriter);
376- if (strides .empty ())
364+ auto meta = computeMemrefMeta (readOp, rewriter);
365+ if (meta. first .empty ())
377366 return rewriter.notifyMatchFailure (readOp, " Failed to compute strides" );
378367
379- Value localOffsets = computeOffsets (readOp, rewriter, strides);
368+ Value localOffsets =
369+ computeOffsets (readOp, rewriter, meta.first , meta.second );
380370
381- Value flatMemref = collapseMemrefTo1D (readOp, rewriter);
371+ Value flatMemref = memrefToIndexPtr (readOp, rewriter);
382372
383373 Value mask = vector::ConstantMaskOp::create (
384374 rewriter, loc, VectorType::get (vectorShape, rewriter.getI1Type ()),
@@ -405,11 +395,14 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
405395 if (!memrefType)
406396 return rewriter.notifyMatchFailure (writeOp, " Expected memref source" );
407397
408- SmallVector<Value> strides = computeStrides (writeOp, rewriter);
398+ auto meta = computeMemrefMeta (writeOp, rewriter);
399+ if (meta.first .empty ())
400+ return rewriter.notifyMatchFailure (writeOp, " Failed to compute strides" );
409401
410- Value localOffsets = computeOffsets (writeOp, rewriter, strides);
402+ Value localOffsets =
403+ computeOffsets (writeOp, rewriter, meta.first , meta.second );
411404
412- Value flatMemref = collapseMemrefTo1D (writeOp, rewriter);
405+ Value flatMemref = memrefToIndexPtr (writeOp, rewriter);
413406
414407 Value mask = vector::ConstantMaskOp::create (
415408 rewriter, loc, VectorType::get (vectorShape, rewriter.getI1Type ()),
0 commit comments