Skip to content

Commit e9f8178

Browse files
authored
[SYCLomatic][PTX][MMA] Support migration of PTX mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 (#2941)
Signed-off-by: [email protected]
1 parent 5ba90d2 commit e9f8178

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,10 +1554,11 @@ class SYCLGen : public SYCLGenBase {
15541554
N = "8";
15551555
K = "16";
15561556

1557-
// Only f16/s8 types are supported for A and B matrices of m16n8k16
1558-
if (AType->getKind() == InlineAsmBuiltinType::f16) {
1559-
InMatrixType[0] = "uint32_t"; // A type is .f16x2
1560-
InMatrixType[1] = "uint32_t"; // B type is .f16x2
1557+
// Only f16/s8/bf16 types are supported for A and B matrices of m16n8k16
1558+
if (AType->getKind() == InlineAsmBuiltinType::f16 ||
1559+
AType->getKind() == InlineAsmBuiltinType::bf16) {
1560+
InMatrixType[0] = "uint32_t"; // A type is .f16/.bf16x2
1561+
InMatrixType[1] = "uint32_t"; // B type is .f16/.bf16x2
15611562

15621563
// If A matrix type is f16, then C&D matrix types can only be f32
15631564
if (CType->getKind() == InlineAsmBuiltinType::f32) {

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2394,7 +2394,7 @@ template <typename T> struct MMAType {
23942394
/// - m8n8k4 (f32.f16.f16.f32)
23952395
/// - m8n8k16 (s32.s8.s8.s32)
23962396
/// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
2397-
/// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32)
2397+
/// - m16n8k16 (f32.f16.f16.f32 & f32.bf16.bf16.f32 & s32.s8.s8.s32)
23982398
/// - m16n8k32 (s32.s8.s8.s32)
23992399
/// Here, m, n & k define the shapes of A, B & C matrices respectively
24002400
/// (A = [m x k], B = [k x n], C = [m x n]).

clang/test/dpct/asm/mma.cu

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ As per PTX ASM 8.1, below is the status of supported configurations
1717
m8n8k4 .f16 .f16 .f32
1818
m8n8k16 .s8 .s8 .s32
1919
m16n8k8 .f16/.bf16 .f16/.bf16 .f32
20-
m16n8k16 .f16 .f16 .f32
20+
m16n8k16 .f16 .f16 .f32
21+
.bf16 .bf16 .f32
2122
.s8 .s8 .s32
2223
m16n8k32 .s8 .s8 .s32
2324
@@ -116,6 +117,22 @@ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) {
116117
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
117118
"r"(b[0]), "r"(b[1]));
118119

120+
// CHECK: {
121+
// CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] };
122+
// CHECK-NEXT: sycl::vec<uint32_t, 4> a_mat_frag_ct1(a[0], a[1], a[2], a[3]);
123+
// CHECK-NEXT: sycl::vec<uint32_t, 2> b_mat_frag_ct1(b[0], b[1]);
124+
// CHECK-NEXT: sycl::vec<float, 4> c_mat_frag_ct1(fc[0], fc[1], fc[2], fc[3]);
125+
// CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::ext::oneapi::bfloat16, float>(reinterpret_cast<volatile void **>(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1);
126+
// CHECK-NEXT: }
127+
asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
128+
" { %0, %1, %2, %3 }, "
129+
" { %4, %5, %6, %7 }, "
130+
" { %8, %9 }, "
131+
" { %0, %1, %2, %3 };"
132+
: "+f"(fc[0]), "+f"(fc[1]), "+f"(fc[2]), "+f"(fc[3])
133+
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
134+
"r"(b[0]), "r"(b[1]));
135+
119136
// CHECK: {
120137
// CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &d[0], &d[1], &d[2], &d[3] };
121138
// CHECK-NEXT: sycl::vec<uint32_t, 2> a_mat_frag_ct1(a[0], a[1]);

0 commit comments

Comments
 (0)