@@ -292,19 +292,25 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
292292 // Get address space from tensor descriptor memory space.
293293 auto ptrTypeLLVM = LLVM::LLVMPointerType::get (
294294 ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
295- if (tdescTy.getRank () == 2 ) {
295+ // Compute element byte size.
296+ Value elemByteSize = arith::ConstantIntOp::create (
297+ rewriter, loc, rewriter.getI32Type (), elemBitSize / 8 );
298+ auto tileRank = tdescTy.getRank ();
299+ // Get tile width from the tensor descriptor type.
300+ auto tileW = tdescTy.getDimSize (tileRank - 1 );
301+ if (tileRank == 2 ) {
296302 // Convert base pointer (i64) to LLVM pointer type.
297303 Value basePtrLLVM =
298304 LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, basePtr);
299- // Compute element byte size and surface width in bytes.
305+ // Compute width in bytes.
300306 Value elemByteSize = arith::ConstantIntOp::create (
301307 rewriter, loc, rewriter.getI32Type (), elemBitSize / 8 );
302308 Value surfaceW =
303309 arith::MulIOp::create (rewriter, loc, baseShapeW, elemByteSize);
304310
305- // Get tile sizes and vblocks from the tensor descriptor type.
306- auto tileW = tdescTy.getDimSize (1 );
311+ // Get tile height from the tensor descriptor type.
307312 auto tileH = tdescTy.getDimSize (0 );
313+ // Get vblocks from the tensor descriptor type.
308314 int32_t vblocks = tdescTy.getArrayLength ();
309315 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
310316 Value src = adaptor.getValue ();
@@ -360,6 +366,65 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
360366 rewriter.replaceOp (op, resultFlatVec);
361367 }
362368 }
369+ } else {
370+ // Get address from base address and offsets.
371+ // Offset in number of elements, need to multiply by element byte size.
372+ // Compute linear offset.
373+ // linearOffset = offsetH * baseShapeW + offsetW
374+ Value offsetHInElems =
375+ rewriter.createOrFold <arith::MulIOp>(loc, offsetH, baseShapeW);
376+ Value linearOffset =
377+ rewriter.createOrFold <arith::AddIOp>(loc, offsetHInElems, offsetW);
378+ // Then compute byte offset by multiplying with element byte size.
379+ // byteOffset = linearOffset * elemByteSize
380+ Value byteOffset =
381+ rewriter.createOrFold <arith::MulIOp>(loc, linearOffset, elemByteSize);
382+ // Final address = basePtr + byteOffset
383+ Value finalAddrI64 = rewriter.createOrFold <arith::AddIOp>(
384+ loc, basePtr,
385+ getValueOrCreateCastToIndexLike (rewriter, loc, rewriter.getI64Type (),
386+ byteOffset));
387+ // Convert base pointer (i64) to LLVM pointer type.
388+ Value finalPtrLLVM =
389+ LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, finalAddrI64);
390+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
391+ Value src = adaptor.getValue ();
392+ // If store value is a scalar, get value from op instead of adaptor.
393+ // Adaptor might have optimized away single element vector
394+ if (src.getType ().isIntOrFloat ()) {
395+ src = op.getValue ();
396+ }
397+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType ());
398+ if (!srcVecTy)
399+ return rewriter.notifyMatchFailure (
400+ op, " Expected store value to be a vector type." );
401+ // Get flat vector type of integer type with matching element bit size.
402+ VectorType newSrcVecTy =
403+ encodeVectorTypeTo (srcVecTy, rewriter.getIntegerType (elemBitSize));
404+ if (srcVecTy != newSrcVecTy)
405+ src = vector::BitCastOp::create (rewriter, loc, newSrcVecTy, src);
406+ auto storeCacheControl =
407+ translateStoreXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
408+ rewriter.replaceOpWithNewOp <xevm::BlockStoreOp>(
409+ op, finalPtrLLVM, src,
410+ xevm::StoreCacheControlAttr::get (ctxt, storeCacheControl));
411+ } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
412+ auto loadCacheControl =
413+ translateLoadXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
414+ VectorType resTy = cast<VectorType>(op.getValue ().getType ());
415+ VectorType loadedTy =
416+ encodeVectorTypeTo (resTy, rewriter.getIntegerType (elemBitSize));
417+ Value load = xevm::BlockLoadOp::create (
418+ rewriter, loc, loadedTy, finalPtrLLVM,
419+ xevm::LoadCacheControlAttr::get (ctxt, loadCacheControl));
420+ if (loadedTy != resTy)
421+ load = vector::BitCastOp::create (rewriter, loc, resTy, load);
422+ rewriter.replaceOp (op, load);
423+ } else {
424+ return rewriter.notifyMatchFailure (
425+ op, " Unsupported operation: xegpu.prefetch_nd with tensor "
426+ " descriptor rank == 1" );
427+ }
363428 }
364429 return success ();
365430 }
0 commit comments