@@ -186,9 +186,6 @@ class CreateNdDescToXeVMPattern
186186 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes ();
187187 // Descriptor shape is expected to be 2D.
188188 int64_t rank = mixedSizes.size ();
189- if (rank != 2 )
190- return rewriter.notifyMatchFailure (op, " Expected 2D shape." );
191-
192189 auto sourceTy = source.getType ();
193190 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
194191 // If source is a memref, we need to extract the aligned pointer as index.
@@ -199,8 +196,19 @@ class CreateNdDescToXeVMPattern
199196 }
200197 baseAddr =
201198 memref::ExtractAlignedPointerAsIndexOp::create (rewriter, loc, source);
199+ // Cast index to i64.
200+ baseAddr = arith::IndexCastUIOp::create (rewriter, loc, i64Ty, baseAddr);
202201 } else {
203202 baseAddr = adaptor.getSource ();
203+ if (baseAddr.getType () != i64Ty) {
204+ // Pointer type may be i32. Cast to i64 if needed.
205+ baseAddr = arith::ExtUIOp::create (rewriter, loc, i64Ty, baseAddr);
206+ }
207+ }
208+ // 1D tensor descriptor is just the base address.
209+ if (rank == 1 ) {
210+ rewriter.replaceOp (op, baseAddr);
211+ return success ();
204212 }
205213 // Utility for creating offset values from op fold result.
206214 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
@@ -215,13 +223,6 @@ class CreateNdDescToXeVMPattern
215223 // Get shape values from op fold results.
216224 baseShapeW = createOffset (mixedSizes, 1 );
217225 baseShapeH = createOffset (mixedSizes, 0 );
218- if (sourceMemrefTy) {
219- // Cast index to i64.
220- baseAddr = arith::IndexCastUIOp::create (rewriter, loc, i64Ty, baseAddr);
221- } else if (baseAddr.getType () != i64Ty) {
222- // Pointer type may be i32. Cast to i64 if needed.
223- baseAddr = arith::ExtUIOp::create (rewriter, loc, i64Ty, baseAddr);
224- }
225226 // Populate payload.
226227 Value payLoadAsI64 =
227228 vector::BitCastOp::create (rewriter, loc, payloadI64Ty, payload);
@@ -257,57 +258,57 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
257258 ConversionPatternRewriter &rewriter) const override {
258259 auto mixedOffsets = op.getMixedOffsets ();
259260 int64_t opOffsetsSize = mixedOffsets.size ();
260- if (opOffsetsSize != 2 )
261- return rewriter.notifyMatchFailure (op, " Expected 2D offsets." );
262261 auto loc = op.getLoc ();
263262 auto ctxt = rewriter.getContext ();
264263
265264 auto tdesc = adaptor.getTensorDesc ();
266265 auto tdescTy = op.getTensorDescType ();
266+ auto tileRank = tdescTy.getRank ();
267+ if (opOffsetsSize != tileRank)
268+ return rewriter.notifyMatchFailure (
269+ op, " Expected offset rank to match descriptor rank." );
267270 auto elemType = tdescTy.getElementType ();
268271 auto elemBitSize = elemType.getIntOrFloatBitWidth ();
269272 if (elemBitSize % 8 != 0 )
270273 return rewriter.notifyMatchFailure (
271274 op, " Expected element type bit width to be multiple of 8." );
272275
273- VectorType payloadI64Ty = VectorType::get (4 , rewriter.getI64Type ());
274- Value payLoadAsI64 =
275- vector::BitCastOp::create (rewriter, loc, payloadI64Ty, tdesc);
276- Value basePtr = vector::ExtractOp::create (
277- rewriter, loc, payLoadAsI64, static_cast <int >(NdTdescOffset::BasePtr));
278- Value baseShapeW = vector::ExtractOp::create (
279- rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeW));
280- Value baseShapeH = vector::ExtractOp::create (
281- rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeH));
282- // Offsets are provided by the op.
283- // convert them to i32.
284- Value offsetW =
285- getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[1 ]);
286- offsetW = getValueOrCreateCastToIndexLike (rewriter, loc,
287- rewriter.getI32Type (), offsetW);
288- Value offsetH =
289- getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[0 ]);
290- offsetH = getValueOrCreateCastToIndexLike (rewriter, loc,
291- rewriter.getI32Type (), offsetH);
292276 // Get address space from tensor descriptor memory space.
293277 auto ptrTypeLLVM = LLVM::LLVMPointerType::get (
294278 ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
295- // Compute element byte size.
296- Value elemByteSize = arith::ConstantIntOp::create (
297- rewriter, loc, rewriter.getI32Type (), elemBitSize / 8 );
298- auto tileRank = tdescTy.getRank ();
299- // Get tile width from the tensor descriptor type.
300- auto tileW = tdescTy.getDimSize (tileRank - 1 );
301279 if (tileRank == 2 ) {
280+ // Compute element byte size.
281+ Value elemByteSize = arith::ConstantIntOp::create (
282+ rewriter, loc, rewriter.getI32Type (), elemBitSize / 8 );
283+ VectorType payloadI64Ty = VectorType::get (4 , rewriter.getI64Type ());
284+ Value payLoadAsI64 =
285+ vector::BitCastOp::create (rewriter, loc, payloadI64Ty, tdesc);
286+ Value basePtr =
287+ vector::ExtractOp::create (rewriter, loc, payLoadAsI64,
288+ static_cast <int >(NdTdescOffset::BasePtr));
289+ Value baseShapeW = vector::ExtractOp::create (
290+ rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeW));
291+ Value baseShapeH = vector::ExtractOp::create (
292+ rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeH));
293+ // Offsets are provided by the op.
294+ // convert them to i32.
295+ Value offsetW =
296+ getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[1 ]);
297+ offsetW = getValueOrCreateCastToIndexLike (rewriter, loc,
298+ rewriter.getI32Type (), offsetW);
299+ Value offsetH =
300+ getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[0 ]);
301+ offsetH = getValueOrCreateCastToIndexLike (rewriter, loc,
302+ rewriter.getI32Type (), offsetH);
302303 // Convert base pointer (i64) to LLVM pointer type.
303304 Value basePtrLLVM =
304305 LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, basePtr);
305306 // Compute width in bytes.
306- Value elemByteSize = arith::ConstantIntOp::create (
307- rewriter, loc, rewriter.getI32Type (), elemBitSize / 8 );
308307 Value surfaceW =
309308 arith::MulIOp::create (rewriter, loc, baseShapeW, elemByteSize);
310309
310+ // Get tile width from the tensor descriptor type.
311+ auto tileW = tdescTy.getDimSize (tileRank - 1 );
311312 // Get tile height from the tensor descriptor type.
312313 auto tileH = tdescTy.getDimSize (0 );
313314 // Get vblocks from the tensor descriptor type.
@@ -367,21 +368,23 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
367368 }
368369 }
369370 } else {
370- // Get address from base address and offsets.
371+ // 1D tensor descriptor.
372+ // `tdesc` represents base address as i64
371373 // Offset in number of elements, need to multiply by element byte size.
372- // Compute linear offset.
373- // linearOffset = offsetH * baseShapeW + offsetW
374- Value offsetHInElems =
375- rewriter.createOrFold <arith::MulIOp>(loc, offsetH, baseShapeW);
376- Value linearOffset =
377- rewriter.createOrFold <arith::AddIOp>(loc, offsetHInElems, offsetW);
378- // Then compute byte offset by multiplying with element byte size.
379- // byteOffset = linearOffset * elemByteSize
374+ // Compute byte offset.
375+ // byteOffset = offset * elementByteSize
376+ Value offset =
377+ getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[0 ]);
378+ offset = getValueOrCreateCastToIndexLike (rewriter, loc,
379+ rewriter.getI64Type (), offset);
380+ // Compute element byte size.
381+ Value elemByteSize = arith::ConstantIntOp::create (
382+ rewriter, loc, rewriter.getI64Type (), elemBitSize / 8 );
380383 Value byteOffset =
381- rewriter.createOrFold <arith::MulIOp>(loc, linearOffset , elemByteSize);
384+ rewriter.createOrFold <arith::MulIOp>(loc, offset , elemByteSize);
382385 // Final address = basePtr + byteOffset
383386 Value finalAddrI64 = rewriter.createOrFold <arith::AddIOp>(
384- loc, basePtr ,
387+ loc, tdesc ,
385388 getValueOrCreateCastToIndexLike (rewriter, loc, rewriter.getI64Type (),
386389 byteOffset));
387390 // Convert base pointer (i64) to LLVM pointer type.
@@ -992,7 +995,10 @@ struct ConvertXeGPUToXeVMPass
992995 return VectorType::get (sum, elemType);
993996 });
994997 typeConverter.addConversion ([&](xegpu::TensorDescType type) -> Type {
998+ // Scattered descriptors are not supported in XeVM lowering.
995999 if (type.isScattered ())
1000+ return {};
1001+ if (type.getRank () == 1 )
9961002 return IntegerType::get (&getContext (), 64 );
9971003 auto i32Type = IntegerType::get (&getContext (), 32 );
9981004 return VectorType::get (8 , i32Type);
0 commit comments