@@ -186,6 +186,47 @@ module {
186
186
gpu.return
187
187
}
188
188
189
+ gpu.func @inner_reduction_1 (%a: memref <8 x32 xf32 >, %b: memref <8 x1 xf32 >) {
190
+ %c0 = arith.constant 0 : index
191
+ %neg_inf = arith.constant dense <0xFF800000 > : vector <8 xf32 > // -inf
192
+
193
+ %a_tile = xetile.init_tile %a [%c0 , %c0 ] : memref <8 x32 xf32 > -> !xetile.tile <8 x32 xf32 >
194
+ %b_tile = xetile.init_tile %b [%c0 , %c0 ] : memref <8 x1 xf32 > -> !xetile.tile <8 x1 xf32 >
195
+
196
+ //CHECK: xegpu.load_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true>> -> vector<8x16xf32>
197
+ //CHECK: xegpu.load_nd %{{.*}} <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_scope = global, array_length = 1 : i64, boundary_check = true>> -> vector<8x16xf32>
198
+ %a_loaded = xetile.load_tile %a_tile: !xetile.tile <8 x32 xf32 > -> vector <8 x32 xf32 >
199
+
200
+ //CHECK: %[[R1:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
201
+ //CHECK: %[[R2:.*]] = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
202
+ //CHECK: %[[R3:.*]] = arith.maximumf %[[R1]], %[[R2]] : vector<16xf32>
203
+ //CHECK: %[[R4:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
204
+ //CHECK: %[[R5:.*]] = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
205
+ //CHECK: %[[R6:.*]] = arith.maximumf %[[R4]], %[[R5]] : vector<16xf32>
206
+ //CHECK: %[[R7:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
207
+ //CHECK: %[[R8:.*]] = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
208
+ //CHECK: %[[R9:.*]] = arith.maximumf %[[R7]], %[[R8]] : vector<16xf32>
209
+ //CHECK: %[[R10:.*]] = vector.shuffle %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] : vector<16xf32>, vector<16xf32>
210
+ //CHECK: %[[R11:.*]] = vector.shuffle %{{.*}}, %{{.*}} [8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
211
+ //CHECK: %[[R12:.*]] = arith.maximumf %[[R10]], %[[R11]] : vector<16xf32>
212
+ //CHECK: %[[R13:.*]] = vector.shuffle %[[R3]], %[[R6]] [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
213
+ //CHECK: %[[R14:.*]] = vector.shuffle %[[R3]], %[[R6]] [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
214
+ //CHECK: %[[R15:.*]] = arith.maximumf %[[R13]], %[[R14]] : vector<16xf32>
215
+ //CHECK: %[[R16:.*]] = vector.shuffle %[[R9]], %[[R12]] [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27] : vector<16xf32>, vector<16xf32>
216
+ //CHECK: %[[R17:.*]] = vector.shuffle %[[R9]], %[[R12]] [4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
217
+ //CHECK: %[[R18:.*]] = arith.maximumf %[[R16]], %[[R17]] : vector<16xf32>
218
+ //CHECK: %[[R19:.*]] = vector.shuffle %[[R15]], %[[R18]] [0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29] : vector<16xf32>, vector<16xf32>
219
+ //CHECK: %[[R20:.*]] = vector.shuffle %[[R15]], %[[R18]] [2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31] : vector<16xf32>, vector<16xf32>
220
+ //CHECK: %[[R21:.*]] = arith.maximumf %[[R19]], %[[R20]] : vector<16xf32>
221
+ //CHECK: %[[R22:.*]] = vector.shuffle %[[R21]], %[[R21]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xf32>, vector<16xf32>
222
+ //CHECK: %[[R23:.*]] = vector.shuffle %[[R21]], %[[R21]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xf32>, vector<16xf32>
223
+ //CHECK: %[[R24:.*]] = arith.maximumf %[[R22]], %[[R23]] : vector<8xf32>
224
+ %3 = vector.multi_reduction <maximumf >, %a_loaded , %neg_inf [1 ] : vector <8 x32 xf32 > to vector <8 xf32 > // fastmath<nnan> is implicit here
225
+ %reduced = vector.shape_cast %3 : vector <8 xf32 > to vector <8 x1 xf32 >
226
+ xetile.store_tile %reduced , %b_tile : vector <8 x1 xf32 >, !xetile.tile <8 x1 xf32 >
227
+ gpu.return
228
+ }
229
+
189
230
//CHECK: gpu.func @outter_reduction(%[[arg0:.*]]: memref<128x256xf16>, %[[arg1:.*]]: memref<128x256xf16>) {
190
231
gpu.func @outter_reduction (%a: memref <128 x256 xf16 >, %b: memref <128 x256 xf16 >) {
191
232
//CHECK: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<32xf16>
0 commit comments