@@ -189,9 +189,104 @@ int main() {
189189 // This combination is not currently supported for sub group size = 32 in
190190 // IGC
191191#if (!defined(SG_SZ) || SG_SZ != 32)
192+ // 8x16x16 float/bfloat16
193+ std::cout << " 8x16x16 float/bfloat16" << std::endl;
194+ // A
195+ test_get_coord_op<bfloat16, float , /* TM*/ 8 , /* TK*/ 16 , use::a,
196+ layout::row_major, 1 >();
197+ // B
198+ test_get_coord_op<bfloat16, float , /* TK*/ 16 , /* TN*/ 16 , use::b,
199+ layout::ext_intel_packed, 2 >();
192200 test_get_coord_op<bfloat16, float , /* TK*/ 16 , /* TN*/ 16 , use::b,
193201 layout::row_major, 1 >();
194- test_get_coord_op<int8_t , int32_t , /* TK*/ 32 , /* TN*/ 16 , use::b,
202+ // Accumulator
203+ test_get_coord_op<bfloat16, float , /* TM*/ 8 , /* TN*/ 16 , use::accumulator,
204+ layout::row_major, 1 >();
205+ test_get_coord_op<float , float , /* TM*/ 8 , /* TN*/ 16 , use::accumulator,
206+ layout::row_major, 1 >();
207+
208+
209+ // 16x16x16 float/bfloat16
210+ std::cout << " 16x16x16 float/bfloat16" << std::endl;
211+ // A
212+ test_get_coord_op<bfloat16, float , /* TM*/ 16 , /* TK*/ 16 , use::a,
213+ layout::row_major, 1 >();
214+ // B
215+ // Duplicate from 8x16x16
216+ // test_get_coord_op<bfloat16, float, /*TK*/ 16, /*TN*/ 16, use::b,
217+ // layout::ext_intel_packed, 2>();
218+ // test_get_coord_op<bfloat16, float, /*TK*/ 16, /*TN*/ 16, use::b,
219+ // layout::row_major, 1>();
220+ // Accumulator
221+ test_get_coord_op<bfloat16, float , /* TM*/ 16 , /* TN*/ 16 , use::accumulator,
222+ layout::row_major, 1 >();
223+ test_get_coord_op<float , float , /* TM*/ 16 , /* TN*/ 16 , use::accumulator,
224+ layout::row_major, 1 >();
225+
226+ // 1x64x16 float/bfloat16
227+ std::cout << " 1x64x16 float/bfloat16" << std::endl;
228+ // A
229+ test_get_coord_op<bfloat16, float , /* TM*/ 1 , /* TK*/ 16 , use::a,
230+ layout::row_major, 1 >();
231+ // B
232+ test_get_coord_op<bfloat16, float , /* TK*/ 16 , /* TN*/ 64 , use::b,
233+ layout::ext_intel_packed, 2 >();
234+ test_get_coord_op<bfloat16, float , /* TK*/ 16 , /* TN*/ 64 , use::b,
235+ layout::row_major, 1 >();
236+ // Accumulator
237+ test_get_coord_op<bfloat16, float , /* TM*/ 1 , /* TN*/ 64 , use::accumulator,
238+ layout::row_major, 1 >();
239+ test_get_coord_op<float , float , /* TM*/ 1 , /* TN*/ 64 , use::accumulator,
240+ layout::row_major, 1 >();
241+
242+ // 1x64x32 float/bfloat16
243+ std::cout << " 1x64x32 float/bfloat16" << std::endl;
244+ // A
245+ test_get_coord_op<bfloat16, float , /* TM*/ 1 , /* TK*/ 32 , use::a,
246+ layout::row_major, 1 >();
247+ // B
248+ test_get_coord_op<bfloat16, float , /* TK*/ 32 , /* TN*/ 64 , use::b,
249+ layout::ext_intel_packed, 2 >();
250+ test_get_coord_op<bfloat16, float , /* TK*/ 32 , /* TN*/ 64 , use::b,
251+ layout::row_major, 1 >();
252+ // Accumulator
253+ test_get_coord_op<bfloat16, float , /* TM*/ 1 , /* TN*/ 64 , use::accumulator,
254+ layout::row_major, 1 >();
255+ test_get_coord_op<float , float , /* TM*/ 1 , /* TN*/ 64 , use::accumulator,
256+ layout::row_major, 1 >();
257+
258+ // 32x64x16 float/bfloat16
259+ std::cout << " 32x64x16 float/bfloat16" << std::endl;
260+ // A
261+ test_get_coord_op<bfloat16, float , /* TM*/ 32 , /* TK*/ 16 , use::a,
262+ layout::row_major, 1 >();
263+ // B
264+ // Duplicate from 1x64x16
265+ // test_get_coord_op<bfloat16, float, /*TK*/ 16, /*TN*/ 64, use::b,
266+ // layout::ext_intel_packed, 2>();
267+ // test_get_coord_op<bfloat16, float, /*TK*/ 16, /*TN*/ 64, use::b,
268+ // layout::row_major, 1>();
269+ // Accumulator
270+ test_get_coord_op<bfloat16, float , /* TM*/ 32 , /* TN*/ 64 , use::accumulator,
271+ layout::row_major, 1 >();
272+ test_get_coord_op<float , float , /* TM*/ 32 , /* TN*/ 64 , use::accumulator,
273+ layout::row_major, 1 >();
274+
275+ // // 32x64x32 float/bfloat16
276+ std::cout << " 32x64x32 float/bfloat16" << std::endl;
277+ // A
278+ test_get_coord_op<bfloat16, float , /* TM*/ 32 , /* TK*/ 32 , use::a,
279+ layout::row_major, 1 >();
280+ // B
281+ // Duplicate from 1x64x32
282+ // test_get_coord_op<bfloat16, float, /*TK*/ 32, /*TN*/ 64, use::b,
283+ // layout::ext_intel_packed, 2>();
284+ // test_get_coord_op<bfloat16, float, /*TK*/ 32, /*TN*/ 64, use::b,
285+ // layout::row_major, 1>();
286+ // Accumulator
287+ test_get_coord_op<bfloat16, float , /* TM*/ 32 , /* TN*/ 64 , use::accumulator,
288+ layout::row_major, 1 >();
289+ test_get_coord_op<float , float , /* TM*/ 32 , /* TN*/ 64 , use::accumulator,
195290 layout::row_major, 1 >();
196291#endif
197292 break ;
0 commit comments