@@ -2394,7 +2394,7 @@ template <typename T> struct MMAType {
2394
2394
// / - m8n8k4 (f32.f16.f16.f32)
2395
2395
// / - m8n8k16 (s32.s8.s8.s32)
2396
2396
// / - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
2397
- // / - m16n8k16 (f32.f16.f16.f32 & f32.bf16.bf16.f32 & s32.s8.s8.s32)
2397
+ // / - m16n8k16 (f32.f16.f16.f32 & f16.f16.f16.f16 & f32.bf16.bf16.f32 & s32.s8.s8.s32)
2398
2398
// / - m16n8k32 (s32.s8.s8.s32)
2399
2399
// / Here, m, n & k define the shapes of A, B & C matrices respectively
2400
2400
// / (A = [m x k], B = [k x n], C = [m x n]).
@@ -2671,6 +2671,66 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag,
2671
2671
static_cast <CDType>(ra[j + 4 ]) * static_cast <CDType>(rb[j + 4 ]);
2672
2672
}
2673
2673
}
2674
+ } else if constexpr (std::is_same_v<CDType, sycl::half>) {
2675
+ // Init D matrix fragment with C matrix fragment
2676
+ sycl::half *d0 = const_cast <sycl::half *>(d[0 ]);
2677
+ sycl::half *d1 = d0 + 1 ;
2678
+ sycl::half *d2 = const_cast <sycl::half *>(d[1 ]);
2679
+ sycl::half *d3 = d2 + 1 ;
2680
+ *d0 = c[0 ];
2681
+ *d1 = c[1 ];
2682
+ *d2 = c[2 ];
2683
+ *d3 = c[3 ];
2684
+
2685
+ // Each sub-group is responsible for computing a fragment size of 16*8
2686
+ // elements of matrix D.
2687
+ // Each work item computes 4 elements of matrix D by gathering
2688
+ // their corresponding row & col matrix fragments of length k (8)
2689
+ // from A & B matrices respectively using below mapping logic:
2690
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
2691
+ // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
2692
+ // As each row & col fragment of A & B matrices is distributed across
2693
+ // 4 work items, each iteration of below loop loads a partial fragment
2694
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
2695
+ for (int i = 0 ; i < 4 ; i++) {
2696
+ typename MMAType<ABType>::PackType recv_a[4 ], recv_b[4 ];
2697
+
2698
+ // Load partial fragment from row0 of matrix A ({a0, a1})
2699
+ recv_a[0 ] = dpct::select_from_sub_group (sg, a[0 ], row_load_offset + i);
2700
+ // Load partial fragment from row0 of matrix A ({a2, a3})
2701
+ recv_a[1 ] = dpct::select_from_sub_group (sg, a[2 ], row_load_offset + i);
2702
+ // Load partial fragment from row1 of matrix A ({a0, a1})
2703
+ recv_a[2 ] = dpct::select_from_sub_group (sg, a[1 ], row_load_offset + i);
2704
+ // Load partial fragment from row1 of matrix A ({a2, a3})
2705
+ recv_a[3 ] = dpct::select_from_sub_group (sg, a[3 ], row_load_offset + i);
2706
+
2707
+ // Load partial fragment from col0 of matrix B ({b0, b1})
2708
+ recv_b[0 ] = dpct::select_from_sub_group (sg, b[0 ], col_load_offset + i);
2709
+ // Load partial fragment from col0 of matrix B ({b2, b3})
2710
+ recv_b[1 ] = dpct::select_from_sub_group (sg, b[1 ], col_load_offset + i);
2711
+ // Load partial fragment from col1 of matrix B ({b0, b1})
2712
+ recv_b[2 ] =
2713
+ dpct::select_from_sub_group (sg, b[0 ], col_load_offset + 4 + i);
2714
+ // Load partial fragment from col1 of matrix B ({b2, b3})
2715
+ recv_b[3 ] =
2716
+ dpct::select_from_sub_group (sg, b[1 ], col_load_offset + 4 + i);
2717
+
2718
+ auto ra = reinterpret_cast <ABType *>(recv_a);
2719
+ auto rb = reinterpret_cast <ABType *>(recv_b);
2720
+
2721
+ // Each work item calculates a partial product of A & B matrix fragments
2722
+ // and adds it to the corresponding D matrix fragment
2723
+ // d0 += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 }
2724
+ // d1 += row0{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 }
2725
+ // d2 += row1{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 }
2726
+ // d3 += row1{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 }
2727
+ for (int j = 0 ; j < 4 ; j++) {
2728
+ *d0 += ra[j] * rb[j];
2729
+ *d1 += ra[j] * rb[j + 4 ];
2730
+ *d2 += ra[j + 4 ] * rb[j];
2731
+ *d3 += ra[j + 4 ] * rb[j + 4 ];
2732
+ }
2733
+ }
2674
2734
} else if constexpr (std::is_integral_v<ABType>) {
2675
2735
// Init D matrix with fragments of C matrix
2676
2736
*d[0 ] = c[0 ];
0 commit comments