Skip to content

Commit 3fe420f

Browse files
authored
Support mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 (#2943)
* Support mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16
1 parent 6acdddd commit 3fe420f

File tree

3 files changed

+95
-6
lines changed

3 files changed

+95
-6
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,6 +1488,8 @@ class SYCLGen : public SYCLGenBase {
14881488
// Data types of A, B & C matrices respectively in the PTX arguments
14891489
std::string InMatrixType[3];
14901490

1491+
InMatrixType[2] = CDType;
1492+
14911493
if (Inst->hasAttr(InstAttr::m8n8k4)) {
14921494
M = "8";
14931495
N = "8";
@@ -1560,13 +1562,22 @@ class SYCLGen : public SYCLGenBase {
15601562
InMatrixType[0] = "uint32_t"; // A type is .f16/.bf16x2
15611563
InMatrixType[1] = "uint32_t"; // B type is .f16/.bf16x2
15621564

1563-
// If A matrix type is f16, then C&D matrix types can only be f32
1565+
// If A matrix type is f16, then C&D matrix types can be f32
15641566
if (CType->getKind() == InlineAsmBuiltinType::f32) {
15651567
NumVecElements[0] = 4; // A
15661568
NumVecElements[1] = 2; // B
15671569
NumVecElements[2] = 4; // C
15681570
NumVecElements[3] = 4; // D
1569-
} else
1571+
}
1572+
// C &D matrix types can be f16.
1573+
else if (CType->getKind() == InlineAsmBuiltinType::f16) {
1574+
NumVecElements[0] = 4; // A
1575+
NumVecElements[1] = 2; // B
1576+
NumVecElements[2] = 2; // C
1577+
NumVecElements[3] = 2; // D
1578+
InMatrixType[2] = "uint32_t"; // C type is f16*2
1579+
}
1580+
else
15701581
return SYCLGenError();
15711582
} else if (AType->getKind() == InlineAsmBuiltinType::s8) {
15721583
InMatrixType[0] = "uint32_t"; // A type is .s8x4
@@ -1605,8 +1616,6 @@ class SYCLGen : public SYCLGenBase {
16051616
} else
16061617
return SYCLGenError();
16071618

1608-
InMatrixType[2] = CDType;
1609-
16101619
// Check the register sizes for vector elements of A, B, C & D matrices
16111620
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
16121621
InputOp++) {

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2394,7 +2394,7 @@ template <typename T> struct MMAType {
23942394
/// - m8n8k4 (f32.f16.f16.f32)
23952395
/// - m8n8k16 (s32.s8.s8.s32)
23962396
/// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
2397-
/// - m16n8k16 (f32.f16.f16.f32 & f32.bf16.bf16.f32 & s32.s8.s8.s32)
2397+
/// - m16n8k16 (f32.f16.f16.f32 & f16.f16.f16.f16 & f32.bf16.bf16.f32 & s32.s8.s8.s32)
23982398
/// - m16n8k32 (s32.s8.s8.s32)
23992399
/// Here, m, n & k define the shapes of A, B & C matrices respectively
24002400
/// (A = [m x k], B = [k x n], C = [m x n]).
@@ -2671,6 +2671,66 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag,
26712671
static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j + 4]);
26722672
}
26732673
}
2674+
} else if constexpr (std::is_same_v<CDType, sycl::half>) {
2675+
// Init D matrix fragment with C matrix fragment
2676+
sycl::half *d0 = const_cast<sycl::half *>(d[0]);
2677+
sycl::half *d1 = d0 + 1;
2678+
sycl::half *d2 = const_cast<sycl::half *>(d[1]);
2679+
sycl::half *d3 = d2 + 1;
2680+
*d0 = c[0];
2681+
*d1 = c[1];
2682+
*d2 = c[2];
2683+
*d3 = c[3];
2684+
2685+
// Each sub-group is responsible for computing a fragment size of 16*8
2686+
// elements of matrix D.
2687+
// Each work item computes 4 elements of matrix D by gathering
2688+
// their corresponding row & col matrix fragments of length k (8)
2689+
// from A & B matrices respectively using below mapping logic:
2690+
// row0 = (lane >> 2) & row1 = (lane >> 2) + 8
2691+
// col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
2692+
// As each row & col fragment of A & B matrices is distributed across
2693+
// 4 work items, each iteration of below loop loads a partial fragment
2694+
// of matrix A (row) and matrix B (col) using the row & col offsets.
2695+
for (int i = 0; i < 4; i++) {
2696+
typename MMAType<ABType>::PackType recv_a[4], recv_b[4];
2697+
2698+
// Load partial fragment from row0 of matrix A ({a0, a1})
2699+
recv_a[0] = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
2700+
// Load partial fragment from row0 of matrix A ({a2, a3})
2701+
recv_a[1] = dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
2702+
// Load partial fragment from row1 of matrix A ({a0, a1})
2703+
recv_a[2] = dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
2704+
// Load partial fragment from row1 of matrix A ({a2, a3})
2705+
recv_a[3] = dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
2706+
2707+
// Load partial fragment from col0 of matrix B ({b0, b1})
2708+
recv_b[0] = dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
2709+
// Load partial fragment from col0 of matrix B ({b2, b3})
2710+
recv_b[1] = dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
2711+
// Load partial fragment from col1 of matrix B ({b0, b1})
2712+
recv_b[2] =
2713+
dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);
2714+
// Load partial fragment from col1 of matrix B ({b2, b3})
2715+
recv_b[3] =
2716+
dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);
2717+
2718+
auto ra = reinterpret_cast<ABType *>(recv_a);
2719+
auto rb = reinterpret_cast<ABType *>(recv_b);
2720+
2721+
// Each work item calculates a partial product of A & B matrix fragments
2722+
// and adds it to the corresponding D matrix fragment
2723+
// d0 += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 }
2724+
// d1 += row0{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 }
2725+
// d2 += row1{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 }
2726+
// d3 += row1{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 }
2727+
for (int j = 0; j < 4; j++) {
2728+
*d0 += ra[j] * rb[j];
2729+
*d1 += ra[j] * rb[j + 4];
2730+
*d2 += ra[j + 4] * rb[j];
2731+
*d3 += ra[j + 4] * rb[j + 4];
2732+
}
2733+
}
26742734
} else if constexpr (std::is_integral_v<ABType>) {
26752735
// Init D matrix with fragments of C matrix
26762736
*d[0] = c[0];

clang/test/dpct/asm/mma.cu

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ m8n8k16 .s8 .s8 .s32
1919
m16n8k8 .f16/.bf16 .f16/.bf16 .f32
2020
m16n8k16 .f16 .f16 .f32
2121
.bf16 .bf16 .f32
22-
.s8 .s8 .s32
22+
.s8 .s8 .s32
23+
.f16 .f16 .f16
2324
m16n8k32 .s8 .s8 .s32
2425
2526
Except for m8n8k4, all other shapes are supported for row/col layout of A/B matrices respectively.
@@ -100,6 +101,25 @@ __global__ void mma_kernel_m16n8k8(int *a, int *b, float *fc, float *fd) {
100101
"f"(fc[0]), "f"(fc[1]), "f"(fc[2]), "f"(fc[3]));
101102
}
102103

104+
__global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, int *d) {
105+
// CHECK: {
106+
// CHECK-NEXT: volatile void *d_mat_frag_ct1[2] = { &d[0], &d[1] };
107+
// CHECK-NEXT: sycl::vec<uint32_t, 4> a_mat_frag_ct1(a[0], a[1], a[2], a[3]);
108+
// CHECK-NEXT: sycl::vec<uint32_t, 2> b_mat_frag_ct1(b[0], b[1]);
109+
// CHECK-NEXT: sycl::vec<uint32_t, 2> c_mat_frag_ct1(c[0], c[1]);
110+
// CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, sycl::half>(reinterpret_cast<volatile void **>(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1);
111+
// CHECK-NEXT: }
112+
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
113+
" { %0, %1 }, "
114+
" { %2, %3, %4, %5 }, "
115+
" { %6, %7 }, "
116+
" { %8, %9 };"
117+
: "+r"(d[0]), "+r"(d[1])
118+
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
119+
"r"(b[0]), "r"(b[1]),
120+
"r"(c[0]), "r"(c[1]));
121+
}
122+
103123
__global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) {
104124
// CHECK: {
105125
// CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] };

0 commit comments

Comments
 (0)