Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 165 additions & 94 deletions mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,6 @@ class CreateNdDescToXeVMPattern
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
if (rank != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");

auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
Expand All @@ -199,8 +196,19 @@ class CreateNdDescToXeVMPattern
}
baseAddr =
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
// Cast index to i64.
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
} else {
baseAddr = adaptor.getSource();
if (baseAddr.getType() != i64Ty) {
// Pointer type may be i32. Cast to i64 if needed.
baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
}
}
// 1D tensor descriptor is just the base address.
if (rank == 1) {
rewriter.replaceOp(op, baseAddr);
return success();
}
// Utility for creating offset values from op fold result.
auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
Expand All @@ -215,13 +223,6 @@ class CreateNdDescToXeVMPattern
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
if (sourceMemrefTy) {
// Cast index to i64.
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
} else if (baseAddr.getType() != i64Ty) {
// Pointer type may be i32. Cast to i64 if needed.
baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
}
// Populate payload.
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
Expand Down Expand Up @@ -257,108 +258,175 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
ConversionPatternRewriter &rewriter) const override {
auto mixedOffsets = op.getMixedOffsets();
int64_t opOffsetsSize = mixedOffsets.size();
if (opOffsetsSize != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();

auto tdesc = adaptor.getTensorDesc();
auto tdescTy = op.getTensorDescType();
if (tdescTy.getRank() != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
auto tileRank = tdescTy.getRank();
if (opOffsetsSize != tileRank)
return rewriter.notifyMatchFailure(
op, "Expected offset rank to match descriptor rank.");
auto elemType = tdescTy.getElementType();
auto elemBitSize = elemType.getIntOrFloatBitWidth();
if (elemBitSize % 8 != 0)
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");

VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
Value basePtr = vector::ExtractOp::create(
rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
Value baseShapeW = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
Value baseShapeH = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
// Offsets are provided by the op.
// convert them to i32.
Value offsetW =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offsetW);
Value offsetH =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offsetH);
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
// Compute element byte size and surface width in bytes.
Value elemByteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
Value surfaceW =
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);

// Get tile sizes and vblocks from the tensor descriptor type.
auto tileW = tdescTy.getDimSize(1);
auto tileH = tdescTy.getDimSize(0);
int32_t vblocks = tdescTy.getArrayLength();
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
Value src = adaptor.getValue();
// If store value is a scalar, get value from op instead of adaptor.
// Adaptor might have optimized away single element vector
if (src.getType().isIntOrFloat()) {
src = op.getValue();
}
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
if (!srcVecTy)
return rewriter.notifyMatchFailure(
op, "Expected store value to be a vector type.");
// Get flat vector type of integer type with matching element bit size.
VectorType newSrcVecTy =
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
if (srcVecTy != newSrcVecTy)
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
offsetH, elemBitSize, tileW, tileH, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
auto loadCacheControl =
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
xevm::BlockPrefetch2dOp::create(
if (tileRank == 2) {
// Compute element byte size.
Value elemByteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
Value basePtr =
vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
static_cast<int>(NdTdescOffset::BasePtr));
Value baseShapeW = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
Value baseShapeH = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
// Offsets are provided by the op.
// convert them to i32.
Value offsetW =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offsetW);
Value offsetH =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offsetH);
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
// Compute width in bytes.
Value surfaceW =
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);

// Get tile width from the tensor descriptor type.
auto tileW = tdescTy.getDimSize(tileRank - 1);
// Get tile height from the tensor descriptor type.
auto tileH = tdescTy.getDimSize(0);
// Get vblocks from the tensor descriptor type.
int32_t vblocks = tdescTy.getArrayLength();
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
Value src = adaptor.getValue();
// If store value is a scalar, get value from op instead of adaptor.
// Adaptor might have optimized away single element vector
if (src.getType().isIntOrFloat()) {
src = op.getValue();
}
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
if (!srcVecTy)
return rewriter.notifyMatchFailure(
op, "Expected store value to be a vector type.");
// Get flat vector type of integer type with matching element bit size.
VectorType newSrcVecTy =
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
if (srcVecTy != newSrcVecTy)
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
offsetH, elemBitSize, tileW, tileH, vblocks,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
offsetH, elemBitSize, tileW, tileH, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
rewriter.eraseOp(op);
} else {
VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
const bool vnni = op.getPacked().value_or(false);
auto transposeValue = op.getTranspose();
bool transpose =
transposeValue.has_value() && transposeValue.value()[0] == 1;
VectorType loadedTy = encodeVectorTypeTo(
dstVecTy, vnni ? rewriter.getI32Type()
: rewriter.getIntegerType(elemBitSize));

Value resultFlatVec = xevm::BlockLoad2dOp::create(
rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
transpose, vnni,
auto loadCacheControl =
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
xevm::BlockPrefetch2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
rewriter.eraseOp(op);
} else {
VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
const bool vnni = op.getPacked().value_or(false);
auto transposeValue = op.getTranspose();
bool transpose =
transposeValue.has_value() && transposeValue.value()[0] == 1;
VectorType loadedTy = encodeVectorTypeTo(
dstVecTy, vnni ? rewriter.getI32Type()
: rewriter.getIntegerType(elemBitSize));

Value resultFlatVec = xevm::BlockLoad2dOp::create(
rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
transpose, vnni,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
resultFlatVec = vector::BitCastOp::create(
rewriter, loc,
encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
resultFlatVec);
rewriter.replaceOp(op, resultFlatVec);
}
}
} else {
// 1D tensor descriptor.
// `tdesc` represents base address as i64
// Offset in number of elements, need to multiply by element byte size.
// Compute byte offset.
// byteOffset = offset * elementByteSize
Value offset =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
offset = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI64Type(), offset);
// Compute element byte size.
Value elemByteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
Value byteOffset =
rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
// Final address = basePtr + byteOffset
Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
loc, tdesc,
getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
byteOffset));
// Convert base pointer (i64) to LLVM pointer type.
Value finalPtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
Value src = adaptor.getValue();
// If store value is a scalar, get value from op instead of adaptor.
// Adaptor might have optimized away single element vector
if (src.getType().isIntOrFloat()) {
src = op.getValue();
}
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
if (!srcVecTy)
return rewriter.notifyMatchFailure(
op, "Expected store value to be a vector type.");
// Get flat vector type of integer type with matching element bit size.
VectorType newSrcVecTy =
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
if (srcVecTy != newSrcVecTy)
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
auto storeCacheControl =
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
op, finalPtrLLVM, src,
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
} else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
auto loadCacheControl =
translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
VectorType resTy = cast<VectorType>(op.getValue().getType());
VectorType loadedTy =
encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
Value load = xevm::BlockLoadOp::create(
rewriter, loc, loadedTy, finalPtrLLVM,
xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
resultFlatVec = vector::BitCastOp::create(
rewriter, loc,
encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
resultFlatVec);
rewriter.replaceOp(op, resultFlatVec);
if (loadedTy != resTy)
load = vector::BitCastOp::create(rewriter, loc, resTy, load);
rewriter.replaceOp(op, load);
} else {
return rewriter.notifyMatchFailure(
op, "Unsupported operation: xegpu.prefetch_nd with tensor "
"descriptor rank == 1");
}
}
return success();
Expand Down Expand Up @@ -927,7 +995,10 @@ struct ConvertXeGPUToXeVMPass
return VectorType::get(sum, elemType);
});
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
// Scattered descriptors are not supported in XeVM lowering.
if (type.isScattered())
return {};
if (type.getRank() == 1)
return IntegerType::get(&getContext(), 64);
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ gpu.module @create_nd_tdesc {

// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
Expand All @@ -53,11 +53,11 @@ gpu.module @create_nd_tdesc {
%BLOCK_DMODEL = arith.constant 16 : index
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
// CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
// CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
// CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
// CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
// CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
// CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
// CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
Expand Down
Loading