@@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16};
5050
5151// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
5252enum class NdTdescOffset : uint32_t {
53- BasePtr = 0 , // Base pointer (i64)
54- BaseShapeW = 2 , // Base shape width (i32)
55- BaseShapeH = 3 , // Base shape height (i32)
56- TensorOffsetW = 4 , // Tensor offset W (i32)
57- TensorOffsetH = 5 // Tensor offset H (i32)
53+ BasePtr = 0 , // Base pointer (i64)
54+ BaseShapeW = 2 , // Base shape width (i32)
55+ BaseShapeH = 3 , // Base shape height (i32)
56+ BasePitch = 4 , // Base pitch (i32)
5857};
5958
6059static int32_t getNumericXeVMAddrSpace (xegpu::MemorySpace xeGpuMemspace) {
@@ -179,11 +178,10 @@ class CreateNdDescToXeVMPattern
179178 Value baseAddr;
180179 Value baseShapeW;
181180 Value baseShapeH;
182- Value offsetW;
183- Value offsetH;
184181
185182 // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
186183 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes ();
184+ SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides ();
187185 // Descriptor shape is expected to be 2D.
188186 int64_t rank = mixedSizes.size ();
189187 auto sourceTy = source.getType ();
@@ -216,12 +214,11 @@ class CreateNdDescToXeVMPattern
216214 val = getValueOrCreateCastToIndexLike (rewriter, loc, payloadElemTy, val);
217215 return val;
218216 };
219- // Offsets are not supported (0 is used).
220- offsetW = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
221- offsetH = arith::ConstantIntOp::create (rewriter, loc, payloadElemTy, 0 );
222217 // Get shape values from op fold results.
223218 baseShapeW = createOffset (mixedSizes, 1 );
224219 baseShapeH = createOffset (mixedSizes, 0 );
220+ // Get pitch value from op fold results.
221+ Value basePitch = createOffset (mixedStrides, 0 );
225222 // Populate payload.
226223 Value payLoadAsI64 =
227224 vector::BitCastOp::create (rewriter, loc, payloadI64Ty, payload);
@@ -235,12 +232,9 @@ class CreateNdDescToXeVMPattern
235232 payload =
236233 vector::InsertOp::create (rewriter, loc, baseShapeH, payload,
237234 static_cast <int >(NdTdescOffset::BaseShapeH));
238- payload = vector::InsertOp::create (
239- rewriter, loc, offsetW, payload,
240- static_cast <int >(NdTdescOffset::TensorOffsetW));
241- payload = vector::InsertOp::create (
242- rewriter, loc, offsetH, payload,
243- static_cast <int >(NdTdescOffset::TensorOffsetH));
235+ payload =
236+ vector::InsertOp::create (rewriter, loc, basePitch, payload,
237+ static_cast <int >(NdTdescOffset::BasePitch));
244238 rewriter.replaceOp (op, payload);
245239 return success ();
246240 }
@@ -289,6 +283,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
289283 rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeW));
290284 Value baseShapeH = vector::ExtractOp::create (
291285 rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeH));
286+ Value basePitch = vector::ExtractOp::create (
287+ rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BasePitch));
292288 // Offsets are provided by the op.
293289 // convert them to i32.
294290 Value offsetW =
@@ -303,8 +299,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
303299 Value basePtrLLVM =
304300 LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, basePtr);
305301 // Compute width in bytes.
306- Value surfaceW =
302+ Value baseWidthByte =
307303 arith::MulIOp::create (rewriter, loc, baseShapeW, elemByteSize);
304+ // Compute pitch in bytes.
305+ Value basePitchByte =
306+ arith::MulIOp::create (rewriter, loc, basePitch, elemByteSize);
308307
309308 // Get tile width from the tensor descriptor type.
310309 auto tileW = tdescTy.getDimSize (tileRank - 1 );
@@ -331,18 +330,18 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
331330 auto storeCacheControl =
332331 translateStoreXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
333332 xevm::BlockStore2dOp::create (
334- rewriter, loc, basePtrLLVM, surfaceW , baseShapeH, surfaceW, offsetW ,
335- offsetH, elemBitSize, tileW, tileH, src,
333+ rewriter, loc, basePtrLLVM, baseWidthByte , baseShapeH,
334+ basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src,
336335 xevm::StoreCacheControlAttr::get (ctxt, storeCacheControl));
337336 rewriter.eraseOp (op);
338337 } else {
339338 auto loadCacheControl =
340339 translateLoadXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
341340 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
342341 xevm::BlockPrefetch2dOp::create (
343- rewriter, loc, basePtrLLVM, surfaceW , baseShapeH, surfaceW ,
344- offsetW, offsetH, elemBitSize, tileW, tileH, vblocks ,
345- xevm::LoadCacheControlAttr::get (ctxt, loadCacheControl));
342+ rewriter, loc, basePtrLLVM, baseWidthByte , baseShapeH,
343+ basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
344+ vblocks, xevm::LoadCacheControlAttr::get (ctxt, loadCacheControl));
346345 rewriter.eraseOp (op);
347346 } else {
348347 VectorType dstVecTy = cast<VectorType>(op.getValue ().getType ());
@@ -355,9 +354,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
355354 : rewriter.getIntegerType (elemBitSize));
356355
357356 Value resultFlatVec = xevm::BlockLoad2dOp::create (
358- rewriter, loc, loadedTy, basePtrLLVM, surfaceW , baseShapeH,
359- surfaceW , offsetW, offsetH, elemBitSize, tileW, tileH, vblocks ,
360- transpose, vnni,
357+ rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte , baseShapeH,
358+ basePitchByte , offsetW, offsetH, elemBitSize, tileW, tileH,
359+ vblocks, transpose, vnni,
361360 xevm::LoadCacheControlAttr::get (ctxt, loadCacheControl));
362361 resultFlatVec = vector::BitCastOp::create (
363362 rewriter, loc,
0 commit comments