Skip to content

Commit 4e4cbd0

Browse files
committed
Replace 2D block load payload with i64.
1 parent e510643 commit 4e4cbd0

File tree

5 files changed

+106
-110
lines changed

5 files changed

+106
-110
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 26 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,6 @@ namespace {
4848
static constexpr int32_t systolicDepth{8};
4949
static constexpr int32_t executionSize{16};
5050

51-
// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
52-
enum class NdTdescOffset : uint32_t {
53-
BasePtr = 0, // Base pointer (i64)
54-
BaseShapeW = 2, // Base shape width (i32)
55-
BaseShapeH = 3, // Base shape height (i32)
56-
TensorOffsetW = 4, // Tensor offset W (i32)
57-
TensorOffsetH = 5 // Tensor offset H (i32)
58-
};
59-
6051
static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
6152
switch (xeGpuMemspace) {
6253
case xegpu::MemorySpace::Global:
@@ -177,92 +168,14 @@ class CreateNdDescToXeVMPattern
177168
if (mixedOffsets.size() != 0)
178169
return rewriter.notifyMatchFailure(op, "Offsets not supported.");
179170
auto loc = op.getLoc();
180-
auto source = op.getSource();
181-
// Op is lowered to a code sequence that populates payload.
182-
// Payload is a 8xi32 vector. Offset to individual fields are defined in
183-
// NdTdescOffset enum.
184-
Type payloadElemTy = rewriter.getI32Type();
185-
VectorType payloadTy = VectorType::get(8, payloadElemTy);
186-
Type i64Ty = rewriter.getI64Type();
187-
// 4xi64 view is used for inserting the base pointer.
188-
VectorType payloadI64Ty = VectorType::get(4, i64Ty);
189-
// Initialize payload to zero.
190-
Value payload = arith::ConstantOp::create(
191-
rewriter, loc,
192-
DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
193-
194-
Value baseAddr;
195-
Value baseShapeW;
196-
Value baseShapeH;
197-
Value offsetW;
198-
Value offsetH;
199171

200-
// Source can be a memref or a pointer (ui64, ui32, i64 or i32).
201-
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
202-
auto srcRank = mixedSizes.size();
203-
if (srcRank < 2)
204-
return rewriter.notifyMatchFailure(op, "Expected at least 2D source.");
205-
206-
auto sourceTy = source.getType();
207-
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
208-
// If source is a memref, we need to extract the aligned pointer as index.
209-
// Pointer type is passed as i32 or i64 by type converter.
210-
if (sourceMemrefTy) {
211-
if (!sourceMemrefTy.hasStaticShape()) {
212-
return rewriter.notifyMatchFailure(op, "Expected static memref shape.");
213-
}
214-
baseAddr =
215-
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
216-
} else {
217-
baseAddr = adaptor.getSource();
218-
}
219-
// Utility for creating offset values from op fold result.
220-
auto createOffset = [&](OpFoldResult ofr) -> Value {
221-
Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofr);
222-
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
223-
return val;
224-
};
225-
// Offsets are not supported (0 is used).
226-
offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
227-
offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
228-
// Get shape values from op fold results.
229-
baseShapeW = createOffset(mixedSizes[srcRank - 1]);
230-
if (srcRank == 2) {
231-
baseShapeH = createOffset(mixedSizes[0]);
232-
} else {
233-
// Generate compute chain for height (product of sizes of all but the last
234-
// dimension).
235-
baseShapeH = getProductOfSizes(rewriter, loc, mixedSizes, 0, srcRank - 1);
236-
baseShapeH = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy,
237-
baseShapeH);
238-
}
239-
if (sourceMemrefTy) {
240-
// Cast index to i64.
241-
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
242-
} else if (baseAddr.getType() != i64Ty) {
172+
Value baseAddr = adaptor.getSource();
173+
Type i64Ty = rewriter.getI64Type();
174+
if (baseAddr.getType() != i64Ty) {
243175
// Pointer type may be i32. Cast to i64 if needed.
244176
baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
245177
}
246-
// Populate payload.
247-
Value payLoadAsI64 =
248-
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
249-
payLoadAsI64 =
250-
vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
251-
static_cast<int>(NdTdescOffset::BasePtr));
252-
payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
253-
payload =
254-
vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
255-
static_cast<int>(NdTdescOffset::BaseShapeW));
256-
payload =
257-
vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
258-
static_cast<int>(NdTdescOffset::BaseShapeH));
259-
payload = vector::InsertOp::create(
260-
rewriter, loc, offsetW, payload,
261-
static_cast<int>(NdTdescOffset::TensorOffsetW));
262-
payload = vector::InsertOp::create(
263-
rewriter, loc, offsetH, payload,
264-
static_cast<int>(NdTdescOffset::TensorOffsetH));
265-
rewriter.replaceOp(op, payload);
178+
rewriter.replaceOp(op, baseAddr);
266179
return success();
267180
}
268181
};
@@ -291,7 +204,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
291204
auto loc = op.getLoc();
292205
auto ctxt = rewriter.getContext();
293206

294-
auto tdesc = adaptor.getTensorDesc();
295207
auto tdescTy = op.getTensorDescType();
296208
if (tdescTy.getRank() != 2)
297209
return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
@@ -301,15 +213,27 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
301213
return rewriter.notifyMatchFailure(
302214
op, "Expected element type bit width to be multiple of 8.");
303215

304-
VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
305-
Value payLoadAsI64 =
306-
vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
307-
Value basePtr = vector::ExtractOp::create(
308-
rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
309-
Value baseShapeW = vector::ExtractOp::create(
310-
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
311-
Value baseShapeH = vector::ExtractOp::create(
312-
rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
216+
Value basePtr = adaptor.getTensorDesc();
217+
// Utility for creating offset values from op fold result.
218+
Type payloadElemTy = rewriter.getIntegerType(32);
219+
auto createOffset = [&](OpFoldResult ofr) -> Value {
220+
Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofr);
221+
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
222+
return val;
223+
};
224+
auto srcRank = mixedSizes.size();
225+
// Get shape values from op fold results.
226+
Value baseShapeW = createOffset(mixedSizes[srcRank - 1]);
227+
Value baseShapeH;
228+
if (srcRank == 2) {
229+
baseShapeH = createOffset(mixedSizes[0]);
230+
} else {
231+
// Generate compute chain for height (product of sizes of all but the last
232+
// dimension).
233+
baseShapeH = getProductOfSizes(rewriter, loc, mixedSizes, 0, srcRank - 1);
234+
baseShapeH = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy,
235+
baseShapeH);
236+
}
313237
// Offsets are provided by the op.
314238
// convert them to i32.
315239
// Offset computation assumes base memory layout is row major.
@@ -979,10 +903,7 @@ struct ConvertXeGPUToXeVMPass
979903
return VectorType::get(sum, elemType);
980904
});
981905
typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
982-
if (type.isScattered())
983-
return IntegerType::get(&getContext(), 64);
984-
auto i32Type = IntegerType::get(&getContext(), 32);
985-
return VectorType::get(8, i32Type);
906+
return IntegerType::get(&getContext(), 64);
986907
});
987908
// Convert MemDescType into flattened MemRefType for SLM
988909
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {

mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
1+
// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
22

33
gpu.module @load_store_check {
44
gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {

mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
1+
// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
22

33
gpu.module @load_store_check {
44
// CHECK: fail
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
2+
3+
gpu.module @load_store_check {
4+
gpu.func @load_store(%src: ui64, %dst: ui32) kernel {
5+
// CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
6+
// CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
7+
// CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
8+
// CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
9+
// CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
10+
// CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
11+
// CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
12+
// CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
13+
%c8 = arith.constant 8 : index
14+
%c16 = arith.constant 16 : index
15+
%c1 = arith.constant 1 : index
16+
%src_tdesc = xegpu.create_nd_tdesc %src, shape:[%c8, %c16], strides:[%c16, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
17+
18+
19+
//CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
20+
//CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
21+
//CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
22+
//CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
23+
//CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
24+
//CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
25+
//CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
26+
//CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
27+
//CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
28+
//CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
29+
//CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
30+
//CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
31+
//CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
32+
//CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
33+
//CHECK-SAME: pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
34+
//CHECK-SAME: v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
35+
%loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
36+
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
37+
//CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
38+
39+
%tid_x = gpu.thread_id x
40+
%tid_x_i32 = arith.index_cast %tid_x : index to i32
41+
%tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
42+
//CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
43+
%loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
44+
45+
// CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
46+
// CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
47+
// CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
48+
// CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
49+
// CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
50+
// CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
51+
// CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
52+
// CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
53+
%dst_tdesc = xegpu.create_nd_tdesc %dst, shape:[%c8, %c16], strides:[%c16, %c1] : ui32 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
54+
55+
//CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
56+
//CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
57+
//CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
58+
//CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
59+
//CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
60+
//CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
61+
//CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
62+
//CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
63+
//CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
64+
//CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
65+
//CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
66+
//CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
67+
//CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
68+
//CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
69+
//CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
70+
//CHECK-SAME: tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
71+
xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
72+
: vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
73+
gpu.return
74+
}
75+
}

mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
// RUN: mlir-opt -convert-xegpu-to-xevm -split-input-file %s | FileCheck %s
1+
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
22

3-
gpu.module @fence_check {
4-
gpu.func @fence(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
3+
gpu.module @prefetch_nd_check {
4+
gpu.func @prefetch_nd(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
55
%srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
66
%dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
77

0 commit comments

Comments
 (0)