@@ -17,7 +17,8 @@ As per PTX ASM 8.1, below is the status of supported configurations
1717m8n8k4 .f16 .f16 .f32
1818m8n8k16 .s8 .s8 .s32
1919m16n8k8 .f16/.bf16 .f16/.bf16 .f32
20- m16n8k16 .f16 .f16 .f32
20+ m16n8k16 .f16 .f16 .f32
21+ .bf16 .bf16 .f32
2122 .s8 .s8 .s32
2223m16n8k32 .s8 .s8 .s32
2324
@@ -116,6 +117,22 @@ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) {
116117 : " r" (a[0 ]), " r" (a[1 ]), " r" (a[2 ]), " r" (a[3 ]),
117118 " r" (b[0 ]), " r" (b[1 ]));
118119
120+ // CHECK: {
121+ // CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] };
122+ // CHECK-NEXT: sycl::vec<uint32_t, 4> a_mat_frag_ct1(a[0], a[1], a[2], a[3]);
123+ // CHECK-NEXT: sycl::vec<uint32_t, 2> b_mat_frag_ct1(b[0], b[1]);
124+ // CHECK-NEXT: sycl::vec<float, 4> c_mat_frag_ct1(fc[0], fc[1], fc[2], fc[3]);
125+ // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::ext::oneapi::bfloat16, float>(reinterpret_cast<volatile void **>(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1);
126+ // CHECK-NEXT: }
127+ asm (" mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
128+ " { %0, %1, %2, %3 }, "
129+ " { %4, %5, %6, %7 }, "
130+ " { %8, %9 }, "
131+ " { %0, %1, %2, %3 };"
132+ : " +f" (fc[0 ]), " +f" (fc[1 ]), " +f" (fc[2 ]), " +f" (fc[3 ])
133+ : " r" (a[0 ]), " r" (a[1 ]), " r" (a[2 ]), " r" (a[3 ]),
134+ " r" (b[0 ]), " r" (b[1 ]));
135+
119136 // CHECK: {
120137 // CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &d[0], &d[1], &d[2], &d[3] };
121138 // CHECK-NEXT: sycl::vec<uint32_t, 2> a_mat_frag_ct1(a[0], a[1]);
0 commit comments