Skip to content

Commit 236343e

Browse files
committed
Temp save.
1 parent d88d676 commit 236343e

File tree

4 files changed

+99
-81
lines changed

4 files changed

+99
-81
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 86 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,13 @@ using namespace mlir;
3939

4040
namespace {
4141

42-
enum class NdDescI32Layout : uint32_t {
43-
BasePtr = 0,
44-
BaseShapeW = 2,
45-
BaseShapeH = 3,
46-
TensorOffsetW = 4,
47-
TensorOffsetH = 5
42+
// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
43+
enum class NdTdescOffset : uint32_t {
44+
BasePtr = 0, // Base pointer (i64)
45+
BaseShapeW = 2, // Base shape width (i32)
46+
BaseShapeH = 3, // Base shape height (i32)
47+
TensorOffsetW = 4, // Tensor offset W (i32)
48+
TensorOffsetH = 5 // Tensor offset H (i32)
4849
};
4950

5051
static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
@@ -57,6 +58,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
5758
llvm_unreachable("Unknown XeGPU memory space.");
5859
}
5960

61+
// Get same bitwidth flat vector type of new element type.
6062
static VectorType encodeVectorTypeTo(VectorType currentVecType,
6163
Type toElemType) {
6264
auto elemType = currentVecType.getElementType();
@@ -221,20 +223,20 @@ class CreateNdDescToXeVMPattern
221223
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
222224
payLoadAsI64 =
223225
vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
224-
static_cast<int>(NdDescI32Layout::BasePtr));
226+
static_cast<int>(NdTdescOffset::BasePtr));
225227
payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
226228
payload =
227229
vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
228-
static_cast<int>(NdDescI32Layout::BaseShapeW));
230+
static_cast<int>(NdTdescOffset::BaseShapeW));
229231
payload =
230232
vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
231-
static_cast<int>(NdDescI32Layout::BaseShapeH));
233+
static_cast<int>(NdTdescOffset::BaseShapeH));
232234
payload = vector::InsertOp::create(
233235
rewriter, loc, offsetW, payload,
234-
static_cast<int>(NdDescI32Layout::TensorOffsetW));
236+
static_cast<int>(NdTdescOffset::TensorOffsetW));
235237
payload = vector::InsertOp::create(
236238
rewriter, loc, offsetH, payload,
237-
static_cast<int>(NdDescI32Layout::TensorOffsetH));
239+
static_cast<int>(NdTdescOffset::TensorOffsetH));
238240
rewriter.replaceOp(op, payload);
239241
return success();
240242
}
@@ -249,6 +251,7 @@ class UpdateNdOffsetToXeVMPattern
249251
ConversionPatternRewriter &rewriter) const override {
250252
auto loc = op.getLoc();
251253
auto mixedOffsets = op.getMixedOffsets();
254+
// Only 2D offsets are supported for now.
252255
if (mixedOffsets.size() != 2)
253256
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
254257
auto tdesc = adaptor.getTensorDesc();
@@ -264,9 +267,9 @@ class UpdateNdOffsetToXeVMPattern
264267
return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
265268
payloadPos);
266269
};
267-
auto val =
268-
updateOffset(0, static_cast<int>(NdDescI32Layout::TensorOffsetH));
269-
val = updateOffset(1, static_cast<int>(NdDescI32Layout::TensorOffsetW));
270+
// Update offsets in the payload.
271+
auto val = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
272+
val = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
270273
rewriter.replaceOp(op, val);
271274
return success();
272275
}
@@ -293,86 +296,74 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
293296
VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
294297
Value payLoadAsI64 =
295298
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
296-
Value basePtr =
297-
vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
298-
static_cast<int>(NdDescI32Layout::BasePtr));
299+
Value basePtr = vector::ExtractOp::create(
300+
rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
299301
Value baseShapeW = vector::ExtractOp::create(
300-
rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeW));
302+
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
301303
Value baseShapeH = vector::ExtractOp::create(
302-
rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeH));
303-
// Offsets can come from three sources:
304-
// 1. Constant offsets, which are provided by the op.
305-
// 2. Offsets as operands, which are provided by the op.
306-
// 3. Offsets extracted from the tensor descriptor.
304+
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
305+
// Offsets provided in two ways:
306+
// 1. Offsets are extracted from the tensor descriptor.
307+
// 2. (Mixed) offsets which are provided by the op.
307308
Value offsetW;
308309
Value offsetH;
309-
auto cOffsets = op.getConstOffsets();
310-
auto offsets = op.getOffsets();
311-
if (cOffsets) {
312-
offsetW = arith::ConstantIntOp::create(
313-
rewriter, loc, rewriter.getI32Type(), (*cOffsets)[0]);
314-
offsetH = arith::ConstantIntOp::create(
315-
rewriter, loc, rewriter.getI32Type(), (*cOffsets)[1]);
316-
} else if (offsets.size() != 0) {
317-
// offsets are provided as operands
318-
if (offsets[0].getType() != rewriter.getI32Type()) {
319-
if (offsets[0].getType() != rewriter.getIndexType()) {
320-
return rewriter.notifyMatchFailure(
321-
op, "Expected offsets to be of type i32 or index.");
322-
}
323-
offsetW = arith::IndexCastUIOp::create(
324-
rewriter, loc, rewriter.getI32Type(), offsets[0]);
325-
} else {
326-
offsetW = offsets[0];
327-
}
328-
if (offsets[1].getType() != rewriter.getI32Type()) {
329-
if (offsets[1].getType() != rewriter.getIndexType()) {
330-
return rewriter.notifyMatchFailure(
331-
op, "Expected offsets to be of type i32 or index.");
332-
}
333-
offsetH = arith::IndexCastUIOp::create(
334-
rewriter, loc, rewriter.getI32Type(), offsets[1]);
335-
} else {
336-
offsetH = offsets[1];
337-
}
310+
auto mixedOffsets = op.getMixedOffsets();
311+
int64_t opOffsetsSize = mixedOffsets.size();
312+
if (opOffsetsSize != 0 && opOffsetsSize != 2) {
313+
return rewriter.notifyMatchFailure(op,
314+
"Expected 2D offsets or no offsets.");
315+
}
316+
if (opOffsetsSize) {
317+
// If mixed offsets are provided by the op convert them to i32.
318+
offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
319+
offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
320+
rewriter.getI32Type(), offsetW);
321+
offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
322+
offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
323+
rewriter.getI32Type(), offsetH);
338324
} else {
339325
// If offsets are not available, we need to extract them from the tensor
340326
// descriptor.
341327
offsetW = vector::ExtractOp::create(
342-
rewriter, loc, tdesc,
343-
static_cast<int>(NdDescI32Layout::TensorOffsetW));
328+
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
344329
offsetH = vector::ExtractOp::create(
345-
rewriter, loc, tdesc,
346-
static_cast<int>(NdDescI32Layout::TensorOffsetH));
330+
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
347331
}
332+
// Get address space from tensor descriptor memory space.
348333
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
349334
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
335+
// Convert base pointer (i64) to LLVM pointer type.
350336
Value basePtrLLVM =
351337
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
338+
// Compute element byte size and surface width in bytes.
352339
auto elemType = tdescTy.getElementType();
353340
auto elemBitSize = elemType.getIntOrFloatBitWidth();
354341
Value elemByteSize = arith::ConstantIntOp::create(
355342
rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
356343
Value surfaceW =
357344
arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
358345

346+
// Get tile sizes and vblocks from the tensor descriptor type.
359347
auto tileW = tdescTy.getDimSize(1);
360348
auto tileH = tdescTy.getDimSize(0);
361349
int32_t vblocks = tdescTy.getArrayLength();
362350
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
363-
VectorType srcVecTy = cast<VectorType>(op.getValue().getType());
351+
VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
352+
if (!srcVecTy) {
353+
return rewriter.notifyMatchFailure(
354+
op, "Expected store value to be a vector type.");
355+
}
364356
auto storeCacheControl =
365357
translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
366-
VectorType srcFlatVecTy =
367-
VectorType::get(srcVecTy.getNumElements(), srcVecTy.getElementType());
368-
Value srcFlatVec = op.getValue();
369-
srcFlatVecTy = encodeVectorTypeTo(srcFlatVecTy,
370-
rewriter.getIntegerType(elemBitSize));
371-
srcFlatVec =
372-
vector::BitCastOp::create(rewriter, loc, srcFlatVecTy, srcFlatVec);
358+
Value src = adaptor.getValue();
359+
// Get flat vector type of integer type with matching element bit size.
360+
VectorType newSrcVecTy =
361+
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
362+
if (srcVecTy != newSrcVecTy)
363+
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
373364
xevm::BlockStore2dOp::create(
374365
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
375-
offsetH, elemBitSize, tileW, tileH, srcFlatVec,
366+
offsetH, elemBitSize, tileW, tileH, src,
376367
xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
377368
rewriter.eraseOp(op);
378369
} else {
@@ -412,15 +403,14 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
412403

413404
// Add a builder that creates
414405
// offset * elemByteSize + baseAddr
415-
static auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
416-
Value baseAddr, Value offset,
417-
int64_t elemByteSize) -> Value {
406+
static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
407+
Value baseAddr, Value offset, int64_t elemByteSize) {
418408
Value byteSize = arith::ConstantIntOp::create(
419409
rewriter, loc, rewriter.getI64Type(), elemByteSize);
420410
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
421411
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
422412
return newAddr;
423-
};
413+
}
424414

425415
class CreateDescToXeVMPattern
426416
: public OpConversionPattern<xegpu::CreateDescOp> {
@@ -908,6 +898,10 @@ struct ConvertXeGPUToXeVMPass
908898
return IntegerType::get(&getContext(), 64);
909899
});
910900

901+
// LLVM type converter puts unrealized casts for the following cases:
902+
// add materialization casts to handle them.
903+
904+
// Materialization to convert memref to i64
911905
auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
912906
ValueRange inputs,
913907
Location loc) -> Value {
@@ -924,6 +918,7 @@ struct ConvertXeGPUToXeVMPass
924918
return {};
925919
};
926920

921+
// Materialization to convert ui64 to i64
927922
auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
928923
ValueRange inputs,
929924
Location loc) -> Value {
@@ -940,6 +935,7 @@ struct ConvertXeGPUToXeVMPass
940935
return {};
941936
};
942937

938+
// Materialization to convert ui32 to i32
943939
auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
944940
ValueRange inputs,
945941
Location loc) -> Value {
@@ -956,9 +952,13 @@ struct ConvertXeGPUToXeVMPass
956952
return {};
957953
};
958954

959-
auto vector1DMaterializationCast = [](OpBuilder &builder, Type type,
960-
ValueRange inputs,
961-
Location loc) -> Value {
955+
// Materialization to convert
956+
// - single element 1D vector to scalar
957+
// - bitcast vector of same rank
958+
// - shape vector of different rank but same element type
959+
auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
960+
ValueRange inputs,
961+
Location loc) -> Value {
962962
if (inputs.size() != 1)
963963
return {};
964964
auto input = inputs.front();
@@ -971,18 +971,30 @@ struct ConvertXeGPUToXeVMPass
971971
cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
972972
.getResult();
973973
return cast;
974+
} else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
975+
// If the target type is a vector of same rank,
976+
// bitcast to the target type.
977+
if (targetVecTy.getRank() == vecTy.getRank())
978+
return vector::BitCastOp::create(builder, loc, targetVecTy, input)
979+
.getResult();
980+
else if (targetVecTy.getElementType() == vecTy.getElementType()) {
981+
// If the target type is a vector of different rank but same element
982+
// type, reshape to the target type.
983+
return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
984+
.getResult();
985+
}
974986
}
975987
}
976988
return {};
977989
};
978990
typeConverter.addSourceMaterialization(memrefMaterializationCast);
979991
typeConverter.addSourceMaterialization(ui64MaterializationCast);
980992
typeConverter.addSourceMaterialization(ui32MaterializationCast);
981-
typeConverter.addSourceMaterialization(vector1DMaterializationCast);
993+
typeConverter.addSourceMaterialization(vectorMaterializationCast);
982994
typeConverter.addTargetMaterialization(memrefMaterializationCast);
983995
typeConverter.addTargetMaterialization(ui32MaterializationCast);
984996
typeConverter.addTargetMaterialization(ui64MaterializationCast);
985-
typeConverter.addTargetMaterialization(vector1DMaterializationCast);
997+
typeConverter.addTargetMaterialization(vectorMaterializationCast);
986998
ConversionTarget target(getContext());
987999
target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
9881000
vector::VectorDialect, arith::ArithDialect,

mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ gpu.module @load_store_check {
2020
//CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
2121
//CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
2222
//CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
23-
//CHECK: %[[LD_TILE_W:.*]] = arith.constant 0 : i32
24-
//CHECK: %[[LD_TILE_H:.*]] = arith.constant 0 : i32
23+
//CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
24+
//CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
25+
//CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
26+
//CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
2527
//CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
2628
//CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
2729
//CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
@@ -54,8 +56,10 @@ gpu.module @load_store_check {
5456
//CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
5557
//CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
5658
//CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
57-
//CHECK: %[[TILE_W:.*]] = arith.constant 0 : i32
58-
//CHECK: %[[TILE_H:.*]] = arith.constant 0 : i32
59+
//CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
60+
//CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
61+
//CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
62+
//CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
5963
//CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
6064
//CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
6165
//CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32

mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ gpu.module @materializecast {
6666
%mask = arith.constant dense<1>: vector<1xi1>
6767
%offset = arith.constant dense<0> : vector<1xindex>
6868
%0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
69-
: memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<1x8xf32>
69+
: memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32>
7070
gpu.return
7171
}
7272
}

mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ gpu.module @fence_check {
2020
//CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
2121
//CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
2222
//CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
23-
//CHECK: %[[PREF_TILE_W:.*]] = arith.constant 0 : i32
24-
//CHECK: %[[PREF_TILE_H:.*]] = arith.constant 0 : i32
23+
//CHECK: %[[PREF_TILE_W64:.*]] = arith.constant 0 : i64
24+
//CHECK: %[[PREF_TILE_W:.*]] = arith.trunci %[[PREF_TILE_W64]] : i64 to i32
25+
//CHECK: %[[PREF_TILE_H64:.*]] = arith.constant 0 : i64
26+
//CHECK: %[[PREF_TILE_H:.*]] = arith.trunci %[[PREF_TILE_H64]] : i64 to i32
2527
//CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1>
2628
//CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32
2729
//CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32

0 commit comments

Comments
 (0)