@@ -186,9 +186,6 @@ class CreateNdDescToXeVMPattern
186186 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes ();
187187 // Descriptor shape is expected to be 2D.
188188 int64_t rank = mixedSizes.size ();
189- if (rank != 2 )
190- return rewriter.notifyMatchFailure (op, " Expected 2D shape." );
191-
192189 auto sourceTy = source.getType ();
193190 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
194191 // If source is a memref, we need to extract the aligned pointer as index.
@@ -199,8 +196,19 @@ class CreateNdDescToXeVMPattern
199196 }
200197 baseAddr =
201198 memref::ExtractAlignedPointerAsIndexOp::create (rewriter, loc, source);
199+ // Cast index to i64.
200+ baseAddr = arith::IndexCastUIOp::create (rewriter, loc, i64Ty, baseAddr);
202201 } else {
203202 baseAddr = adaptor.getSource ();
203+ if (baseAddr.getType () != i64Ty) {
204+ // Pointer type may be i32. Cast to i64 if needed.
205+ baseAddr = arith::ExtUIOp::create (rewriter, loc, i64Ty, baseAddr);
206+ }
207+ }
208+ // 1D tensor descriptor is just the base address.
209+ if (rank == 1 ) {
210+ rewriter.replaceOp (op, baseAddr);
211+ return success ();
204212 }
205213 // Utility for creating offset values from op fold result.
206214 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
@@ -215,13 +223,6 @@ class CreateNdDescToXeVMPattern
215223 // Get shape values from op fold results.
216224 baseShapeW = createOffset (mixedSizes, 1 );
217225 baseShapeH = createOffset (mixedSizes, 0 );
218- if (sourceMemrefTy) {
219- // Cast index to i64.
220- baseAddr = arith::IndexCastUIOp::create (rewriter, loc, i64Ty, baseAddr);
221- } else if (baseAddr.getType () != i64Ty) {
222- // Pointer type may be i32. Cast to i64 if needed.
223- baseAddr = arith::ExtUIOp::create (rewriter, loc, i64Ty, baseAddr);
224- }
225226 // Populate payload.
226227 Value payLoadAsI64 =
227228 vector::BitCastOp::create (rewriter, loc, payloadI64Ty, payload);
@@ -257,108 +258,175 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
257258 ConversionPatternRewriter &rewriter) const override {
258259 auto mixedOffsets = op.getMixedOffsets ();
259260 int64_t opOffsetsSize = mixedOffsets.size ();
260- if (opOffsetsSize != 2 )
261- return rewriter.notifyMatchFailure (op, " Expected 2D offsets." );
262261 auto loc = op.getLoc ();
263262 auto ctxt = rewriter.getContext ();
264263
265264 auto tdesc = adaptor.getTensorDesc ();
266265 auto tdescTy = op.getTensorDescType ();
267- if (tdescTy.getRank () != 2 )
268- return rewriter.notifyMatchFailure (op, " Expected 2D tensor descriptor." );
266+ auto tileRank = tdescTy.getRank ();
267+ if (opOffsetsSize != tileRank)
268+ return rewriter.notifyMatchFailure (
269+ op, " Expected offset rank to match descriptor rank." );
269270 auto elemType = tdescTy.getElementType ();
270271 auto elemBitSize = elemType.getIntOrFloatBitWidth ();
271272 if (elemBitSize % 8 != 0 )
272273 return rewriter.notifyMatchFailure (
273274 op, " Expected element type bit width to be multiple of 8." );
274275
275- VectorType payloadI64Ty = VectorType::get (4 , rewriter.getI64Type ());
276- Value payLoadAsI64 =
277- vector::BitCastOp::create (rewriter, loc, payloadI64Ty, tdesc);
278- Value basePtr = vector::ExtractOp::create (
279- rewriter, loc, payLoadAsI64, static_cast <int >(NdTdescOffset::BasePtr));
280- Value baseShapeW = vector::ExtractOp::create (
281- rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeW));
282- Value baseShapeH = vector::ExtractOp::create (
283- rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeH));
284- // Offsets are provided by the op.
285- // convert them to i32.
286- Value offsetW =
287- getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[1 ]);
288- offsetW = getValueOrCreateCastToIndexLike (rewriter, loc,
289- rewriter.getI32Type (), offsetW);
290- Value offsetH =
291- getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[0 ]);
292- offsetH = getValueOrCreateCastToIndexLike (rewriter, loc,
293- rewriter.getI32Type (), offsetH);
294276 // Get address space from tensor descriptor memory space.
295277 auto ptrTypeLLVM = LLVM::LLVMPointerType::get (
296278 ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
297- // Convert base pointer (i64) to LLVM pointer type.
298- Value basePtrLLVM =
299- LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, basePtr);
300- // Compute element byte size and surface width in bytes.
301- Value elemByteSize = arith::ConstantIntOp::create (
302- rewriter, loc, rewriter.getI32Type (), elemBitSize / 8 );
303- Value surfaceW =
304- arith::MulIOp::create (rewriter, loc, baseShapeW, elemByteSize);
305-
306- // Get tile sizes and vblocks from the tensor descriptor type.
307- auto tileW = tdescTy.getDimSize (1 );
308- auto tileH = tdescTy.getDimSize (0 );
309- int32_t vblocks = tdescTy.getArrayLength ();
310- if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
311- Value src = adaptor.getValue ();
312- // If store value is a scalar, get value from op instead of adaptor.
313- // Adaptor might have optimized away single element vector
314- if (src.getType ().isIntOrFloat ()) {
315- src = op.getValue ();
316- }
317- VectorType srcVecTy = dyn_cast<VectorType>(src.getType ());
318- if (!srcVecTy)
319- return rewriter.notifyMatchFailure (
320- op, " Expected store value to be a vector type." );
321- // Get flat vector type of integer type with matching element bit size.
322- VectorType newSrcVecTy =
323- encodeVectorTypeTo (srcVecTy, rewriter.getIntegerType (elemBitSize));
324- if (srcVecTy != newSrcVecTy)
325- src = vector::BitCastOp::create (rewriter, loc, newSrcVecTy, src);
326- auto storeCacheControl =
327- translateStoreXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
328- xevm::BlockStore2dOp::create (
329- rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
330- offsetH, elemBitSize, tileW, tileH, src,
331- xevm::StoreCacheControlAttr::get (ctxt, storeCacheControl));
332- rewriter.eraseOp (op);
333- } else {
334- auto loadCacheControl =
335- translateLoadXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
336- if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
337- xevm::BlockPrefetch2dOp::create (
279+ if (tileRank == 2 ) {
280+ // Compute element byte size.
281+ Value elemByteSize = arith::ConstantIntOp::create (
282+ rewriter, loc, rewriter.getI32Type (), elemBitSize / 8 );
283+ VectorType payloadI64Ty = VectorType::get (4 , rewriter.getI64Type ());
284+ Value payLoadAsI64 =
285+ vector::BitCastOp::create (rewriter, loc, payloadI64Ty, tdesc);
286+ Value basePtr =
287+ vector::ExtractOp::create (rewriter, loc, payLoadAsI64,
288+ static_cast <int >(NdTdescOffset::BasePtr));
289+ Value baseShapeW = vector::ExtractOp::create (
290+ rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeW));
291+ Value baseShapeH = vector::ExtractOp::create (
292+ rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeH));
293+ // Offsets are provided by the op.
294+ // convert them to i32.
295+ Value offsetW =
296+ getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[1 ]);
297+ offsetW = getValueOrCreateCastToIndexLike (rewriter, loc,
298+ rewriter.getI32Type (), offsetW);
299+ Value offsetH =
300+ getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[0 ]);
301+ offsetH = getValueOrCreateCastToIndexLike (rewriter, loc,
302+ rewriter.getI32Type (), offsetH);
303+ // Convert base pointer (i64) to LLVM pointer type.
304+ Value basePtrLLVM =
305+ LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, basePtr);
306+ // Compute width in bytes.
307+ Value surfaceW =
308+ arith::MulIOp::create (rewriter, loc, baseShapeW, elemByteSize);
309+
310+ // Get tile width from the tensor descriptor type.
311+ auto tileW = tdescTy.getDimSize (tileRank - 1 );
312+ // Get tile height from the tensor descriptor type.
313+ auto tileH = tdescTy.getDimSize (0 );
314+ // Get vblocks from the tensor descriptor type.
315+ int32_t vblocks = tdescTy.getArrayLength ();
316+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
317+ Value src = adaptor.getValue ();
318+ // If store value is a scalar, get value from op instead of adaptor.
319+ // Adaptor might have optimized away single element vector
320+ if (src.getType ().isIntOrFloat ()) {
321+ src = op.getValue ();
322+ }
323+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType ());
324+ if (!srcVecTy)
325+ return rewriter.notifyMatchFailure (
326+ op, " Expected store value to be a vector type." );
327+ // Get flat vector type of integer type with matching element bit size.
328+ VectorType newSrcVecTy =
329+ encodeVectorTypeTo (srcVecTy, rewriter.getIntegerType (elemBitSize));
330+ if (srcVecTy != newSrcVecTy)
331+ src = vector::BitCastOp::create (rewriter, loc, newSrcVecTy, src);
332+ auto storeCacheControl =
333+ translateStoreXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
334+ xevm::BlockStore2dOp::create (
338335 rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
339- offsetH, elemBitSize, tileW, tileH, vblocks ,
340- xevm::LoadCacheControlAttr ::get (ctxt, loadCacheControl ));
336+ offsetH, elemBitSize, tileW, tileH, src ,
337+ xevm::StoreCacheControlAttr ::get (ctxt, storeCacheControl ));
341338 rewriter.eraseOp (op);
342339 } else {
343- VectorType dstVecTy = cast<VectorType>(op.getValue ().getType ());
344- const bool vnni = op.getPacked ().value_or (false );
345- auto transposeValue = op.getTranspose ();
346- bool transpose =
347- transposeValue.has_value () && transposeValue.value ()[0 ] == 1 ;
348- VectorType loadedTy = encodeVectorTypeTo (
349- dstVecTy, vnni ? rewriter.getI32Type ()
350- : rewriter.getIntegerType (elemBitSize));
351-
352- Value resultFlatVec = xevm::BlockLoad2dOp::create (
353- rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
354- surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
355- transpose, vnni,
340+ auto loadCacheControl =
341+ translateLoadXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
342+ if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
343+ xevm::BlockPrefetch2dOp::create (
344+ rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
345+ offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
346+ xevm::LoadCacheControlAttr::get (ctxt, loadCacheControl));
347+ rewriter.eraseOp (op);
348+ } else {
349+ VectorType dstVecTy = cast<VectorType>(op.getValue ().getType ());
350+ const bool vnni = op.getPacked ().value_or (false );
351+ auto transposeValue = op.getTranspose ();
352+ bool transpose =
353+ transposeValue.has_value () && transposeValue.value ()[0 ] == 1 ;
354+ VectorType loadedTy = encodeVectorTypeTo (
355+ dstVecTy, vnni ? rewriter.getI32Type ()
356+ : rewriter.getIntegerType (elemBitSize));
357+
358+ Value resultFlatVec = xevm::BlockLoad2dOp::create (
359+ rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
360+ surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
361+ transpose, vnni,
362+ xevm::LoadCacheControlAttr::get (ctxt, loadCacheControl));
363+ resultFlatVec = vector::BitCastOp::create (
364+ rewriter, loc,
365+ encodeVectorTypeTo (loadedTy, dstVecTy.getElementType ()),
366+ resultFlatVec);
367+ rewriter.replaceOp (op, resultFlatVec);
368+ }
369+ }
370+ } else {
371+ // 1D tensor descriptor.
372+ // `tdesc` represents base address as i64
373+ // Offset in number of elements, need to multiply by element byte size.
374+ // Compute byte offset.
375+ // byteOffset = offset * elementByteSize
376+ Value offset =
377+ getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[0 ]);
378+ offset = getValueOrCreateCastToIndexLike (rewriter, loc,
379+ rewriter.getI64Type (), offset);
380+ // Compute element byte size.
381+ Value elemByteSize = arith::ConstantIntOp::create (
382+ rewriter, loc, rewriter.getI64Type (), elemBitSize / 8 );
383+ Value byteOffset =
384+ rewriter.createOrFold <arith::MulIOp>(loc, offset, elemByteSize);
385+ // Final address = basePtr + byteOffset
386+ Value finalAddrI64 = rewriter.createOrFold <arith::AddIOp>(
387+ loc, tdesc,
388+ getValueOrCreateCastToIndexLike (rewriter, loc, rewriter.getI64Type (),
389+ byteOffset));
390+ // Convert base pointer (i64) to LLVM pointer type.
391+ Value finalPtrLLVM =
392+ LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, finalAddrI64);
393+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
394+ Value src = adaptor.getValue ();
395+ // If store value is a scalar, get value from op instead of adaptor.
396+ // Adaptor might have optimized away single element vector
397+ if (src.getType ().isIntOrFloat ()) {
398+ src = op.getValue ();
399+ }
400+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType ());
401+ if (!srcVecTy)
402+ return rewriter.notifyMatchFailure (
403+ op, " Expected store value to be a vector type." );
404+ // Get flat vector type of integer type with matching element bit size.
405+ VectorType newSrcVecTy =
406+ encodeVectorTypeTo (srcVecTy, rewriter.getIntegerType (elemBitSize));
407+ if (srcVecTy != newSrcVecTy)
408+ src = vector::BitCastOp::create (rewriter, loc, newSrcVecTy, src);
409+ auto storeCacheControl =
410+ translateStoreXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
411+ rewriter.replaceOpWithNewOp <xevm::BlockStoreOp>(
412+ op, finalPtrLLVM, src,
413+ xevm::StoreCacheControlAttr::get (ctxt, storeCacheControl));
414+ } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
415+ auto loadCacheControl =
416+ translateLoadXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
417+ VectorType resTy = cast<VectorType>(op.getValue ().getType ());
418+ VectorType loadedTy =
419+ encodeVectorTypeTo (resTy, rewriter.getIntegerType (elemBitSize));
420+ Value load = xevm::BlockLoadOp::create (
421+ rewriter, loc, loadedTy, finalPtrLLVM,
356422 xevm::LoadCacheControlAttr::get (ctxt, loadCacheControl));
357- resultFlatVec = vector::BitCastOp::create (
358- rewriter, loc,
359- encodeVectorTypeTo (loadedTy, dstVecTy.getElementType ()),
360- resultFlatVec);
361- rewriter.replaceOp (op, resultFlatVec);
423+ if (loadedTy != resTy)
424+ load = vector::BitCastOp::create (rewriter, loc, resTy, load);
425+ rewriter.replaceOp (op, load);
426+ } else {
427+ return rewriter.notifyMatchFailure (
428+ op, " Unsupported operation: xegpu.prefetch_nd with tensor "
429+ " descriptor rank == 1" );
362430 }
363431 }
364432 return success ();
@@ -929,7 +997,10 @@ struct ConvertXeGPUToXeVMPass
929997 return VectorType::get (sum, elemType);
930998 });
931999 typeConverter.addConversion ([&](xegpu::TensorDescType type) -> Type {
1000+ // Scattered descriptors are not supported in XeVM lowering.
9321001 if (type.isScattered ())
1002+ return {};
1003+ if (type.getRank () == 1 )
9331004 return IntegerType::get (&getContext (), 64 );
9341005 auto i32Type = IntegerType::get (&getContext (), 32 );
9351006 return VectorType::get (8 , i32Type);
0 commit comments