Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 34 additions & 136 deletions mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ class CreateNdDescToXeVMPattern
matchAndRewrite(xegpu::CreateNdDescOp op,
xegpu::CreateNdDescOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
if (mixedOffsets.size() != 0)
return rewriter.notifyMatchFailure(op, "Offsets not supported.");
auto loc = op.getLoc();
auto source = op.getSource();
// Op is lowered to a code sequence that populates payload.
Expand All @@ -177,7 +180,6 @@ class CreateNdDescToXeVMPattern

// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
if (rank != 2)
Expand All @@ -202,17 +204,9 @@ class CreateNdDescToXeVMPattern
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
return val;
};
// Offsets can be either 2D or not provided (0 is used).
if (mixedOffsets.size() == 2) {
offsetW = createOffset(mixedOffsets, 1);
offsetH = createOffset(mixedOffsets, 0);
} else if (mixedOffsets.size() == 0) {
offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
} else {
return rewriter.notifyMatchFailure(op,
"Expected 2D offsets or no offsets.");
}
// Offsets are not supported (0 is used).
offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
Expand Down Expand Up @@ -247,39 +241,6 @@ class CreateNdDescToXeVMPattern
}
};

class UpdateNdOffsetToXeVMPattern
: public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::UpdateNdOffsetOp op,
xegpu::UpdateNdOffsetOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto mixedOffsets = op.getMixedOffsets();
// Only 2D offsets are supported for now.
if (mixedOffsets.size() != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto payload = adaptor.getTensorDesc();
// Utility for updating payload offset values from op fold result.
auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
Value offset =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
offset = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offset);
Value oldOffset =
vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
return vector::InsertOp::create(rewriter, loc, newOffset, payload,
payloadPos);
};
// Update offsets in the payload.
payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
rewriter.replaceOp(op, payload);
return success();
}
};

template <
typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
Expand All @@ -289,6 +250,10 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto mixedOffsets = op.getMixedOffsets();
int64_t opOffsetsSize = mixedOffsets.size();
if (opOffsetsSize != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();

Expand All @@ -311,32 +276,16 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
Value baseShapeH = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
// Offsets provided in two ways:
// 1. Offsets are extracted from the tensor descriptor.
// 2. (Mixed) offsets which are provided by the op.
Value offsetW;
Value offsetH;
auto mixedOffsets = op.getMixedOffsets();
int64_t opOffsetsSize = mixedOffsets.size();
if (opOffsetsSize != 0 && opOffsetsSize != 2)
return rewriter.notifyMatchFailure(op,
"Expected 2D offsets or no offsets.");
if (opOffsetsSize) {
// If mixed offsets are provided by the op convert them to i32.
offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offsetW);
offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offsetH);
} else {
// If offsets are not available, we need to extract them from the tensor
// descriptor.
offsetW = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
offsetH = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
}
// Offsets are provided by the op.
// convert them to i32.
Value offsetW =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offsetW);
Value offsetH =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offsetH);
// Get address space from tensor descriptor memory space.
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
Expand Down Expand Up @@ -422,54 +371,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
return newAddr;
}

class CreateDescToXeVMPattern
: public OpConversionPattern<xegpu::CreateDescOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto eTy = op.getTensorDescType().getElementType();
auto eBw = eTy.getIntOrFloatBitWidth();
if (eBw % 8 != 0)
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
auto loc = op.getLoc();
// Offsets are provided as scalar i64 by type converter.
auto offsets = adaptor.getOffsets();
// Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
// But type converter will convert them to integer types.
Value addr = adaptor.getSource();
// ui32 or i32 are passed as i32 so they need to be casted to i64.
if (addr.getType() != rewriter.getI64Type())
addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
rewriter.replaceOp(op, laneAddr);
return success();
}
};

class UpdateOffsetToXeVMPattern
: public OpConversionPattern<xegpu::UpdateOffsetOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::UpdateOffsetOp op,
xegpu::UpdateOffsetOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto eTy = op.getTensorDescType().getElementType();
auto eBw = eTy.getIntOrFloatBitWidth();
if (eBw % 8 != 0)
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");
auto loc = op.getLoc();
// Scatter descriptor is provided as scalar i64 by type converter.
// Offsets are provided as scalar i64 by type converter.
Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
adaptor.getOffsets(), eBw / 8);
rewriter.replaceOp(op, newOffset);
return success();
}
};

template <typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
Expand All @@ -478,6 +379,9 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value offset = adaptor.getOffsets();
if (!offset)
return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();
auto tdescTy = op.getTensorDescType();
Expand Down Expand Up @@ -527,21 +431,16 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
basePtrI64);
}
Value offsets = adaptor.getOffsets();
Value mask = adaptor.getMask();
if (offsets) {
if (dyn_cast<VectorType>(offsets.getType())) {
// Offset needs be scalar. Single element vector is converted to scalar
// by type converter.
return rewriter.notifyMatchFailure(op,
"Expected offsets to be a scalar.");
} else {
// If offsets are provided, we add them to the base pointer.
// Offsets are in number of elements, we need to multiply by
// element byte size.
basePtrI64 =
addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
}
if (dyn_cast<VectorType>(offset.getType())) {
// Offset needs be scalar. Single element vector is converted to scalar
// by type converter.
return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar.");
} else {
// If offset is provided, we add them to the base pointer.
// Offset is in number of elements, we need to multiply by
// element byte size.
basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
}
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
Expand Down Expand Up @@ -1011,13 +910,12 @@ struct ConvertXeGPUToXeVMPass
//===----------------------------------------------------------------------===//
void mlir::populateXeGPUToXeVMConversionPatterns(
const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
patterns.add<CreateNdDescToXeVMPattern,
LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
typeConverter, patterns.getContext());
patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
typeConverter, patterns.getContext());
Expand Down
32 changes: 0 additions & 32 deletions mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,6 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
// CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>

// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
// CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
// CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
// CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
// CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
// CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
// CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR28:.*]] = vector.bitcast %[[VAR27]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR28]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
// CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
%src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>

// CHECK: %[[C8:.*]] = arith.constant 8 : index
%c8 = arith.constant 8 : index
// CHECK: %[[C16:.*]] = arith.constant 16 : index
%c16 = arith.constant 16 : index
// CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
// CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
// CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
// CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
// CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
// CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
// CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
// CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
%updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}
Loading