Skip to content

Commit 639e997

Browse files
committed
Add tests
1 parent 7783591 commit 639e997

File tree

3 files changed

+269
-5
lines changed

3 files changed

+269
-5
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,10 @@ struct SgOffsetInfo {
120120
template <typename OpTy>
121121
std::optional<SgOffsetInfo>
122122
extractSgOffsetInfo(OpTy op, ConversionPatternRewriter &rewriter) {
123+
123124
int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
124-
if (offsetSize == 0)
125-
return std::nullopt;
125+
if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
126+
return std::nullopt;
126127

127128
Location loc = op.getLoc();
128129
Value tdesc = op.getTensorDesc();
@@ -832,8 +833,9 @@ namespace xegpu {
832833
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
833834
patterns
834835
.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
835-
WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
836-
WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
836+
WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
837+
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
838+
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
837839
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
838840
patterns.getContext());
839841
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,47 @@ gpu.module @test_distribution {
1010
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
1111
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
1212
gpu.return
13-
}
13+
}
14+
15+
// CHECK-LABEL: load_nd_tdesc_with_offset
16+
gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
17+
// CHECK-COUNT-4: xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}]
18+
// CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
19+
// CHECK-SAME-COUNT-4: -> vector<16x16xf32>
20+
// CHECK-NOT: xegpu.load_nd
21+
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
22+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
23+
%load = xegpu.load_nd %tdesc[0, 0]
24+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
25+
-> vector<256x128xf32>
26+
gpu.return
27+
}
28+
29+
// CHECK-LABEL: store_nd_with_offset
30+
gpu.func @store_nd_with_offset(%src: memref<256x128xf32>) {
31+
// CHECK-COUNT-4: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}]
32+
// CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
33+
// CHECK-NOT: xegpu.store_nd
34+
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
35+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
36+
%load = xegpu.load_nd %tdesc[0, 0]
37+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
38+
-> vector<256x128xf32>
39+
xegpu.store_nd %load, %tdesc[0, 0]
40+
: vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
41+
gpu.return
42+
}
43+
44+
// CHECK-LABEL: prefetch_nd_tdesc_with_offset
45+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
46+
gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
47+
// CHECK-COUNT-4: xegpu.prefetch_nd {{%.*}}[{{%.*}}, {{%.*}}]
48+
// CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
49+
// CHECK-NOT: xegpu.prefetch_nd
50+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
51+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
52+
xegpu.prefetch_nd %tdesc[0, 0]
53+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
54+
gpu.return
55+
}
1456
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
22

3+
//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
4+
//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
35
gpu.module @test_distribution {
46
// CHECK-LABEL: create_nd_tdesc_no_offset
57
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -21,4 +23,222 @@ gpu.module @test_distribution {
2123
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
2224
gpu.return
2325
}
26+
27+
// CHECK-LABEL: load_nd_tdesc_with_offset
28+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
29+
gpu.func @load_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
30+
//CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
31+
//CHECK: [[C8:%.+]] = arith.constant 8 : index
32+
//CHECK: [[C4:%.+]] = arith.constant 4 : index
33+
//CHECK: [[C4_1:%.+]] = arith.constant 4 : index
34+
//CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
35+
//CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
36+
//CHECK: [[C32:%.+]] = arith.constant 32 : index
37+
//CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
38+
//CHECK: [[C32_1:%.+]] = arith.constant 32 : index
39+
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
40+
//CHECK: [[C0:%.+]] = arith.constant 0 : index
41+
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
42+
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
43+
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
44+
//CHECK: [[C256:%.+]] = arith.constant 256 : index
45+
//CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
46+
//CHECK: [[C128:%.+]] = arith.constant 128 : index
47+
//CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
48+
//CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
49+
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
50+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
51+
%load = xegpu.load_nd %tdesc[0, 0]
52+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
53+
-> vector<256x128xf32>
54+
gpu.return
55+
}
56+
57+
// CHECK-LABEL: store_nd_with_offsets
58+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
59+
gpu.func @store_nd_with_offsets(%src: memref<256x128xf32>) {
60+
//CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
61+
//CHECK: [[C8:%.+]] = arith.constant 8 : index
62+
//CHECK: [[C4:%.+]] = arith.constant 4 : index
63+
//CHECK: [[C4_1:%.+]] = arith.constant 4 : index
64+
//CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
65+
//CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
66+
//CHECK: [[C32:%.+]] = arith.constant 32 : index
67+
//CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
68+
//CHECK: [[C32_1:%.+]] = arith.constant 32 : index
69+
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
70+
//CHECK: [[C0:%.+]] = arith.constant 0 : index
71+
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
72+
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
73+
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
74+
//CHECK: [[C256:%.+]] = arith.constant 256 : index
75+
//CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
76+
//CHECK: [[C128:%.+]] = arith.constant 128 : index
77+
//CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
78+
//CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
79+
%tdesc = xegpu.create_nd_tdesc %src: memref<256x128xf32>
80+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
81+
%load = xegpu.load_nd %tdesc[0, 0]
82+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
83+
-> vector<256x128xf32>
84+
xegpu.store_nd %load, %tdesc[0, 0]
85+
: vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
86+
gpu.return
87+
}
88+
89+
// CHECK-LABEL: prefetch_nd_tdesc_with_offset
90+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
91+
gpu.func @prefetch_nd_tdesc_with_offset(%src: memref<256x128xf32>) {
92+
//CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
93+
//CHECK: [[C8:%.+]] = arith.constant 8 : index
94+
//CHECK: [[C4:%.+]] = arith.constant 4 : index
95+
//CHECK: [[C4_1:%.+]] = arith.constant 4 : index
96+
//CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
97+
//CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
98+
//CHECK: [[C32:%.+]] = arith.constant 32 : index
99+
//CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
100+
//CHECK: [[C32_1:%.+]] = arith.constant 32 : index
101+
//CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
102+
//CHECK: [[C0:%.+]] = arith.constant 0 : index
103+
//CHECK: [[C0_1:%.+]] = arith.constant 0 : index
104+
//CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
105+
//CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
106+
//CHECK: [[C256:%.+]] = arith.constant 256 : index
107+
//CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
108+
//CHECK: [[C128:%.+]] = arith.constant 128 : index
109+
//CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
110+
//CHECK: xegpu.prefetch_nd %{{.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
111+
%cst0 = arith.constant 0 : index
112+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
113+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
114+
xegpu.prefetch_nd %tdesc[%cst0, %cst0]
115+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
116+
gpu.return
117+
}
118+
119+
// CHECK-LABEL: dpas
120+
gpu.func @dpas(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
121+
// CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
122+
%tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16>
123+
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
124+
%load_a = xegpu.load_nd %tdesc_a[0, 0]
125+
: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1]>>
126+
-> vector<128x128xf16>
127+
%tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16>
128+
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
129+
%load_b = xegpu.load_nd %tdesc_b[0, 0]
130+
: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1]>>
131+
-> vector<128x128xf16>
132+
%dpas = xegpu.dpas %load_a, %load_b
133+
{layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
134+
: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
135+
gpu.return
136+
}
137+
138+
// CHECK-LABEL: dpas_no_sg_data
139+
gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
140+
// CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
141+
%tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16>
142+
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
143+
order = [1, 0]>>
144+
%load_a = xegpu.load_nd %tdesc_a[0, 0]
145+
: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
146+
order = [1, 0]>>
147+
-> vector<128x128xf16>
148+
%tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16>
149+
-> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
150+
order = [1, 0]>>
151+
%load_b = xegpu.load_nd %tdesc_b[0, 0]
152+
: !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
153+
order = [1, 0]>>
154+
-> vector<128x128xf16>
155+
%dpas = xegpu.dpas %load_a, %load_b
156+
{layout_result_0 = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
157+
: vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
158+
gpu.return
159+
}
160+
161+
// CHECK-LABEL: dpas_with_no_create_nd_desc
162+
gpu.func @dpas_with_no_create_nd_desc(%a: vector<256x128xf32>, %b: vector<128x256xf32>) {
163+
// CHECK-NOT: vector<32x32xf32>
164+
%dpas = xegpu.dpas %a, %b
165+
{layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1]>}
166+
: vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
167+
gpu.return
168+
}
169+
170+
// CHECK-LABEL: broadcast_dim1
171+
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32>
172+
gpu.func @broadcast_dim1(%src: memref<256x1xf32>) {
173+
%tdesc = xegpu.create_nd_tdesc %src : memref<256x1xf32>
174+
-> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
175+
%load = xegpu.load_nd %tdesc[0, 0]
176+
: !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 1], lane_layout = [8, 1], lane_data = [1, 1]>>
177+
-> vector<256x1xf32>
178+
// CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
179+
// CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32>
180+
%broadcast = vector.broadcast %load
181+
{layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [32, 32], lane_layout = [8, 1], lane_data = [1, 1]>}
182+
: vector<256x1xf32> to vector<256x32xf32>
183+
gpu.return
184+
}
185+
186+
// CHECK-LABEL: broadcast_dim0
187+
// CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32>
188+
gpu.func @broadcast_dim0(%src: memref<1x128xf32>) {
189+
%tdesc = xegpu.create_nd_tdesc %src : memref<1x128xf32>
190+
-> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
191+
%load = xegpu.load_nd %tdesc[0, 0]
192+
: !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
193+
-> vector<1x128xf32>
194+
// CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
195+
// CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32>
196+
%broadcast = vector.broadcast %load
197+
{layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>}
198+
: vector<1x128xf32> to vector<32x128xf32>
199+
gpu.return
200+
}
201+
202+
gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
203+
//CHECK: [[c0:%.+]] = arith.constant 0 : index
204+
//CHECK: [[c128:%.+]] = arith.constant 128 : index
205+
//CHECK: [[c1024:%.+]] = arith.constant 1024 : index
206+
%c0 = arith.constant 0 : index
207+
%c128 = arith.constant 128 : index
208+
%c1024 = arith.constant 1024 : index
209+
%block_id_x = gpu.block_id x
210+
%block_id_y = gpu.block_id y
211+
%0 = arith.muli %block_id_x, %c128 : index
212+
%1 = arith.muli %block_id_y, %c128 : index
213+
%2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
214+
%3 = xegpu.load_nd %2 : !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf32>
215+
%4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
216+
%5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
217+
218+
// CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]]
219+
// CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) ->
220+
// CHECK-SAME: (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>)
221+
// CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
222+
// CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
223+
// CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
224+
// CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16>
225+
// CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16>
226+
// CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
227+
%6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3)
228+
-> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>,
229+
!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>) {
230+
%8 = xegpu.load_nd %arg4 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
231+
%9 = xegpu.load_nd %arg5 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
232+
%10 = xegpu.dpas %8, %9, %arg6 {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>}
233+
: vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
234+
%11 = xegpu.update_nd_offset %arg4, [%c0, %c128] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
235+
%12 = xegpu.update_nd_offset %arg5, [%c128, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
236+
scf.yield %11, %12, %10 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>,
237+
!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>
238+
}
239+
%7 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32>
240+
-> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
241+
xegpu.store_nd %6#2, %7 : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
242+
gpu.return
243+
}
24244
}

0 commit comments

Comments
 (0)