Skip to content

Commit 00142c1

Browse files
authored
[SYCLomatic][PTX] Support migration of asm PTX instruction movmatrix.sync.aligned.m8n8.trans.b16 (#2947)
Signed-off-by: [email protected]
1 parent 3fe420f commit 00142c1

File tree

3 files changed

+88
-0
lines changed

3 files changed

+88
-0
lines changed

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3525,6 +3525,34 @@ class SYCLGen : public SYCLGenBase {
35253525
endstmt();
35263526
return SYCLGenSuccess();
35273527
}
3528+
3529+
3530+
bool handle_movmatrix(const InlineAsmInstruction *Inst) override {
3531+
if (Inst->getNumInputOperands() != 1)
3532+
return SYCLGenError();
3533+
3534+
const auto *Type = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));
3535+
3536+
if (!Type || Type->getKind() != InlineAsmBuiltinType::b16)
3537+
return SYCLGenError();
3538+
3539+
OS() << MapNames::getDpctNamespace() << "experimental::matrix::movmatrix(";
3540+
3541+
3542+
if (emitStmt(Inst->getOutputOperand()))
3543+
return SYCLGenError();
3544+
3545+
OS() << ", ";
3546+
3547+
if (emitStmt(Inst->getInputOperand(0)))
3548+
return SYCLGenError();
3549+
3550+
OS() << ")";
3551+
3552+
endstmt();
3553+
3554+
return SYCLGenSuccess();
3555+
}
35283556
};
35293557

35303558
/// Clean the special character in identifier.

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2859,6 +2859,37 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag,
28592859
}
28602860
}
28612861

2862+
/// Transpose 1 8x8 b16 (128 bytes) matrix per sub-group. Requires the sub-group
2863+
/// size of kernel calling this function to be 32.
2864+
/// \param [output] output: The register to store the transposed matrix fragment. It refers to 2
2865+
/// b16 type elements.
2866+
/// \param [in] input: The register to store the matrix fragment. It refers to 2 b16
2867+
/// type elements.
2868+
void movmatrix(uint32_t &output, uint32_t &input) {
2869+
auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
2870+
int laneid = sg.get_local_linear_id();
2871+
2872+
int elm0_row = laneid / 4;
2873+
int elm0_col = (laneid % 4) * 2;
2874+
int elm1_row = elm0_row;
2875+
int elm1_col = elm0_col + 1;
2876+
int src0_row = elm0_col;
2877+
int src0_col = elm0_row;
2878+
int src1_row = elm1_col;
2879+
int src1_col = elm1_row;
2880+
int src0_laneid = src0_row * 4 + src0_col / 2;
2881+
int src0_pos = src0_col % 2;
2882+
int src1_laneid = src1_row * 4 + src1_col / 2;
2883+
int src1_pos = src1_col % 2;
2884+
2885+
auto recv0 = dpct::select_from_sub_group(sg, *(uint32_t *)(&input), src0_laneid);
2886+
auto recv1 = dpct::select_from_sub_group(sg, *(uint32_t *)(&input), src1_laneid);
2887+
2888+
auto ptr_out = reinterpret_cast<sycl::half *>(&output);
2889+
ptr_out[0] = reinterpret_cast<sycl::half *>(&recv0)[src0_pos];
2890+
ptr_out[1] = reinterpret_cast<sycl::half *>(&recv1)[src1_pos];
2891+
}
2892+
28622893
} // namespace matrix
28632894
} // namespace experimental
28642895

clang/test/dpct/asm/movmatrix.cu

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2
2+
// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2
3+
// RUN: dpct --format-range=none -out-root %T/movmatrix %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only
4+
// RUN: FileCheck %s --match-full-lines --input-file %T/movmatrix/movmatrix.dp.cpp
5+
// RUN: %if build_lit %{icpx -c -fsycl %T/movmatrix/movmatrix.dp.cpp -o %T/movmatrix/movmatrix.dp.o %}
6+
7+
// clang-format off
8+
#include <cuda_runtime.h>
9+
#include <cstdint>
10+
#include <cuda_bf16.h>
11+
12+
using bf16_2 = __nv_bfloat162;
13+
14+
//Syntax:
15+
//movmatrix.sync.aligned.shape.trans.type d, a;
16+
//.shape = {.m8n8};
17+
//.type = {.b16};#include <cuda_bf16.h>
18+
// Only .m8n8.b16
19+
//
20+
21+
__global__ void movmatrix(bf16_2 &dst, const bf16_2 &src) {
22+
23+
// CHECK: dpct::experimental::matrix::movmatrix(*(uint32_t *)(&dst), (*(uint32_t *)(&src)));
24+
asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n"
25+
: "+r"(*(uint32_t *)(&dst))
26+
: "r"(*(uint32_t *)(&src)));
27+
}
28+
29+
// clang-format on

0 commit comments

Comments
 (0)