Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 126 additions & 61 deletions mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {

auto tdesc = adaptor.getTensorDesc();
auto tdescTy = op.getTensorDescType();
if (tdescTy.getRank() != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
auto elemType = tdescTy.getElementType();
auto elemBitSize = elemType.getIntOrFloatBitWidth();
if (elemBitSize % 8 != 0)
Expand Down Expand Up @@ -294,71 +292,138 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
// Compute element byte size and surface width in bytes.
// Compute element byte size.
Value elemByteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
Value surfaceW =
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);

// Get tile sizes and vblocks from the tensor descriptor type.
auto tileW = tdescTy.getDimSize(1);
auto tileH = tdescTy.getDimSize(0);
int32_t vblocks = tdescTy.getArrayLength();
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
Value src = adaptor.getValue();
// If store value is a scalar, get value from op instead of adaptor.
// Adaptor might have optimized away single element vector
if (src.getType().isIntOrFloat()) {
src = op.getValue();
}
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
if (!srcVecTy)
return rewriter.notifyMatchFailure(
op, "Expected store value to be a vector type.");
// Get flat vector type of integer type with matching element bit size.
VectorType newSrcVecTy =
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
if (srcVecTy != newSrcVecTy)
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
offsetH, elemBitSize, tileW, tileH, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
auto loadCacheControl =
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
xevm::BlockPrefetch2dOp::create(
auto tileRank = tdescTy.getRank();
// Get tile width from the tensor descriptor type.
auto tileW = tdescTy.getDimSize(tileRank - 1);
if (tileRank == 2) {
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
// Compute width in bytes.
Value elemByteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
Value surfaceW =
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);

// Get tile height from the tensor descriptor type.
auto tileH = tdescTy.getDimSize(0);
// Get vblocks from the tensor descriptor type.
int32_t vblocks = tdescTy.getArrayLength();
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
Value src = adaptor.getValue();
// If store value is a scalar, get value from op instead of adaptor.
// Adaptor might have optimized away single element vector
if (src.getType().isIntOrFloat()) {
src = op.getValue();
}
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
if (!srcVecTy)
return rewriter.notifyMatchFailure(
op, "Expected store value to be a vector type.");
// Get flat vector type of integer type with matching element bit size.
VectorType newSrcVecTy =
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
if (srcVecTy != newSrcVecTy)
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
offsetH, elemBitSize, tileW, tileH, vblocks,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
offsetH, elemBitSize, tileW, tileH, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
const bool vnni = op.getPacked().value_or(false);
auto transposeValue = op.getTranspose();
bool transpose =
transposeValue.has_value() && transposeValue.value()[0] == 1;
VectorType loadedTy = encodeVectorTypeTo(
dstVecTy, vnni ? rewriter.getI32Type()
: rewriter.getIntegerType(elemBitSize));

Value resultFlatVec = xevm::BlockLoad2dOp::create(
rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
transpose, vnni,
auto loadCacheControl =
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
xevm::BlockPrefetch2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
rewriter.eraseOp(op);
} else {
VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
const bool vnni = op.getPacked().value_or(false);
auto transposeValue = op.getTranspose();
bool transpose =
transposeValue.has_value() && transposeValue.value()[0] == 1;
VectorType loadedTy = encodeVectorTypeTo(
dstVecTy, vnni ? rewriter.getI32Type()
: rewriter.getIntegerType(elemBitSize));

Value resultFlatVec = xevm::BlockLoad2dOp::create(
rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
transpose, vnni,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
resultFlatVec = vector::BitCastOp::create(
rewriter, loc,
encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
resultFlatVec);
rewriter.replaceOp(op, resultFlatVec);
}
}
} else {
// Get address from base address and offsets.
// Offset in number of elements, need to multiply by element byte size.
// Compute linear offset.
// linearOffset = offsetH * baseShapeW + offsetW
Value offsetHInElems =
rewriter.createOrFold<arith::MulIOp>(loc, offsetH, baseShapeW);
Value linearOffset =
rewriter.createOrFold<arith::AddIOp>(loc, offsetHInElems, offsetW);
// Then compute byte offset by multiplying with element byte size.
// byteOffset = linearOffset * elemByteSize
Value byteOffset =
rewriter.createOrFold<arith::MulIOp>(loc, linearOffset, elemByteSize);
// Final address = basePtr + byteOffset
Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
loc, basePtr,
getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
byteOffset));
// Convert base pointer (i64) to LLVM pointer type.
Value finalPtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
Value src = adaptor.getValue();
// If store value is a scalar, get value from op instead of adaptor.
// Adaptor might have optimized away single element vector
if (src.getType().isIntOrFloat()) {
src = op.getValue();
}
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
if (!srcVecTy)
return rewriter.notifyMatchFailure(
op, "Expected store value to be a vector type.");
// Get flat vector type of integer type with matching element bit size.
VectorType newSrcVecTy =
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
if (srcVecTy != newSrcVecTy)
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
op, finalPtrLLVM, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
} else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
auto loadCacheControl =
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
VectorType resTy = cast<VectorType>(op.getValue().getType());
VectorType loadedTy =
encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
Value load = xevm::BlockLoadOp::create(
rewriter, loc, loadedTy, finalPtrLLVM,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
resultFlatVec = vector::BitCastOp::create(
rewriter, loc,
encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
resultFlatVec);
rewriter.replaceOp(op, resultFlatVec);
if (loadedTy != resTy)
load = vector::BitCastOp::create(rewriter, loc, resTy, load);
rewriter.replaceOp(op, load);
} else {
return rewriter.notifyMatchFailure(
op, "Unsupported operation: xegpu.prefetch_nd with tensor "
"descriptor rank == 1");
}
}
return success();
Expand Down
57 changes: 57 additions & 0 deletions mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s

gpu.module @load_store_check {
// CHECK-LABEL: @load_store(
// CHECK-SAME: %[[SRC:.*]]: memref<8x64xf32, 1>, %[[DST:.*]]: memref<8x32xf32, 1>
gpu.func @load_store(%src: memref<8x64xf32, 1>, %dst: memref<8x32xf32, 1>) kernel {
// CHECK: %[[C512:.*]] = arith.constant 512 : i64
// CHECK: %[[C32:.*]] = arith.constant 32 : i32
// CHECK: %[[C384:.*]] = arith.constant 384 : i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64>
// CHECK: %[[C8:.*]] = arith.constant 8 : i32
// CHECK: %[[C64:.*]] = arith.constant 64 : i32
// CHECK: %[[C0:.*]] = arith.constant 0 : i32

// CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<8x64xf32, 1> to memref<8x64xf32>
%srcce = memref.memory_space_cast %src : memref<8x64xf32, 1> to memref<8x64xf32>
// CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<8x32xf32, 1> to memref<8x32xf32>
%dstte = memref.memory_space_cast %dst : memref<8x32xf32, 1> to memref<8x32xf32>

// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<8x64xf32> -> index
// CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[VEC1:.*]] = vector.insert %[[INTPTR_I64]], %[[CST]] [0] : i64 into vector<4xi64>
// CHECK: %[[VEC2:.*]] = vector.bitcast %[[VEC1]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VEC3:.*]] = vector.insert %[[C64]], %[[VEC2]] [2] : i32 into vector<8xi32>
// CHECK: %[[VEC4:.*]] = vector.insert %[[C8]], %[[VEC3]] [3] : i32 into vector<8xi32>
// CHECK: %[[VEC5:.*]] = vector.insert %[[C0]], %[[VEC4]] [4] : i32 into vector<8xi32>
// CHECK: %[[VEC6:.*]] = vector.insert %[[C0]], %[[VEC5]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x64xf32> -> !xegpu.tensor_desc<32xf32>
// CHECK: %[[VEC7:.*]] = vector.bitcast %[[VEC6]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[EXTR:.*]] = vector.extract %[[VEC7]][0] : i64 from vector<4xi64>
// CHECK: %[[ADDR:.*]] = arith.addi %[[EXTR]], %[[C384]] : i64
// CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ADDR]] : i64 to !llvm.ptr<1>
// CHECK: %[[LOAD:.*]] = xevm.blockload %[[PTR]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}>
// CHECK-SAME: : (!llvm.ptr<1>) -> vector<2xi32>
%loaded = xegpu.load_nd %src_tdesc[1, 32] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
: !xegpu.tensor_desc<32xf32> -> vector<2xf32>

// CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<8x32xf32> -> index
// CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64
// CHECK: %[[VEC1_1:.*]] = vector.insert %[[INTPTR1_I64]], %[[CST]] [0] : i64 into vector<4xi64>
// CHECK: %[[VEC2_1:.*]] = vector.bitcast %[[VEC1_1]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VEC3_1:.*]] = vector.insert %[[C32]], %[[VEC2_1]] [2] : i32 into vector<8xi32>
// CHECK: %[[VEC4_1:.*]] = vector.insert %[[C8]], %[[VEC3_1]] [3] : i32 into vector<8xi32>
// CHECK: %[[VEC5_1:.*]] = vector.insert %[[C0]], %[[VEC4_1]] [4] : i32 into vector<8xi32>
// CHECK: %[[VEC6_1:.*]] = vector.insert %[[C0]], %[[VEC5_1]] [5] : i32 into vector<8xi32>
%dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The memref should be 1d only - user needs to view 2d memref as 1d.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That part isn't clear from the xegpu op description.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jianhui-Li what is the motivation for fixing the memref to 1D?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The operation doesn't take care of out of boundary access, so given 2d doesn't restrict the operation from accessing out-of-boundary data so it is better for user to flatten it to 1d.

// CHECK: %[[VEC7_1:.*]] = vector.bitcast %[[VEC6_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[EXTR1:.*]] = vector.extract %[[VEC7_1]][0] : i64 from vector<4xi64>
// CHECK: %[[ADDR1:.*]] = arith.addi %[[EXTR1]], %[[C512]] : i64
// CHECK: %[[PTR1:.*]] = llvm.inttoptr %[[ADDR1]] : i64 to !llvm.ptr<1>
// CHECK: xevm.blockstore %[[PTR1]], %[[LOAD]] <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}>
// CHECK-SAME: : (!llvm.ptr<1>, vector<2xi32>)
xegpu.store_nd %loaded, %dst_tdesc[4, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The offsets should be 1d only, and vector size (2) should match with tensor_desc size (32) - the IR verifier needs to enhanced to capture this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is SIMT level xegpu code and value to write should have been distributed.
Why should the size match the descriptor size?
Isn't the descriptor size for the entire subgroup and the operand type size is for individual lanes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. For SIMT level, vector size doesn't have to match with tensor_desc size.

: vector<2xf32>, !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
gpu.return
}
}