Skip to content

Commit b1c973c

Browse files
committed
Support mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16
Support mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16
1 parent e9f8178 commit b1c973c

File tree

3 files changed

+94
-2
lines changed

3 files changed

+94
-2
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,13 +1560,21 @@ class SYCLGen : public SYCLGenBase {
15601560
InMatrixType[0] = "uint32_t"; // A type is .f16/.bf16x2
15611561
InMatrixType[1] = "uint32_t"; // B type is .f16/.bf16x2
15621562

1563-
// If A matrix type is f16, then C&D matrix types can only be f32
1563+
// If A matrix type is f16, then C&D matrix types can be f32
15641564
if (CType->getKind() == InlineAsmBuiltinType::f32) {
15651565
NumVecElements[0] = 4; // A
15661566
NumVecElements[1] = 2; // B
15671567
NumVecElements[2] = 4; // C
15681568
NumVecElements[3] = 4; // D
1569-
} else
1569+
}
1570+
// C &D matrix types can be f16.
1571+
else if (CType->getKind() == InlineAsmBuiltinType::f16) {
1572+
NumVecElements[0] = 4; // A
1573+
NumVecElements[1] = 2; // B
1574+
NumVecElements[2] = 2; // C
1575+
NumVecElements[3] = 2; // D
1576+
}
1577+
else
15701578
return SYCLGenError();
15711579
} else if (AType->getKind() == InlineAsmBuiltinType::s8) {
15721580
InMatrixType[0] = "uint32_t"; // A type is .s8x4

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2671,6 +2671,72 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag,
26712671
static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j + 4]);
26722672
}
26732673
}
2674+
} else if constexpr (std::is_same_v<CDType, sycl::half>) {
2675+
// Init D matrix fragment with C matrix fragment
2676+
*const_cast<sycl::half *>(d[0]) = c[0];
2677+
*const_cast<sycl::half *>(d[1]) = c[1];
2678+
*const_cast<sycl::half *>(d[2]) = c[2];
2679+
*const_cast<sycl::half *>(d[3]) = c[3];
2680+
2681+
// Each sub-group is responsible for computing a fragment size of 16*8
2682+
// elements of matrix D.
2683+
// Each work item computes 4 elements of matrix D by gathering
2684+
// their corresponding row & col matrix fragments of length k (8)
2685+
// from A & B matrices respectively using below mapping logic:
2686+
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
2687+
// col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
2688+
// As each row & col fragment of A & B matrices is distributed across
2689+
// 4 work items, each iteration of below loop loads a partial fragment
2690+
// of matrix A (row) and matrix B (col) using the row & col offsets.
2691+
for (int i = 0; i < 4; i++) {
2692+
typename MMAType<ABType>::PackType recv_a[4], recv_b[4];
2693+
2694+
// Load partial fragment from row0 of matrix A ({a0, a1})
2695+
recv_a[0] = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
2696+
// Load partial fragment from row0 of matrix A ({a2, a3})
2697+
recv_a[1] = dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
2698+
// Load partial fragment from row1 of matrix A ({a0, a1})
2699+
recv_a[2] = dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
2700+
// Load partial fragment from row1 of matrix A ({a2, a3})
2701+
recv_a[3] = dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
2702+
2703+
// Load partial fragment from col0 of matrix B ({b0, b1})
2704+
recv_b[0] = dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
2705+
// Load partial fragment from col0 of matrix B ({b2, b3})
2706+
recv_b[1] = dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
2707+
// Load partial fragment from col1 of matrix B ({b0, b1})
2708+
recv_b[2] =
2709+
dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);
2710+
// Load partial fragment from col1 of matrix B ({b2, b3})
2711+
recv_b[3] =
2712+
dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);
2713+
2714+
auto ra = reinterpret_cast<ABType *>(recv_a);
2715+
auto rb = reinterpret_cast<ABType *>(recv_b);
2716+
2717+
// Each work item calculates a partial product of A & B matrix
2718+
// fragments and adds it to the corresponding D matrix fragment d0
2719+
// += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
2720+
// a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1,
2721+
// a2, a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } *
2722+
// col1{ b0, b1, b2, b3 }
2723+
// for (int j = 0; j < 4; j++) {
2724+
// *d[0] +=
2725+
// static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
2726+
// *d[1] += static_cast<CDType>(ra[j]) *
2727+
// static_cast<CDType>(rb[j + 4]);
2728+
// *d[2] += static_cast<CDType>(ra[j + 4]) *
2729+
// static_cast<CDType>(rb[j]);
2730+
// *d[3] += static_cast<CDType>(ra[j + 4]) *
2731+
// static_cast<CDType>(rb[j + 4]);
2732+
2733+
for (int j = 0; j < 4; j++) {
2734+
*const_cast<sycl::half *>(d[0]) += ra[j] * rb[j];
2735+
*const_cast<sycl::half *>(d[1]) += ra[j] * rb[j + 4];
2736+
*const_cast<sycl::half *>(d[2]) += ra[j + 4] * rb[j];
2737+
*const_cast<sycl::half *>(d[3]) += ra[j + 4] * rb[j + 4];
2738+
}
2739+
}
26742740
} else if constexpr (std::is_integral_v<ABType>) {
26752741
// Init D matrix with fragments of C matrix
26762742
*d[0] = c[0];

clang/test/dpct/asm/mma.cu

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,24 @@ __global__ void mma_kernel_m16n8k8(int *a, int *b, float *fc, float *fd) {
100100
"f"(fc[0]), "f"(fc[1]), "f"(fc[2]), "f"(fc[3]));
101101
}
102102

103+
__global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, int *d) {
104+
// CHECK: {
105+
// CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &fc[0], &fc[1]};
106+
// CHECK-NEXT: sycl::vec<uint32_t, 4> a_mat_frag_ct1(a[0], a[1], a[2], a[3]);
107+
// CHECK-NEXT: sycl::vec<uint32_t, 2> b_mat_frag_ct1(b[0], b[1]);
108+
// CHECK-NEXT: sycl::vec<uint32, 4> c_mat_frag_ct1(fc[0], fc[1]);
109+
// CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, sycl::half>(reinterpret_cast<volatile void **>(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1);
110+
// CHECK-NEXT: }
111+
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
112+
" { %0, %1 }, "
113+
" { %2, %3, %4, %5 }, "
114+
" { %6, %7 }, "
115+
" { %0, %1 };"
116+
: "+r"(c[0]), "+r"(c[1])
117+
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
118+
"r"(d[0]), "r"(d[1]));
119+
}
120+
103121
__global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) {
104122
// CHECK: {
105123
// CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] };

0 commit comments

Comments
 (0)