Skip to content

Commit 5d749a5

Browse files
committed
more enhancement
more enhancement
1 parent b1c973c commit 5d749a5

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,6 +1488,8 @@ class SYCLGen : public SYCLGenBase {
14881488
// Data types of A, B & C matrices respectively in the PTX arguments
14891489
std::string InMatrixType[3];
14901490

1491+
InMatrixType[2] = CDType;
1492+
14911493
if (Inst->hasAttr(InstAttr::m8n8k4)) {
14921494
M = "8";
14931495
N = "8";
@@ -1573,6 +1575,7 @@ class SYCLGen : public SYCLGenBase {
15731575
NumVecElements[1] = 2; // B
15741576
NumVecElements[2] = 2; // C
15751577
NumVecElements[3] = 2; // D
1578+
InMatrixType[2] = "uint32_t"; // C type is f16*2
15761579
}
15771580
else
15781581
return SYCLGenError();
@@ -1613,8 +1616,6 @@ class SYCLGen : public SYCLGenBase {
16131616
} else
16141617
return SYCLGenError();
16151618

1616-
InMatrixType[2] = CDType;
1617-
16181619
// Check the register sizes for vector elements of A, B, C & D matrices
16191620
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
16201621
InputOp++) {

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2673,10 +2673,14 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag,
26732673
}
26742674
} else if constexpr (std::is_same_v<CDType, sycl::half>) {
26752675
// Init D matrix fragment with C matrix fragment
2676-
*const_cast<sycl::half *>(d[0]) = c[0];
2677-
*const_cast<sycl::half *>(d[1]) = c[1];
2678-
*const_cast<sycl::half *>(d[2]) = c[2];
2679-
*const_cast<sycl::half *>(d[3]) = c[3];
2676+
sycl::half *d0 = const_cast<sycl::half *>(d[0]);
2677+
sycl::half *d1 = d0 + 1;
2678+
sycl::half *d2 = const_cast<sycl::half *>(d[1]);
2679+
sycl::half *d3 = d2 + 1;
2680+
*d0 = c[0];
2681+
*d1 = c[1];
2682+
*d2 = c[2];
2683+
*d3 = c[3];
26802684

26812685
// Each sub-group is responsible for computing a fragment size of 16*8
26822686
// elements of matrix D.
@@ -2731,10 +2735,10 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag,
27312735
// static_cast<CDType>(rb[j + 4]);
27322736

27332737
for (int j = 0; j < 4; j++) {
2734-
*const_cast<sycl::half *>(d[0]) += ra[j] * rb[j];
2735-
*const_cast<sycl::half *>(d[1]) += ra[j] * rb[j + 4];
2736-
*const_cast<sycl::half *>(d[2]) += ra[j + 4] * rb[j];
2737-
*const_cast<sycl::half *>(d[3]) += ra[j + 4] * rb[j + 4];
2738+
*d0 += ra[j] * rb[j];
2739+
*d1 += ra[j] * rb[j + 4];
2740+
*d2 += ra[j + 4] * rb[j];
2741+
*d3 += ra[j + 4] * rb[j + 4];
27382742
}
27392743
}
27402744
} else if constexpr (std::is_integral_v<ABType>) {

clang/test/dpct/asm/mma.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ m8n8k16 .s8 .s8 .s32
1919
m16n8k8 .f16/.bf16 .f16/.bf16 .f32
2020
m16n8k16 .f16 .f16 .f32
2121
.bf16 .bf16 .f32
22-
.s8 .s8 .s32
22+
.s8 .s8 .s32
23+
.f16 .f16 .f16
2324
m16n8k32 .s8 .s8 .s32
2425
2526
Except for m8n8k4, all other shapes are supported for row/col layout of A/B matrices respectively.

0 commit comments

Comments
 (0)