1
+ #include < stdio.h>
2
+ #include < cstdint>
3
+ #include < iostream>
4
+ #include < cuda_bf16.h>
5
+ #include < cuda_fp16.h>
6
+
7
+ #define NO_HALVES_PER_BLOCK 1024
8
+ using bf16_2 = __nv_bfloat162;
9
+
10
+ // Syntax:
11
+ // movmatrix.sync.aligned.shape.trans.type d, a;
12
+ // .shape = {.m8n8};
13
+ // .type = {.b16};#include <cuda_bf16.h>
14
+ // Only .m8n8.b16
15
+ int Shape_M = 8 , Shape_N = 8 ;
16
+
17
+ #define TEST (FN ) \
18
+ { \
19
+ if (FN ()) { \
20
+ printf (" Test " #FN " PASS\n " ); \
21
+ } else { \
22
+ printf (" Test " #FN " FAIL\n " ); \
23
+ return 1 ; \
24
+ } \
25
+ }
26
+
27
+
28
+ __device__ inline void movmatrix (bf16_2 &dst, const bf16_2 &src) {
29
+ asm volatile (" movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n "
30
+ : " +r" (*(uint32_t *)(&dst))
31
+ : " r" (*(uint32_t *)(&src)));
32
+ }
33
+
34
+ __device__ void ldmatrix (void *addr, volatile int *r) {
35
+ unsigned int addr_int = __cvta_generic_to_shared (addr);
36
+
37
+ asm volatile (" ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n "
38
+ : " =r" (r[0 ])
39
+ : " r" (addr_int));
40
+ }
41
+
42
+ __global__ void test_movmatrix_fragment (half *input, half *output,
43
+ const int TOTAL_ELEMENTS) {
44
+ __shared__ half shared_data[NO_HALVES_PER_BLOCK];
45
+
46
+ int lane_id = threadIdx .x % 32 ;
47
+
48
+ // Load matrix inputs into shared memory.
49
+ for (int i = threadIdx .x ; i < TOTAL_ELEMENTS; i += blockDim .x ) {
50
+ shared_data[i] = input[i];
51
+ }
52
+
53
+ __syncthreads ();
54
+
55
+ int row_offset = 0 ;
56
+ if (lane_id < 8 ) {
57
+ row_offset += (8 * lane_id);
58
+ }
59
+
60
+ void *addr = shared_data + row_offset;
61
+
62
+
63
+ volatile int rt;
64
+ volatile int r;
65
+
66
+ // Load matrix fragment from shared memory to register.
67
+ ldmatrix (addr, &r);
68
+
69
+ // Transpose matrix fragment in register.
70
+ movmatrix (*(bf16_2 *)&rt, *(bf16_2 *)&r);
71
+ int d_ind = 2 * lane_id;
72
+
73
+ if (d_ind + 1 < TOTAL_ELEMENTS) {
74
+ output[d_ind] = ((half *)(&rt))[0 ];
75
+ output[d_ind + 1 ] = ((half *)(&rt))[1 ];
76
+ }
77
+ }
78
+
79
+ bool ldmatrix_movmatrix_m8n8_b16 () {
80
+ const int TOTAL_ELEMENTS = Shape_M * Shape_N;
81
+ // Allocate host memory for matrices
82
+ half *h_input = new half[TOTAL_ELEMENTS];
83
+ half *h_output = new half[TOTAL_ELEMENTS];
84
+ half *exp_output = new half[TOTAL_ELEMENTS];
85
+
86
+ // Allocate device memory for matrices
87
+ half *d_input;
88
+ half *d_output;
89
+ cudaMalloc (&d_input, TOTAL_ELEMENTS * sizeof (half));
90
+ cudaMalloc (&d_output, TOTAL_ELEMENTS * sizeof (half));
91
+ cudaMemset (d_output, 0 , TOTAL_ELEMENTS * sizeof (half));
92
+
93
+ // // Initialize input matrix with some values
94
+ for (int i = 0 ; i < TOTAL_ELEMENTS; i++) {
95
+ h_input[i] = static_cast <half>(i);
96
+ }
97
+
98
+ // Copy input matrix to device
99
+ cudaMemcpy (d_input, h_input, TOTAL_ELEMENTS * sizeof (half),
100
+ cudaMemcpyHostToDevice);
101
+
102
+ // Initialize expected matrix with some values
103
+ int val = 0 ;
104
+
105
+ for (int c = 0 ; c < Shape_N; c++) {
106
+ for (int r = 0 ; r < Shape_M; r++) {
107
+ exp_output[r * Shape_N + c] =
108
+ static_cast <half>(val++);
109
+ }
110
+ }
111
+
112
+
113
+ test_movmatrix_fragment<<<1 , 32 >>> (d_input, d_output, TOTAL_ELEMENTS);
114
+ cudaDeviceSynchronize ();
115
+
116
+ // Copy output matrix back to host
117
+ cudaMemcpy (h_output, d_output, TOTAL_ELEMENTS * sizeof (half),
118
+ cudaMemcpyDeviceToHost);
119
+
120
+
121
+ // Compare input & expected matrices data
122
+ bool res = true ;
123
+ for (int r = 0 ; r < 8 ; r++) {
124
+ for (int c = 0 ; c < 8 ; c++) {
125
+ int index = r * 8 + c;
126
+
127
+ float out = __half2float (h_output[index]);
128
+ float exp_out = __half2float (exp_output[index]);
129
+ if (out != exp_out) {
130
+ std::cout << " Mismatch at index " << index << " : expected " << exp_out
131
+ << " , got " << out << std::endl;
132
+ res = false ;
133
+ }
134
+ }
135
+ }
136
+
137
+ delete[] h_input;
138
+ delete[] h_output;
139
+ cudaFree (d_input);
140
+ cudaFree (d_output);
141
+
142
+ return res;
143
+ }
144
+
145
+ int main () {
146
+ TEST (ldmatrix_movmatrix_m8n8_b16);
147
+ }
0 commit comments