Skip to content

Commit 43d9ddb

Browse files
committed
support memref subview in xegpu to xevm type conversion
1 parent 637a230 commit 43d9ddb

File tree

2 files changed

+52
-13
lines changed

2 files changed

+52
-13
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -999,19 +999,52 @@ struct ConvertXeGPUToXeVMPass
999999
// LLVM type converter puts unrealized casts for the following cases:
10001000
// add materialization casts to handle them.
10011001

1002-
// Materialization to convert memref to i64
1002+
// Materialization to convert memref to i64 or i32 depending on global/SLM
10031003
auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
10041004
ValueRange inputs,
10051005
Location loc) -> Value {
10061006
if (inputs.size() != 1)
10071007
return {};
10081008
auto input = inputs.front();
10091009
if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1010+
unsigned rank = memrefTy.getRank();
1011+
Type indexType = builder.getIndexType();
10101012

1011-
Value addr =
1012-
memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
1013-
return arith::IndexCastUIOp::create(builder, loc, type, addr)
1014-
.getResult();
1013+
SmallVector<Type> resultTypes;
1014+
// Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
1015+
// size0, size1, ..., sizeN-1]
1016+
resultTypes.push_back(MemRefType::get(
1017+
{}, memrefTy.getElementType(), MemRefLayoutAttrInterface(),
1018+
memrefTy.getMemorySpace())); // base memref (unranked)
1019+
resultTypes.push_back(indexType); // offset
1020+
for (unsigned i = 0; i < rank; ++i)
1021+
resultTypes.push_back(indexType); // strides
1022+
for (unsigned i = 0; i < rank; ++i)
1023+
resultTypes.push_back(indexType); // sizes
1024+
1025+
auto meta = memref::ExtractStridedMetadataOp::create(
1026+
builder, loc, resultTypes, input);
1027+
1028+
auto addr = memref::ExtractAlignedPointerAsIndexOp::create(
1029+
builder, loc, meta.getBaseBuffer());
1030+
auto offset = meta.getOffset();
1031+
1032+
auto addr_casted =
1033+
arith::IndexCastUIOp::create(builder, loc, type, addr);
1034+
auto offset_casted =
1035+
arith::IndexCastUIOp::create(builder, loc, type, offset);
1036+
1037+
// Compute the final address: base address + byte offset
1038+
auto byte_size = arith::ConstantOp::create(
1039+
builder, loc, type,
1040+
builder.getIntegerAttr(type,
1041+
memrefTy.getElementTypeBitWidth() / 8));
1042+
auto byte_offset =
1043+
arith::MulIOp::create(builder, loc, offset_casted, byte_size);
1044+
auto addr_with_offset =
1045+
arith::AddIOp::create(builder, loc, addr_casted, byte_offset);
1046+
1047+
return addr_with_offset.getResult();
10151048
}
10161049
return {};
10171050
};

mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,29 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
3333

3434
%subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
3535

36+
//CHECK: %[[base_buffer:.*]], %[[offset:.*]], %[[sizes:.*]]:2, %[[strides:.*]]:2 = memref.extract_strided_metadata %{{.*}} : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> memref<f32, 3>, index, index, index, index, index
37+
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer]] : memref<f32, 3> -> index
38+
//CHECK: %[[ptr_i32:.*]] = arith.index_castui %[[intptr]] : index to i32
39+
//CHECK: %[[offset_i32:.*]] = arith.index_castui %[[offset]] : index to i32
40+
//CHECK: %[[c4_i32:.*]] = arith.constant 4 : i32
41+
//CHECK: %[[mul:.*]] = arith.muli %[[offset_i32]], %[[c4_i32]] : i32
42+
//CHECK: %[[add:.*]] = arith.addi %[[ptr_i32]], %[[mul]] : i32
43+
3644
%0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
3745

3846
//CHECK: %[[TID:.*]] = gpu.thread_id x
3947
//CHECK: %[[C1:.*]] = arith.constant 1 : index
4048
//CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
41-
//CHECK: %[[C4:.*]] = arith.constant 4 : i32
42-
//CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
49+
//CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, {{.*}} : i32
4350
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
4451

4552
%tid_x = gpu.thread_id x
46-
53+
4754
%1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
4855

4956
//CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
5057

51-
xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
58+
xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
5259

5360
gpu.return %1: f32
5461
}
@@ -99,8 +106,6 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
99106
//CHECK-LABEL: load_store_matrix_blocked_nostride
100107
gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
101108

102-
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
103-
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
104109
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
105110

106111
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
@@ -178,7 +183,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
178183
//CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
179184
gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
180185

181-
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
186+
//CHECK: %[[base_buffer:.*]], %[[offset:.*]], %[[sizes:.*]], %[[strides:.*]] = memref.extract_strided_metadata %arg0 : memref<4096xi8, 3> -> memref<i8, 3>, index, index, index
187+
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[base_buffer]] : memref<i8, 3> -> index
182188
//CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
183189
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
184190

@@ -206,7 +212,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
206212
//CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
207213
//CHECK: %[[c2:.*]] = arith.constant 2 : i32
208214
//CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
209-
//CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32]], %[[byteOffset]] : i32
215+
//CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32:.*]], %[[byteOffset]] : i32
210216
//CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
211217
//CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
212218
//CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>

0 commit comments

Comments
 (0)