Skip to content

Commit 9dc4ebf

Browse files
authored
[MLIR][XeGPU] Allow create mem desc from 2d memref (#167767)
This PR relax the create_mem_desc's restriction on source memref, allowing it to be a 2d memref.
1 parent f38cf01 commit 9dc4ebf

File tree

6 files changed

+97
-67
lines changed

6 files changed

+97
-67
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,12 +1282,6 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
12821282
let hasCanonicalizer = 1;
12831283
}
12841284

1285-
def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
1286-
class StaticShared1DMemRefOf<list<Type> allowedTypes> :
1287-
ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
1288-
"statically shaped " # MemRefOf<allowedTypes>.summary # " for shared memory",
1289-
"mlir::MemRefType">;
1290-
12911285
class SizeInBits<string name> :
12921286
StrFunc<"llvm::cast<mlir::ShapedType>($" # name # ".getType()).getNumElements()"
12931287
"*llvm::cast<mlir::ShapedType>($" # name # ".getType()).getElementTypeBitWidth()">;
@@ -1304,11 +1298,12 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
13041298
as the underlying shared local memory.
13051299

13061300
Arguments:
1307-
- `source` : a 1D statically shaped memref with element type i8, representing the raw SLM buffer.
1301+
- `source` : 1D or 2D statically shape memref, representing the raw SLM buffer.
1302+
The provided memref must be contiguous.
13081303
Results:
13091304
- `mem_desc` : the memory descriptor.
13101305
}];
1311-
let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source);
1306+
let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[XeGPU_ScalarType]>, StaticShared2DMemRefOf<[XeGPU_ScalarType]>]>:$source);
13121307
let results = (outs XeGPU_MemDesc:$mem_desc);
13131308
let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))";
13141309
}

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
3535
let mnemonic = typeMnemonic;
3636
}
3737

38+
def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
39+
class StaticShared1DMemRefOf<list<Type> allowedTypes> :
40+
ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
41+
"reside in share memory and statically 1d shaped " # MemRefOf<allowedTypes>.summary # " ",
42+
"mlir::MemRefType">;
43+
44+
class StaticShared2DMemRefOf<list<Type> allowedTypes>:
45+
ConfinedType<MemRefRankOf<allowedTypes, [2]>, [HasStaticShapePred, isSharedPred],
46+
"reside in share memory and statically 2d shaped " # MemRefOf<allowedTypes>.summary # " ",
47+
"mlir::MemRefType">;
48+
3849
def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
3950
[ShapedTypeInterface], "::mlir::TensorType"> {
4051
let summary = "TensorDesc describing regions of interested data.";

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -579,9 +579,6 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
579579
}
580580
};
581581

582-
// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
583-
// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
584-
// 32 bits will be converted to 32 bits.
585582
class CreateMemDescOpPattern final
586583
: public OpConversionPattern<xegpu::CreateMemDescOp> {
587584
public:
@@ -590,16 +587,7 @@ class CreateMemDescOpPattern final
590587
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
591588
ConversionPatternRewriter &rewriter) const override {
592589

593-
auto resTy = op.getMemDesc();
594-
595-
// Create the result MemRefType with the same shape, element type, and
596-
// memory space
597-
auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
598-
599-
Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
600-
auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
601-
op.getSource(), zero, ValueRange());
602-
rewriter.replaceOp(op, viewOp);
590+
rewriter.replaceOp(op, adaptor.getSource());
603591
return success();
604592
}
605593
};
@@ -619,7 +607,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
619607

620608
auto loc = op.getLoc();
621609
auto ctxt = rewriter.getContext();
622-
Value basePtrStruct = adaptor.getMemDesc();
610+
Value baseAddr32 = adaptor.getMemDesc();
623611
Value mdescVal = op.getMemDesc();
624612
// Load result or Store value Type can be vector or scalar.
625613
Value data;
@@ -647,21 +635,14 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
647635

648636
auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
649637

650-
Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
651-
rewriter, loc, basePtrStruct);
652-
653-
// Convert base pointer (ptr) to i32
654-
Value basePtrI32 = arith::IndexCastUIOp::create(
655-
rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
656-
657638
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
658639
linearOffset = arith::IndexCastUIOp::create(
659640
rewriter, loc, rewriter.getI32Type(), linearOffset);
660-
basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
661-
elemByteSize);
641+
Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
642+
linearOffset, elemByteSize);
662643

663644
// convert base pointer (i32) to LLVM pointer type
664-
basePtrLLVM =
645+
Value basePtrLLVM =
665646
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
666647

667648
if (op.getSubgroupBlockIoAttr()) {
@@ -1005,15 +986,14 @@ struct ConvertXeGPUToXeVMPass
1005986
auto i32Type = IntegerType::get(&getContext(), 32);
1006987
return VectorType::get(8, i32Type);
1007988
});
1008-
// Convert MemDescType into flattened MemRefType for SLM
989+
// Convert MemDescType into i32 for SLM
1009990
typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
1010-
Type elemTy = type.getElementType();
1011-
int numElems = type.getNumElements();
1012-
return MemRefType::get(numElems, elemTy, AffineMap(), 3);
991+
return IntegerType::get(&getContext(), 32);
1013992
});
1014993

1015994
typeConverter.addConversion([&](MemRefType type) -> Type {
1016-
// Convert MemRefType to i64 type.
995+
if (type.getMemorySpaceAsInt() == 3)
996+
return IntegerType::get(&getContext(), 32);
1017997
return IntegerType::get(&getContext(), 64);
1018998
});
1019999

mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
44

55
// e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
66
// its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
7-
//CHECK-LABEL: load_store_matrix_1
8-
gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
7+
//CHECK-LABEL: load_store_matrix_plain
8+
gpu.func @load_store_matrix_plain(%arg0: memref<4096xi8, 3>) -> f32 {
99
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
1010

1111
//CHECK: %[[TID:.*]] = gpu.thread_id x
@@ -26,20 +26,48 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
2626
gpu.return %1: f32
2727
}
2828

29+
//CHECK-LABEL: load_store_matrix_plain_2d_input
30+
gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
31+
%c0 = arith.constant 0 : index
32+
%view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
33+
34+
%subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
35+
36+
%0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
37+
38+
//CHECK: %[[TID:.*]] = gpu.thread_id x
39+
//CHECK: %[[C1:.*]] = arith.constant 1 : index
40+
//CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
41+
//CHECK: %[[C4:.*]] = arith.constant 4 : i32
42+
//CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
43+
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
44+
45+
%tid_x = gpu.thread_id x
46+
47+
%1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
48+
49+
//CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
50+
51+
xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
52+
53+
gpu.return %1: f32
54+
}
55+
56+
2957
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
3058
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
31-
//CHECK-LABEL: load_store_matrix_2
32-
gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
59+
//CHECK-LABEL: load_store_matrix_blocked_strided
60+
gpu.func @load_store_matrix_blocked_strided(%arg0: memref<4096xi8, 3>) -> f16 {
3361
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
34-
//CHECK: %[[c0:.*]] = arith.constant 0 : index
62+
3563
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
3664
//CHECK: %[[c13:.*]] = arith.constant 13 : index
3765
//CHECK: %[[c16:.*]] = arith.constant 16 : index
3866
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
3967
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
4068
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
4169
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
42-
70+
//CHECK: %[[c0:.*]] = arith.constant 0 : index
4371
//CHECK: %[[c256:.*]] = arith.constant 256 : index
4472
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
4573
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -68,24 +96,24 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
6896

6997
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
7098
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
71-
//CHECK-LABEL: load_store_matrix_3
72-
gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
73-
//CHECK: %[[c0:.*]] = arith.constant 0 : index
74-
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
99+
//CHECK-LABEL: load_store_matrix_blocked_nostride
100+
gpu.func @load_store_matrix_blocked_nostride(%arg0: memref<4096xi8, 3>) -> f16 {
101+
102+
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
103+
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
75104
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
76105

77106
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
78107
//CHECK: %[[c19:.*]] = arith.constant 19 : index
79108
%tid_x = gpu.thread_id x
80109
%c19 = arith.constant 19: index
81110

82-
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
83-
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
84111
//CHECK: %[[c16:.*]] = arith.constant 16 : index
85112
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
86113
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
87114
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
88115
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
116+
//CHECK: %[[c0:.*]] = arith.constant 0 : index
89117
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
90118
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
91119
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -97,7 +125,6 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
97125
//CHECK: %[[c1:.*]] = arith.constant 1 : index
98126
//CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
99127
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
100-
101128
//CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
102129
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
103130

@@ -110,19 +137,17 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
110137

111138
// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
112139
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
113-
//CHECK-LABEL: load_store_matrix_4
114-
gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
140+
//CHECK-LABEL: load_store_matrix_blocked_strided_return_vector
141+
gpu.func @load_store_matrix_blocked_strided_return_vector(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
115142
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
116143

117-
//CHECK: %[[c0:.*]] = arith.constant 0 : index
118144
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
119-
120145
//CHECK: %[[c16:.*]] = arith.constant 16 : index
121146
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
122147
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
123148
//CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
124149
//CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
125-
150+
//CHECK: %[[c0:.*]] = arith.constant 0 : index
126151
//CHECK: %[[c256:.*]] = arith.constant 256 : index
127152
//CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
128153
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -150,25 +175,23 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
150175

151176
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
152177
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
153-
//CHECK-LABEL: load_store_matrix_5
154-
gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
155-
//CHECK: %[[c0:.*]] = arith.constant 0 : index
156-
//CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
157-
158-
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
159-
178+
//CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
179+
gpu.func @load_store_matrix_blocked_subgroupblockio(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
180+
181+
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
182+
//CHECK: %[[basePtrI32:.*]] = arith.index_castui %[[intptr]] : index to i32
183+
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
184+
160185
//CHECK: %[[c16:.*]] = arith.constant 16 : index
161186
//CHECK: %[[c48:.*]] = arith.constant 48 : index
162-
163187
%c16 = arith.constant 16 : index
164188
%c48 = arith.constant 48 : index
165189

166-
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
167-
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
168190
//CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
169191
//CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
170192
//CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
171193
//CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
194+
//CHECK: %[[c0:.*]] = arith.constant 0 : index
172195
//CHECK: %[[c1024:.*]] = arith.constant 1024 : index
173196
//CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
174197
//CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
@@ -183,7 +206,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
183206
//CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
184207
//CHECK: %[[c2:.*]] = arith.constant 2 : i32
185208
//CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
186-
//CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32
209+
//CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI32]], %[[byteOffset]] : i32
187210
//CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
188211
//CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
189212
//CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ func.func @slice_attr_repeat_dim() {
836836
// -----
837837
func.func @create_mem_desc_non_slm() {
838838
%m = memref.alloca() {alignment = 1024} : memref<2048xi8, 1>
839-
// expected-error@+1 {{operand #0 must be statically shaped memref of 8-bit signless integer values for shared memory}}
839+
// expected-error@+1 {{operand #0 must be reside in share memory and statically 1d shaped memref }}
840840
%mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 1> -> !xegpu.mem_desc<16x64xf16>
841841
return
842842
}

mlir/test/Dialect/XeGPU/ops.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,27 @@ gpu.func @create_mem_desc_with_stride() {
834834
gpu.return
835835
}
836836

837+
838+
// CHECK-LABEL: gpu.func @create_mem_desc_from_2d_memref({{.*}}) {
839+
gpu.func @create_mem_desc_from_2d_memref() {
840+
//CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<16x64xf16, 3>
841+
//CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
842+
%m = memref.alloca() {alignment = 1024} : memref<16x64xf16, 3>
843+
%mem_desc = xegpu.create_mem_desc %m : memref<16x64xf16, 3> -> !xegpu.mem_desc<16x64xf16>
844+
gpu.return
845+
}
846+
847+
// CHECK-LABEL: gpu.func @create_mem_desc_with_stride_from_2d_memref({{.*}}) {
848+
gpu.func @create_mem_desc_with_stride_from_2d_memref() {
849+
//CHECK: %[[ALLOC:.+]] = memref.alloca() {alignment = 1024 : i64} : memref<32x64xf16, 3>
850+
//CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][16, 0] [16, 64] [1, 1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
851+
//CHECK: %{{.+}} = xegpu.create_mem_desc %[[SUBVIEW]] : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
852+
%m = memref.alloca() {alignment = 1024} : memref<32x64xf16, 3>
853+
%m_sub = memref.subview %m[16, 0][16, 64][1,1] : memref<32x64xf16, 3> to memref<16x64xf16, strided<[64, 1], offset: 1024>, 3>
854+
%mem_desc = xegpu.create_mem_desc %m_sub : memref<16x64xf16, strided<[64, 1], offset: 1024>, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
855+
gpu.return
856+
}
857+
837858
// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
838859
gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
839860
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>

0 commit comments

Comments
 (0)