Skip to content

Commit fd3186e

Browse files
committed
Add handler for 1D block load_nd/store_nd and test case.
1 parent 319205c commit fd3186e

File tree

1 file changed

+69
-4
lines changed

1 file changed

+69
-4
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,19 +292,25 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
292292
// Get address space from tensor descriptor memory space.
293293
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
294294
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
295-
if (tdescTy.getRank() == 2) {
295+
// Compute element byte size.
296+
Value elemByteSize = arith::ConstantIntOp::create(
297+
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
298+
auto tileRank = tdescTy.getRank();
299+
// Get tile width from the tensor descriptor type.
300+
auto tileW = tdescTy.getDimSize(tileRank - 1);
301+
if (tileRank == 2) {
296302
// Convert base pointer (i64) to LLVM pointer type.
297303
Value basePtrLLVM =
298304
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
299-
// Compute element byte size and surface width in bytes.
305+
// Compute width in bytes.
300306
Value elemByteSize = arith::ConstantIntOp::create(
301307
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
302308
Value surfaceW =
303309
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
304310

305-
// Get tile sizes and vblocks from the tensor descriptor type.
306-
auto tileW = tdescTy.getDimSize(1);
311+
// Get tile height from the tensor descriptor type.
307312
auto tileH = tdescTy.getDimSize(0);
313+
// Get vblocks from the tensor descriptor type.
308314
int32_t vblocks = tdescTy.getArrayLength();
309315
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
310316
Value src = adaptor.getValue();
@@ -360,6 +366,65 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
360366
rewriter.replaceOp(op, resultFlatVec);
361367
}
362368
}
369+
} else {
370+
// Get address from base address and offsets.
371+
// Offset in number of elements, need to multiply by element byte size.
372+
// Compute linear offset.
373+
// linearOffset = offsetH * baseShapeW + offsetW
374+
Value offsetHInElems =
375+
rewriter.createOrFold<arith::MulIOp>(loc, offsetH, baseShapeW);
376+
Value linearOffset =
377+
rewriter.createOrFold<arith::AddIOp>(loc, offsetHInElems, offsetW);
378+
// Then compute byte offset by multiplying with element byte size.
379+
// byteOffset = linearOffset * elemByteSize
380+
Value byteOffset =
381+
rewriter.createOrFold<arith::MulIOp>(loc, linearOffset, elemByteSize);
382+
// Final address = basePtr + byteOffset
383+
Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
384+
loc, basePtr,
385+
getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
386+
byteOffset));
387+
// Convert base pointer (i64) to LLVM pointer type.
388+
Value finalPtrLLVM =
389+
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
390+
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
391+
Value src = adaptor.getValue();
392+
// If store value is a scalar, get value from op instead of adaptor.
393+
// Adaptor might have optimized away single element vector
394+
if (src.getType().isIntOrFloat()) {
395+
src = op.getValue();
396+
}
397+
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
398+
if (!srcVecTy)
399+
return rewriter.notifyMatchFailure(
400+
op, "Expected store value to be a vector type.");
401+
// Get flat vector type of integer type with matching element bit size.
402+
VectorType newSrcVecTy =
403+
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
404+
if (srcVecTy != newSrcVecTy)
405+
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
406+
auto storeCacheControl =
407+
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
408+
rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
409+
op, finalPtrLLVM, src,
410+
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
411+
} else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
412+
auto loadCacheControl =
413+
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
414+
VectorType resTy = cast<VectorType>(op.getValue().getType());
415+
VectorType loadedTy =
416+
encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
417+
Value load = xevm::BlockLoadOp::create(
418+
rewriter, loc, loadedTy, finalPtrLLVM,
419+
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
420+
if (loadedTy != resTy)
421+
load = vector::BitCastOp::create(rewriter, loc, resTy, load);
422+
rewriter.replaceOp(op, load);
423+
} else {
424+
return rewriter.notifyMatchFailure(
425+
op, "Unsupported operation: xegpu.prefetch_nd with tensor "
426+
"descriptor rank == 1");
427+
}
363428
}
364429
return success();
365430
}

0 commit comments

Comments
 (0)