@@ -189,9 +189,104 @@ int main() {
189
189
// This combination is not currently supported for sub group size = 32 in
190
190
// IGC
191
191
#if (!defined(SG_SZ) || SG_SZ != 32)
192
+ // 8x16x16 float/bfloat16
193
+ std::cout << " 8x16x16 float/bfloat16" << std::endl;
194
+ // A
195
+ test_get_coord_op<bfloat16, float , /* TM*/ 8 , /* TK*/ 16 , use::a,
196
+ layout::row_major, 1 >();
197
+ // B
198
+ test_get_coord_op<bfloat16, float , /* TK*/ 16 , /* TN*/ 16 , use::b,
199
+ layout::ext_intel_packed, 2 >();
192
200
test_get_coord_op<bfloat16, float , /* TK*/ 16 , /* TN*/ 16 , use::b,
193
201
layout::row_major, 1 >();
194
- test_get_coord_op<int8_t , int32_t , /* TK*/ 32 , /* TN*/ 16 , use::b,
202
+ // Accumulator
203
+ test_get_coord_op<bfloat16, float , /* TM*/ 8 , /* TN*/ 16 , use::accumulator,
204
+ layout::row_major, 1 >();
205
+ test_get_coord_op<float , float , /* TM*/ 8 , /* TN*/ 16 , use::accumulator,
206
+ layout::row_major, 1 >();
207
+
208
+
209
+ // 16x16x16 float/bfloat16
210
+ std::cout << " 16x16x16 float/bfloat16" << std::endl;
211
+ // A
212
+ test_get_coord_op<bfloat16, float , /* TM*/ 16 , /* TK*/ 16 , use::a,
213
+ layout::row_major, 1 >();
214
+ // B
215
+ // Duplicate from 8x16x16
216
+ // test_get_coord_op<bfloat16, float, /*TK*/ 16, /*TN*/ 16, use::b,
217
+ // layout::ext_intel_packed, 2>();
218
+ // test_get_coord_op<bfloat16, float, /*TK*/ 16, /*TN*/ 16, use::b,
219
+ // layout::row_major, 1>();
220
+ // Accumulator
221
+ test_get_coord_op<bfloat16, float , /* TM*/ 16 , /* TN*/ 16 , use::accumulator,
222
+ layout::row_major, 1 >();
223
+ test_get_coord_op<float , float , /* TM*/ 16 , /* TN*/ 16 , use::accumulator,
224
+ layout::row_major, 1 >();
225
+
226
+ // 1x64x16 float/bfloat16
227
+ std::cout << " 1x64x16 float/bfloat16" << std::endl;
228
+ // A
229
+ test_get_coord_op<bfloat16, float , /* TM*/ 1 , /* TK*/ 16 , use::a,
230
+ layout::row_major, 1 >();
231
+ // B
232
+ test_get_coord_op<bfloat16, float , /* TK*/ 16 , /* TN*/ 64 , use::b,
233
+ layout::ext_intel_packed, 2 >();
234
+ test_get_coord_op<bfloat16, float , /* TK*/ 16 , /* TN*/ 64 , use::b,
235
+ layout::row_major, 1 >();
236
+ // Accumulator
237
+ test_get_coord_op<bfloat16, float , /* TM*/ 1 , /* TN*/ 64 , use::accumulator,
238
+ layout::row_major, 1 >();
239
+ test_get_coord_op<float , float , /* TM*/ 1 , /* TN*/ 64 , use::accumulator,
240
+ layout::row_major, 1 >();
241
+
242
+ // 1x64x32 float/bfloat16
243
+ std::cout << " 1x64x32 float/bfloat16" << std::endl;
244
+ // A
245
+ test_get_coord_op<bfloat16, float , /* TM*/ 1 , /* TK*/ 32 , use::a,
246
+ layout::row_major, 1 >();
247
+ // B
248
+ test_get_coord_op<bfloat16, float , /* TK*/ 32 , /* TN*/ 64 , use::b,
249
+ layout::ext_intel_packed, 2 >();
250
+ test_get_coord_op<bfloat16, float , /* TK*/ 32 , /* TN*/ 64 , use::b,
251
+ layout::row_major, 1 >();
252
+ // Accumulator
253
+ test_get_coord_op<bfloat16, float , /* TM*/ 1 , /* TN*/ 64 , use::accumulator,
254
+ layout::row_major, 1 >();
255
+ test_get_coord_op<float , float , /* TM*/ 1 , /* TN*/ 64 , use::accumulator,
256
+ layout::row_major, 1 >();
257
+
258
+ // 32x64x16 float/bfloat16
259
+ std::cout << " 32x64x16 float/bfloat16" << std::endl;
260
+ // A
261
+ test_get_coord_op<bfloat16, float , /* TM*/ 32 , /* TK*/ 16 , use::a,
262
+ layout::row_major, 1 >();
263
+ // B
264
+ // Duplicate from 1x64x16
265
+ // test_get_coord_op<bfloat16, float, /*TK*/ 16, /*TN*/ 64, use::b,
266
+ // layout::ext_intel_packed, 2>();
267
+ // test_get_coord_op<bfloat16, float, /*TK*/ 16, /*TN*/ 64, use::b,
268
+ // layout::row_major, 1>();
269
+ // Accumulator
270
+ test_get_coord_op<bfloat16, float , /* TM*/ 32 , /* TN*/ 64 , use::accumulator,
271
+ layout::row_major, 1 >();
272
+ test_get_coord_op<float , float , /* TM*/ 32 , /* TN*/ 64 , use::accumulator,
273
+ layout::row_major, 1 >();
274
+
275
+ // // 32x64x32 float/bfloat16
276
+ std::cout << " 32x64x32 float/bfloat16" << std::endl;
277
+ // A
278
+ test_get_coord_op<bfloat16, float , /* TM*/ 32 , /* TK*/ 32 , use::a,
279
+ layout::row_major, 1 >();
280
+ // B
281
+ // Duplicate from 1x64x32
282
+ // test_get_coord_op<bfloat16, float, /*TK*/ 32, /*TN*/ 64, use::b,
283
+ // layout::ext_intel_packed, 2>();
284
+ // test_get_coord_op<bfloat16, float, /*TK*/ 32, /*TN*/ 64, use::b,
285
+ // layout::row_major, 1>();
286
+ // Accumulator
287
+ test_get_coord_op<bfloat16, float , /* TM*/ 32 , /* TN*/ 64 , use::accumulator,
288
+ layout::row_major, 1 >();
289
+ test_get_coord_op<float , float , /* TM*/ 32 , /* TN*/ 64 , use::accumulator,
195
290
layout::row_major, 1 >();
196
291
#endif
197
292
break ;
0 commit comments