11// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
22
3+ //CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
4+ //CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
35gpu.module @test_distribution {
46 // CHECK-LABEL: create_nd_tdesc_no_offset
57 // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -21,4 +23,222 @@ gpu.module @test_distribution {
2123 -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
2224 gpu.return
2325 }
26+
27+ // CHECK-LABEL: load_nd_tdesc_with_offset
28+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
29+ gpu.func @load_nd_tdesc_with_offset (%src: memref <256 x128 xf32 >) {
30+ //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
31+ //CHECK: [[C8:%.+]] = arith.constant 8 : index
32+ //CHECK: [[C4:%.+]] = arith.constant 4 : index
33+ //CHECK: [[C4_1:%.+]] = arith.constant 4 : index
34+ //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
35+ //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
36+ //CHECK: [[C32:%.+]] = arith.constant 32 : index
37+ //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
38+ //CHECK: [[C32_1:%.+]] = arith.constant 32 : index
39+ //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
40+ //CHECK: [[C0:%.+]] = arith.constant 0 : index
41+ //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
42+ //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
43+ //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
44+ //CHECK: [[C256:%.+]] = arith.constant 256 : index
45+ //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
46+ //CHECK: [[C128:%.+]] = arith.constant 128 : index
47+ //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
48+ //CHECK: %[[LOAD:.*]] = xegpu.load_nd {{%.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<32x32xf32>
49+ %tdesc = xegpu.create_nd_tdesc %src: memref <256 x128 xf32 >
50+ -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
51+ %load = xegpu.load_nd %tdesc [0 , 0 ]
52+ : !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
53+ -> vector <256 x128 xf32 >
54+ gpu.return
55+ }
56+
57+ // CHECK-LABEL: store_nd_with_offsets
58+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
59+ gpu.func @store_nd_with_offsets (%src: memref <256 x128 xf32 >) {
60+ //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
61+ //CHECK: [[C8:%.+]] = arith.constant 8 : index
62+ //CHECK: [[C4:%.+]] = arith.constant 4 : index
63+ //CHECK: [[C4_1:%.+]] = arith.constant 4 : index
64+ //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
65+ //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
66+ //CHECK: [[C32:%.+]] = arith.constant 32 : index
67+ //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
68+ //CHECK: [[C32_1:%.+]] = arith.constant 32 : index
69+ //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
70+ //CHECK: [[C0:%.+]] = arith.constant 0 : index
71+ //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
72+ //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
73+ //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
74+ //CHECK: [[C256:%.+]] = arith.constant 256 : index
75+ //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
76+ //CHECK: [[C128:%.+]] = arith.constant 128 : index
77+ //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
78+ //CHECK: xegpu.store_nd %{{.*}}, {{%.*}}[{{%.*}}, {{%.*}}] : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
79+ %tdesc = xegpu.create_nd_tdesc %src: memref <256 x128 xf32 >
80+ -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
81+ %load = xegpu.load_nd %tdesc [0 , 0 ]
82+ : !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
83+ -> vector <256 x128 xf32 >
84+ xegpu.store_nd %load , %tdesc [0 , 0 ]
85+ : vector <256 x128 xf32 >, !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
86+ gpu.return
87+ }
88+
89+ // CHECK-LABEL: prefetch_nd_tdesc_with_offset
90+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
91+ gpu.func @prefetch_nd_tdesc_with_offset (%src: memref <256 x128 xf32 >) {
92+ //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
93+ //CHECK: [[C8:%.+]] = arith.constant 8 : index
94+ //CHECK: [[C4:%.+]] = arith.constant 4 : index
95+ //CHECK: [[C4_1:%.+]] = arith.constant 4 : index
96+ //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
97+ //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
98+ //CHECK: [[C32:%.+]] = arith.constant 32 : index
99+ //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
100+ //CHECK: [[C32_1:%.+]] = arith.constant 32 : index
101+ //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32_1]]
102+ //CHECK: [[C0:%.+]] = arith.constant 0 : index
103+ //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
104+ //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
105+ //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
106+ //CHECK: [[C256:%.+]] = arith.constant 256 : index
107+ //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
108+ //CHECK: [[C128:%.+]] = arith.constant 128 : index
109+ //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
110+ //CHECK: xegpu.prefetch_nd %{{.*}}[{{%.*}}, {{%.*}}] : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
111+ %cst0 = arith.constant 0 : index
112+ %tdesc = xegpu.create_nd_tdesc %src : memref <256 x128 xf32 >
113+ -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
114+ xegpu.prefetch_nd %tdesc [%cst0 , %cst0 ]
115+ : !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
116+ gpu.return
117+ }
118+
119+ // CHECK-LABEL: dpas
120+ gpu.func @dpas (%a: memref <128 x128 xf16 >, %b: memref <128 x128 xf16 >) {
121+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
122+ %tdesc_a = xegpu.create_nd_tdesc %a : memref <128 x128 xf16 >
123+ -> !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 128 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
124+ %load_a = xegpu.load_nd %tdesc_a [0 , 0 ]
125+ : !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 128 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
126+ -> vector <128 x128 xf16 >
127+ %tdesc_b = xegpu.create_nd_tdesc %b : memref <128 x128 xf16 >
128+ -> !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [128 , 16 ], lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>>
129+ %load_b = xegpu.load_nd %tdesc_b [0 , 0 ]
130+ : !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [128 , 16 ], lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>>
131+ -> vector <128 x128 xf16 >
132+ %dpas = xegpu.dpas %load_a , %load_b
133+ {layout_result_0 = #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 16 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>}
134+ : vector <128 x128 xf16 >, vector <128 x128 xf16 > -> vector <128 x128 xf32 >
135+ gpu.return
136+ }
137+
138+ // CHECK-LABEL: dpas_no_sg_data
139+ gpu.func @dpas_no_sg_data (%a: memref <128 x128 xf16 >, %b: memref <128 x128 xf16 >) {
140+ // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
141+ %tdesc_a = xegpu.create_nd_tdesc %a : memref <128 x128 xf16 >
142+ -> !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ],
143+ order = [1 , 0 ]>>
144+ %load_a = xegpu.load_nd %tdesc_a [0 , 0 ]
145+ : !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ],
146+ order = [1 , 0 ]>>
147+ -> vector <128 x128 xf16 >
148+ %tdesc_b = xegpu.create_nd_tdesc %b : memref <128 x128 xf16 >
149+ -> !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], lane_layout = [1 , 16 ], lane_data = [2 , 1 ],
150+ order = [1 , 0 ]>>
151+ %load_b = xegpu.load_nd %tdesc_b [0 , 0 ]
152+ : !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], lane_layout = [1 , 16 ], lane_data = [2 , 1 ],
153+ order = [1 , 0 ]>>
154+ -> vector <128 x128 xf16 >
155+ %dpas = xegpu.dpas %load_a , %load_b
156+ {layout_result_0 = #xegpu.layout <sg_layout = [8 , 8 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ], order = [1 , 0 ]>}
157+ : vector <128 x128 xf16 >, vector <128 x128 xf16 > -> vector <128 x128 xf32 >
158+ gpu.return
159+ }
160+
161+ // CHECK-LABEL: dpas_with_no_create_nd_desc
162+ gpu.func @dpas_with_no_create_nd_desc (%a: vector <256 x128 xf32 >, %b: vector <128 x256 xf32 >) {
163+ // CHECK-NOT: vector<32x32xf32>
164+ %dpas = xegpu.dpas %a , %b
165+ {layout = #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 16 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>}
166+ : vector <256 x128 xf32 >, vector <128 x256 xf32 > -> vector <256 x256 xf32 >
167+ gpu.return
168+ }
169+
170+ // CHECK-LABEL: broadcast_dim1
171+ // CHECK-SAME: %[[ARG_0:.*]]: memref<256x1xf32>
172+ gpu.func @broadcast_dim1 (%src: memref <256 x1 xf32 >) {
173+ %tdesc = xegpu.create_nd_tdesc %src : memref <256 x1 xf32 >
174+ -> !xegpu.tensor_desc <256 x1 xf32 , #xegpu.layout <sg_layout = [8 , 1 ], sg_data = [32 , 1 ], lane_layout = [8 , 1 ], lane_data = [1 , 1 ]>>
175+ %load = xegpu.load_nd %tdesc [0 , 0 ]
176+ : !xegpu.tensor_desc <256 x1 xf32 , #xegpu.layout <sg_layout = [8 , 1 ], sg_data = [32 , 1 ], lane_layout = [8 , 1 ], lane_data = [1 , 1 ]>>
177+ -> vector <256 x1 xf32 >
178+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 1]>}
179+ // CHECK-SAME: : vector<32x1xf32> to vector<32x32xf32>
180+ %broadcast = vector.broadcast %load
181+ {layout_result_0 = #xegpu.layout <sg_layout = [8 , 1 ], sg_data = [32 , 32 ], lane_layout = [8 , 1 ], lane_data = [1 , 1 ]>}
182+ : vector <256 x1 xf32 > to vector <256 x32 xf32 >
183+ gpu.return
184+ }
185+
186+ // CHECK-LABEL: broadcast_dim0
187+ // CHECK-SAME: %[[ARG_0:.*]]: memref<1x128xf32>
188+ gpu.func @broadcast_dim0 (%src: memref <1 x128 xf32 >) {
189+ %tdesc = xegpu.create_nd_tdesc %src : memref <1 x128 xf32 >
190+ -> !xegpu.tensor_desc <1 x128 xf32 , #xegpu.layout <sg_layout = [1 , 4 ], sg_data = [1 , 32 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
191+ %load = xegpu.load_nd %tdesc [0 , 0 ]
192+ : !xegpu.tensor_desc <1 x128 xf32 , #xegpu.layout <sg_layout = [1 , 4 ], sg_data = [1 , 32 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
193+ -> vector <1 x128 xf32 >
194+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
195+ // CHECK-SAME: : vector<1x32xf32> to vector<32x32xf32>
196+ %broadcast = vector.broadcast %load
197+ {layout_result_0 = #xegpu.layout <sg_layout = [1 , 4 ], sg_data = [32 , 32 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>}
198+ : vector <1 x128 xf32 > to vector <32 x128 xf32 >
199+ gpu.return
200+ }
201+
202+ gpu.func @scf_for (%arg0: memref <1024 x1024 xf16 >, %arg1: memref <1024 x1024 xf16 >, %arg2: memref <1024 x1024 xf32 >) {
203+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
204+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
205+ //CHECK: [[c1024:%.+]] = arith.constant 1024 : index
206+ %c0 = arith.constant 0 : index
207+ %c128 = arith.constant 128 : index
208+ %c1024 = arith.constant 1024 : index
209+ %block_id_x = gpu.block_id x
210+ %block_id_y = gpu.block_id y
211+ %0 = arith.muli %block_id_x , %c128 : index
212+ %1 = arith.muli %block_id_y , %c128 : index
213+ %2 = xegpu.create_nd_tdesc %arg2 [%0 , %1 ] : memref <1024 x1024 xf32 > -> !xegpu.tensor_desc <128 x128 xf32 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 16 ]>>
214+ %3 = xegpu.load_nd %2 : !xegpu.tensor_desc <128 x128 xf32 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 16 ]>> -> vector <128 x128 xf32 >
215+ %4 = xegpu.create_nd_tdesc %arg0 [%0 , %c0 ] : memref <1024 x1024 xf16 > -> !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 128 ]>>
216+ %5 = xegpu.create_nd_tdesc %arg1 [%c0 , %1 ] : memref <1024 x1024 xf16 > -> !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [128 , 16 ]>>
217+
218+ // CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]]
219+ // CHECK-SAME: iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) ->
220+ // CHECK-SAME: (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>)
221+ // CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
222+ // CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
223+ // CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
224+ // CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16>
225+ // CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16>
226+ // CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
227+ %6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args (%arg4 = %4 , %arg5 = %5 , %arg6 = %3 )
228+ -> (!xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 128 ]>>,
229+ !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [128 , 16 ]>>, vector <128 x128 xf32 >) {
230+ %8 = xegpu.load_nd %arg4 : !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 128 ]>> -> vector <128 x128 xf16 >
231+ %9 = xegpu.load_nd %arg5 : !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [128 , 16 ]>> -> vector <128 x128 xf16 >
232+ %10 = xegpu.dpas %8 , %9 , %arg6 {layout_result_0 = #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 16 ]>}
233+ : vector <128 x128 xf16 >, vector <128 x128 xf16 >, vector <128 x128 xf32 > -> vector <128 x128 xf32 >
234+ %11 = xegpu.update_nd_offset %arg4 , [%c0 , %c128 ] : !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 128 ]>>
235+ %12 = xegpu.update_nd_offset %arg5 , [%c128 , %c0 ] : !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [128 , 16 ]>>
236+ scf.yield %11 , %12 , %10 : !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 128 ]>>,
237+ !xegpu.tensor_desc <128 x128 xf16 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [128 , 16 ]>>, vector <128 x128 xf32 >
238+ }
239+ %7 = xegpu.create_nd_tdesc %arg2 [%0 , %1 ] : memref <1024 x1024 xf32 >
240+ -> !xegpu.tensor_desc <128 x128 xf32 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 16 ]>>
241+ xegpu.store_nd %6#2 , %7 : vector <128 x128 xf32 >, !xegpu.tensor_desc <128 x128 xf32 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 16 ]>>
242+ gpu.return
243+ }
24244}
0 commit comments