Skip to content

Commit 18f8cb9

Browse files
committed
adding tests
1 parent d9a5376 commit 18f8cb9

File tree

2 files changed

+27
-51
lines changed

2 files changed

+27
-51
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -590,20 +590,10 @@ class CreateMemDescOpPattern final
590590
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
591591
ConversionPatternRewriter &rewriter) const override {
592592

593-
// auto resTy = op.getMemDesc();
594-
595-
// Create the result MemRefType with the same shape, element type, and
596-
// memory space
597-
// auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
598-
599-
// Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
600-
// auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
601-
// op.getSource(), zero, ValueRange());
602-
603593
Value baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(
604-
rewriter, op.getLoc(), op.getSource());
605-
auto baseAddr32 = arith::IndexCastUIOp::create(
606-
rewriter, op.getLoc(), rewriter.getI32Type(), baseAddr);
594+
rewriter, op.getLoc(), op.getSource());
595+
auto baseAddr32 = arith::IndexCastUIOp::create(
596+
rewriter, op.getLoc(), rewriter.getI32Type(), baseAddr);
607597

608598
rewriter.replaceOp(op, baseAddr32);
609599
return success();
@@ -653,18 +643,11 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
653643

654644
auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
655645

656-
// Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
657-
// rewriter, loc, basePtrStruct);
658-
659-
// Convert base pointer (ptr) to i32
660-
//Value basePtrI32 = arith::IndexCastUIOp::create(
661-
// rewriter, loc, rewriter.getI32Type(), baseAddr);
662-
663646
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
664647
linearOffset = arith::IndexCastUIOp::create(
665648
rewriter, loc, rewriter.getI32Type(), linearOffset);
666-
Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32, linearOffset,
667-
elemByteSize);
649+
Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
650+
linearOffset, elemByteSize);
668651

669652
// convert base pointer (i32) to LLVM pointer type
670653
Value basePtrLLVM =
@@ -1012,12 +995,9 @@ struct ConvertXeGPUToXeVMPass
1012995
return VectorType::get(8, i32Type);
1013996
});
1014997
// Convert MemDescType into i32 for SLM
1015-
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
1016-
// Type elemTy = type.getElementType();
1017-
// int numElems = type.getNumElements();
1018-
// return MemRefType::get(numElems, elemTy, AffineMap(), 3);
1019-
return IntegerType::get(&getContext(), 32);
1020-
});
998+
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
999+
return IntegerType::get(&getContext(), 32);
1000+
});
10211001

10221002
typeConverter.addConversion([&](MemRefType type) -> Type {
10231003
// Convert MemRefType to i64 type.

mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
2727
}
2828

2929
//CHECK-LABEL: load_store_matrix_plain_2d_input
30-
gpu.func @load_store_matrix_plain(%arg0: memref<8192xi8, 3>) -> f32 {
31-
32-
%view = memref.view %arg0[0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
30+
gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
31+
%c0 = arith.constant 0 : index
32+
%view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
3333

34-
%subview = memref.subview %view[64, 0] [64, 128] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
34+
%subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
3535

3636
%0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
3737

@@ -43,7 +43,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
4343
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
4444

4545
%tid_x = gpu.thread_id x
46-
%c0 = arith.constant 0 : index
46+
4747
%1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
4848

4949
//CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
@@ -59,15 +59,15 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
5959
//CHECK-LABEL: load_store_matrix_blocked_strided
6060
gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
6161
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
62-
//CHECK: %[[c0:.*]] = arith.constant 0 : index
62+
6363
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
6464
//CHECK: %[[c13:.*]] = arith.constant 13 : index
6565
//CHECK: %[[c16:.*]] = arith.constant 16 : index
6666
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
6767
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
6868
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
6969
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
70-
70+
//CHECK: %[[c0:.*]] = arith.constant 0 : index
7171
//CHECK: %[[c256:.*]] = arith.constant 256 : index
7272
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
7373
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -98,22 +98,22 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
9898
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
9999
//CHECK-LABEL: load_store_matrix_blocked_nostride
100100
gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
101-
//CHECK: %[[c0:.*]] = arith.constant 0 : index
102-
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
101+
102+
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
103+
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
103104
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
104105

105106
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
106107
//CHECK: %[[c19:.*]] = arith.constant 19 : index
107108
%tid_x = gpu.thread_id x
108109
%c19 = arith.constant 19: index
109110

110-
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
111-
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
112111
//CHECK: %[[c16:.*]] = arith.constant 16 : index
113112
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
114113
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
115114
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
116115
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
116+
//CHECK: %[[c0:.*]] = arith.constant 0 : index
117117
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
118118
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
119119
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -125,7 +125,6 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
125125
//CHECK: %[[c1:.*]] = arith.constant 1 : index
126126
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
127127
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
128-
129128
//CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
130129
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
131130

@@ -142,15 +141,13 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
142141
gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
143142
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
144143

145-
//CHECK: %[[c0:.*]] = arith.constant 0 : index
146144
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
147-
148145
//CHECK: %[[c16:.*]] = arith.constant 16 : index
149146
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
150147
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
151148
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
152149
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
153-
150+
//CHECK: %[[c0:.*]] = arith.constant 0 : index
154151
//CHECK: %[[c256:.*]] = arith.constant 256 : index
155152
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
156153
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -180,23 +177,22 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
180177
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
181178
//CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
182179
gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
183-
//CHECK: %[[c0:.*]] = arith.constant 0 : index
184-
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
185-
186-
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
187-
180+
181+
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
182+
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
183+
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
184+
185+
188186
//CHECK: %[[c16:.*]] = arith.constant 16 : index
189187
//CHECK: %[[c48:.*]] = arith.constant 48 : index
190-
191188
%c16 = arith.constant 16 : index
192189
%c48 = arith.constant 48 : index
193190

194-
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
195-
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
196191
//CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
197192
//CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
198193
//CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
199194
//CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
195+
//CHECK: %[[c0:.*]] = arith.constant 0 : index
200196
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
201197
//CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
202198
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index

0 commit comments

Comments
 (0)