Skip to content

Commit d9a5376

Browse files
committed
inital implementation
1 parent 540250c commit d9a5376

File tree

4 files changed

+86
-30
lines changed

4 files changed

+86
-30
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1308,7 +1308,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
13081308
Results:
13091309
- `mem_desc` : the memory descriptor.
13101310
}];
1311-
let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source);
1311+
let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[I8]>, ConfinedType<MemRefRankOf<[XeGPU_ScalarType], [2]>, [HasStaticShapePred, isSharedPred]>]>:$source);
13121312
let results = (outs XeGPU_MemDesc:$mem_desc);
13131313
let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))";
13141314
}

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -590,16 +590,22 @@ class CreateMemDescOpPattern final
590590
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
591591
ConversionPatternRewriter &rewriter) const override {
592592

593-
auto resTy = op.getMemDesc();
593+
// auto resTy = op.getMemDesc();
594594

595595
// Create the result MemRefType with the same shape, element type, and
596596
// memory space
597-
auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
597+
// auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
598598

599-
Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
600-
auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
601-
op.getSource(), zero, ValueRange());
602-
rewriter.replaceOp(op, viewOp);
599+
// Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
600+
// auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
601+
// op.getSource(), zero, ValueRange());
602+
603+
Value baseAddr = memref::ExtractAlignedPointerAsIndexOp::create(
604+
rewriter, op.getLoc(), op.getSource());
605+
auto baseAddr32 = arith::IndexCastUIOp::create(
606+
rewriter, op.getLoc(), rewriter.getI32Type(), baseAddr);
607+
608+
rewriter.replaceOp(op, baseAddr32);
603609
return success();
604610
}
605611
};
@@ -619,7 +625,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
619625

620626
auto loc = op.getLoc();
621627
auto ctxt = rewriter.getContext();
622-
Value basePtrStruct = adaptor.getMemDesc();
628+
Value baseAddr32 = adaptor.getMemDesc();
623629
Value mdescVal = op.getMemDesc();
624630
// Load result or Store value Type can be vector or scalar.
625631
Value data;
@@ -647,21 +653,21 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
647653

648654
auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
649655

650-
Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
651-
rewriter, loc, basePtrStruct);
656+
// Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
657+
// rewriter, loc, basePtrStruct);
652658

653659
// Convert base pointer (ptr) to i32
654-
Value basePtrI32 = arith::IndexCastUIOp::create(
655-
rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
660+
//Value basePtrI32 = arith::IndexCastUIOp::create(
661+
// rewriter, loc, rewriter.getI32Type(), baseAddr);
656662

657663
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
658664
linearOffset = arith::IndexCastUIOp::create(
659665
rewriter, loc, rewriter.getI32Type(), linearOffset);
660-
basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
666+
Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32, linearOffset,
661667
elemByteSize);
662668

663669
// convert base pointer (i32) to LLVM pointer type
664-
basePtrLLVM =
670+
Value basePtrLLVM =
665671
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
666672

667673
if (op.getSubgroupBlockIoAttr()) {
@@ -1005,12 +1011,13 @@ struct ConvertXeGPUToXeVMPass
10051011
auto i32Type = IntegerType::get(&getContext(), 32);
10061012
return VectorType::get(8, i32Type);
10071013
});
1008-
// Convert MemDescType into flattened MemRefType for SLM
1009-
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
1010-
Type elemTy = type.getElementType();
1011-
int numElems = type.getNumElements();
1012-
return MemRefType::get(numElems, elemTy, AffineMap(), 3);
1013-
});
1014+
// Convert MemDescType into i32 for SLM
1015+
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
1016+
// Type elemTy = type.getElementType();
1017+
// int numElems = type.getNumElements();
1018+
// return MemRefType::get(numElems, elemTy, AffineMap(), 3);
1019+
return IntegerType::get(&getContext(), 32);
1020+
});
10141021

10151022
typeConverter.addConversion([&](MemRefType type) -> Type {
10161023
// Convert MemRefType to i64 type.

mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
44

55
// e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
66
// its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
7-
//CHECK-LABEL: load_store_matrix_1
8-
gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
7+
//CHECK-LABEL: load_store_matrix_plain
8+
gpu.func @load_store_matrix_plain(%arg0: memref<4096xi8, 3>) -> f32 {
99
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
1010

1111
//CHECK: %[[TID:.*]] = gpu.thread_id x
@@ -26,10 +26,38 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
2626
gpu.return %1: f32
2727
}
2828

29+
//CHECK-LABEL: load_store_matrix_plain_2d_input
30+
gpu.func @load_store_matrix_plain(%arg0: memref<8192xi8, 3>) -> f32 {
31+
32+
%view = memref.view %arg0[0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
33+
34+
%subview = memref.subview %view[64, 0] [64, 128] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
35+
36+
%0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
37+
38+
//CHECK: %[[TID:.*]] = gpu.thread_id x
39+
//CHECK: %[[C1:.*]] = arith.constant 1 : index
40+
//CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
41+
//CHECK: %[[C4:.*]] = arith.constant 4 : i32
42+
//CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
43+
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
44+
45+
%tid_x = gpu.thread_id x
46+
%c0 = arith.constant 0 : index
47+
%1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
48+
49+
//CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
50+
51+
xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
52+
53+
gpu.return %1: f32
54+
}
55+
56+
2957
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
3058
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
31-
//CHECK-LABEL: load_store_matrix_2
32-
gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
59+
//CHECK-LABEL: load_store_matrix_blocked_strided
60+
gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
3361
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
3462
//CHECK: %[[c0:.*]] = arith.constant 0 : index
3563
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
@@ -68,8 +96,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
6896

6997
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
7098
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
71-
//CHECK-LABEL: load_store_matrix_3
72-
gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
99+
//CHECK-LABEL: load_store_matrix_blocked_nostride
100+
gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
73101
//CHECK: %[[c0:.*]] = arith.constant 0 : index
74102
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
75103
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
@@ -110,8 +138,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
110138

111139
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
112140
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
113-
//CHECK-LABEL: load_store_matrix_4
114-
gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
141+
//CHECK-LABEL: load_store_matrix_blocked_strided_return_vector
142+
gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
115143
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
116144

117145
//CHECK: %[[c0:.*]] = arith.constant 0 : index
@@ -150,8 +178,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
150178

151179
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
152180
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
153-
//CHECK-LABEL: load_store_matrix_5
154-
gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
181+
//CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
182+
gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
155183
//CHECK: %[[c0:.*]] = arith.constant 0 : index
156184
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
157185

mlir/test/Dialect/XeGPU/ops.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,27 @@ gpu.func @create_mem_desc_with_stride() {
834834
gpu.return
835835
}
836836

837+
838+
// CHECK-LABEL: gpu.func @create_mem_desc_from_2d_memref({{.*}}) {
839+
gpu.func @create_mem_desc_from_2d_memref() {
840+
//CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<16x64xf16, 3>
841+
//CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
842+
%m = memref.alloca() {alignment = 1024} : memref<16x64xf16, 3>
843+
%mem_desc = xegpu.create_mem_desc %m : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
844+
gpu.return
845+
}
846+
847+
// CHECK-LABEL: gpu.func @create_mem_desc_with_stride_from_2d_memref({{.*}}) {
848+
gpu.func @create_mem_desc_with_stride_from_2d_memref() {
849+
//CHECK: %[[ALLOC:.+]] = memref.alloca() {alignment = 1024 : i64} : memref<32x64xf16, 3>
850+
//CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16, 0] [16, 64] [1, 1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
851+
//CHECK: %{{.+}} = xegpu.create_mem_desc %[[SUBVIEW]] : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
852+
%m = memref.alloca() {alignment = 1024} : memref<32x64xf16, 3>
853+
%m_sub = memref.subview %m[16, 0][16, 64][1,1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
854+
%mem_desc = xegpu.create_mem_desc %m_sub : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
855+
gpu.return
856+
}
857+
837858
// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
838859
gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
839860
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>

0 commit comments

Comments
 (0)