@@ -2671,6 +2671,72 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag,
26712671 static_cast <CDType>(ra[j + 4 ]) * static_cast <CDType>(rb[j + 4 ]);
26722672 }
26732673 }
2674+ } else if constexpr (std::is_same_v<CDType, sycl::half>) {
2675+ // Init D matrix fragment with C matrix fragment
2676+ *const_cast <sycl::half *>(d[0 ]) = c[0 ];
2677+ *const_cast <sycl::half *>(d[1 ]) = c[1 ];
2678+ *const_cast <sycl::half *>(d[2 ]) = c[2 ];
2679+ *const_cast <sycl::half *>(d[3 ]) = c[3 ];
2680+
2681+ // Each sub-group is responsible for computing a fragment size of 16*8
2682+ // elements of matrix D.
2683+ // Each work item computes 4 elements of matrix D by gathering
2684+ // their corresponding row & col matrix fragments of length k (8)
2685+ // from A & B matrices respectively using below mapping logic:
2686+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
2687+ // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
2688+ // As each row & col fragment of A & B matrices is distributed across
2689+ // 4 work items, each iteration of below loop loads a partial fragment
2690+ // of matrix A (row) and matrix B (col) using the row & col offsets.
2691+ for (int i = 0 ; i < 4 ; i++) {
2692+ typename MMAType<ABType>::PackType recv_a[4 ], recv_b[4 ];
2693+
2694+ // Load partial fragment from row0 of matrix A ({a0, a1})
2695+ recv_a[0 ] = dpct::select_from_sub_group (sg, a[0 ], row_load_offset + i);
2696+ // Load partial fragment from row0 of matrix A ({a2, a3})
2697+ recv_a[1 ] = dpct::select_from_sub_group (sg, a[2 ], row_load_offset + i);
2698+ // Load partial fragment from row1 of matrix A ({a0, a1})
2699+ recv_a[2 ] = dpct::select_from_sub_group (sg, a[1 ], row_load_offset + i);
2700+ // Load partial fragment from row1 of matrix A ({a2, a3})
2701+ recv_a[3 ] = dpct::select_from_sub_group (sg, a[3 ], row_load_offset + i);
2702+
2703+ // Load partial fragment from col0 of matrix B ({b0, b1})
2704+ recv_b[0 ] = dpct::select_from_sub_group (sg, b[0 ], col_load_offset + i);
2705+ // Load partial fragment from col0 of matrix B ({b2, b3})
2706+ recv_b[1 ] = dpct::select_from_sub_group (sg, b[1 ], col_load_offset + i);
2707+ // Load partial fragment from col1 of matrix B ({b0, b1})
2708+ recv_b[2 ] =
2709+ dpct::select_from_sub_group (sg, b[0 ], col_load_offset + 4 + i);
2710+ // Load partial fragment from col1 of matrix B ({b2, b3})
2711+ recv_b[3 ] =
2712+ dpct::select_from_sub_group (sg, b[1 ], col_load_offset + 4 + i);
2713+
2714+ auto ra = reinterpret_cast <ABType *>(recv_a);
2715+ auto rb = reinterpret_cast <ABType *>(recv_b);
2716+
2717+ // Each work item calculates a partial product of A & B matrix
2718+ // fragments and adds it to the corresponding D matrix fragment d0
2719+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
2720+ // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1,
2721+ // a2, a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } *
2722+ // col1{ b0, b1, b2, b3 }
2723+ // for (int j = 0; j < 4; j++) {
2724+ // *d[0] +=
2725+ // static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
2726+ // *d[1] += static_cast<CDType>(ra[j]) *
2727+ // static_cast<CDType>(rb[j + 4]);
2728+ // *d[2] += static_cast<CDType>(ra[j + 4]) *
2729+ // static_cast<CDType>(rb[j]);
2730+ // *d[3] += static_cast<CDType>(ra[j + 4]) *
2731+ // static_cast<CDType>(rb[j + 4]);
2732+
2733+ for (int j = 0 ; j < 4 ; j++) {
2734+ *const_cast <sycl::half *>(d[0 ]) += ra[j] * rb[j];
2735+ *const_cast <sycl::half *>(d[1 ]) += ra[j] * rb[j + 4 ];
2736+ *const_cast <sycl::half *>(d[2 ]) += ra[j + 4 ] * rb[j];
2737+ *const_cast <sycl::half *>(d[3 ]) += ra[j + 4 ] * rb[j + 4 ];
2738+ }
2739+ }
26742740 } else if constexpr (std::is_integral_v<ABType>) {
26752741 // Init D matrix with fragments of C matrix
26762742 *d[0 ] = c[0 ];
0 commit comments