Skip to content
Closed
189 changes: 83 additions & 106 deletions mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ namespace {
static constexpr int32_t systolicDepth{8};
static constexpr int32_t executionSize{16};

// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
enum class NdTdescOffset : uint32_t {
BasePtr = 0, // Base pointer (i64)
BaseShapeW = 2, // Base shape width (i32)
BaseShapeH = 3, // Base shape height (i32)
TensorOffsetW = 4, // Tensor offset W (i32)
TensorOffsetH = 5 // Tensor offset H (i32)
};

static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
switch (xeGpuMemspace) {
case xegpu::MemorySpace::Global:
Expand Down Expand Up @@ -151,6 +142,22 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
}
}

// Compute the product of sizes in the range [lo, hi) from the sizes array.
static Value getProductOfSizes(ConversionPatternRewriter &rewriter,
Location loc, ArrayRef<OpFoldResult> sizes,
size_t lo, size_t hi) {
Value product =
arith::ConstantIntOp::create(rewriter, loc, rewriter.getI64Type(), 1);
for (size_t idx = lo; idx < hi; idx++) {
OpFoldResult ofr = sizes[idx];
Value sizeVal = getValueOrCreateConstantIntOp(rewriter, loc, ofr);
sizeVal = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI64Type(), sizeVal);
product = rewriter.createOrFold<arith::MulIOp>(loc, product, sizeVal);
}
return product;
}

class CreateNdDescToXeVMPattern
: public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern::OpConversionPattern;
Expand All @@ -162,86 +169,14 @@ class CreateNdDescToXeVMPattern
if (mixedOffsets.size() != 0)
return rewriter.notifyMatchFailure(op, "Offsets not supported.");
auto loc = op.getLoc();
auto source = op.getSource();
// Op is lowered to a code sequence that populates payload.
// Payload is a 8xi32 vector. Offset to individual fields are defined in
// NdTdescOffset enum.
Type payloadElemTy = rewriter.getI32Type();
VectorType payloadTy = VectorType::get(8, payloadElemTy);
Type i64Ty = rewriter.getI64Type();
// 4xi64 view is used for inserting the base pointer.
VectorType payloadI64Ty = VectorType::get(4, i64Ty);
// Initialize payload to zero.
Value payload = arith::ConstantOp::create(
rewriter, loc,
DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));

Value baseAddr;
Value baseShapeW;
Value baseShapeH;
Value offsetW;
Value offsetH;

// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
if (rank != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");

auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
// Pointer type is passed as i32 or i64 by type converter.
if (sourceMemrefTy) {
if (!sourceMemrefTy.hasStaticShape()) {
return rewriter.notifyMatchFailure(op, "Expected static memref shape.");
}
baseAddr =
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
} else {
baseAddr = adaptor.getSource();
}
// Utility for creating offset values from op fold result.
auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
unsigned idx) -> Value {
Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
return val;
};
// Offsets are not supported (0 is used).
offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
// Get shape values from op fold results.
baseShapeW = createOffset(mixedSizes, 1);
baseShapeH = createOffset(mixedSizes, 0);
if (sourceMemrefTy) {
// Cast index to i64.
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
} else if (baseAddr.getType() != i64Ty) {
Value baseAddr = adaptor.getSource();
Type i64Ty = rewriter.getI64Type();
if (baseAddr.getType() != i64Ty) {
// Pointer type may be i32. Cast to i64 if needed.
baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
}
// Populate payload.
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
payLoadAsI64 =
vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
static_cast<int>(NdTdescOffset::BasePtr));
payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
payload =
vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
static_cast<int>(NdTdescOffset::BaseShapeW));
payload =
vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
static_cast<int>(NdTdescOffset::BaseShapeH));
payload = vector::InsertOp::create(
rewriter, loc, offsetW, payload,
static_cast<int>(NdTdescOffset::TensorOffsetW));
payload = vector::InsertOp::create(
rewriter, loc, offsetH, payload,
static_cast<int>(NdTdescOffset::TensorOffsetH));
rewriter.replaceOp(op, payload);
rewriter.replaceOp(op, baseAddr);
return success();
}
};
Expand All @@ -255,14 +190,24 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto tdVal = op.getTensorDesc();
xegpu::CreateNdDescOp descOp =
tdVal.template getDefiningOp<xegpu::CreateNdDescOp>();
if (!descOp)
return rewriter.notifyMatchFailure(
op, "Expected tensor descriptor to be created by CreateNdDescOp.");
auto mixedStrides = descOp.getMixedStrides();
auto mixedOffsets = op.getMixedOffsets();
int64_t opOffsetsSize = mixedOffsets.size();
if (opOffsetsSize != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto mixedSizes = descOp.getMixedSizes();
size_t opOffsetsSize = mixedOffsets.size();
if (opOffsetsSize != mixedStrides.size())
return rewriter.notifyMatchFailure(
op, "Offsets size should match base memory rank.");
if (opOffsetsSize < 2)
return rewriter.notifyMatchFailure(op, "Expected at least 2D offset.");
auto loc = op.getLoc();
auto ctxt = rewriter.getContext();

auto tdesc = adaptor.getTensorDesc();
auto tdescTy = op.getTensorDescType();
if (tdescTy.getRank() != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
Expand All @@ -272,23 +217,58 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
return rewriter.notifyMatchFailure(
op, "Expected element type bit width to be multiple of 8.");

VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
Value payLoadAsI64 =
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
Value basePtr = vector::ExtractOp::create(
rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
Value baseShapeW = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
Value baseShapeH = vector::ExtractOp::create(
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
Value basePtr = adaptor.getTensorDesc();
// Utility for creating offset values from op fold result.
Type payloadElemTy = rewriter.getIntegerType(32);
auto createOffset = [&](OpFoldResult ofr) -> Value {
Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofr);
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
return val;
};
auto srcRank = mixedSizes.size();
// Get shape values from op fold results.
Value baseShapeW = createOffset(mixedSizes[srcRank - 1]);
Value baseShapeH;
if (srcRank == 2) {
baseShapeH = createOffset(mixedSizes[0]);
} else {
// Generate compute chain for height (product of sizes of all but the last
// dimension).
baseShapeH = getProductOfSizes(rewriter, loc, mixedSizes, 0, srcRank - 1);
baseShapeH = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy,
baseShapeH);
}
// Offsets are provided by the op.
// convert them to i32.
Value offsetW =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
// Offset computation assumes base memory layout is row major.
Value offsetW = getValueOrCreateConstantIntOp(
rewriter, loc, mixedOffsets[opOffsetsSize - 1]);
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offsetW);
Value offsetH =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
Value offsetH;
if (opOffsetsSize == 2)
offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
else {
offsetH = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value tmpStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
// offsetH requires computing the linear offset using the strides.
for (size_t idx = 0; idx < opOffsetsSize - 1; idx++) {
size_t revIdx = opOffsetsSize - 2 - idx;
Value offsetVal =
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[revIdx]);
offsetVal = getValueOrCreateCastToIndexLike(
rewriter, loc, rewriter.getIndexType(), offsetVal);
Value mul =
rewriter.createOrFold<arith::MulIOp>(loc, tmpStride, offsetVal);
Value dimSize =
getValueOrCreateConstantIntOp(rewriter, loc, mixedSizes[revIdx]);
dimSize = getValueOrCreateCastToIndexLike(
rewriter, loc, rewriter.getIndexType(), dimSize);
tmpStride =
rewriter.createOrFold<arith::MulIOp>(loc, tmpStride, dimSize);
offsetH = rewriter.createOrFold<arith::AddIOp>(loc, offsetH, mul);
}
}
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offsetH);
// Get address space from tensor descriptor memory space.
Expand Down Expand Up @@ -927,10 +907,7 @@ struct ConvertXeGPUToXeVMPass
return VectorType::get(sum, elemType);
});
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
if (type.isScattered())
return IntegerType::get(&getContext(), 64);
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
return IntegerType::get(&getContext(), 64);
});
// Convert MemDescType into flattened MemRefType for SLM
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
Expand Down
48 changes: 14 additions & 34 deletions mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,25 @@ gpu.module @create_nd_tdesc {
// CHECK-LABEL: gpu.func @create_nd_tdesc
// CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
// CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
// CHECK-SAME: %[[ARG8:.*]]: memref<?x?xf16>) kernel {
gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
%stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32
// CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32
// CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32
// CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32
// CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32>
// CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32>
%stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
// Optimized away
%ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
: ui64 -> !xegpu.tensor_desc<8x16xf32>

// CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
// CHECK-NEXT: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
%srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>

// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
// CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
// CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
// CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
// CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
// CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
// Optimized away
%src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK-NEXT: %c1 = arith.constant 1 : index
%c1 = arith.constant 1 : index
// CHECK-NEXT: %c64 = arith.constant 64 : index
%size_x = arith.constant 64 : index
// CHECK-NEXT: %c16 = arith.constant 16 : index
%BLOCK_DMODEL = arith.constant 16 : index
// Optimized away
%dyn_tdesc = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16>
// CHECK-NEXT: gpu.return
gpu.return
}
}
Loading