77#include < cuda.h>
88#include < cstdlib>
99
10- #define CEIL_DIV (M, N ) (((M) + (N)- 1 ) / (N))
10+ #define CEIL_DIV (M, N ) (((M) + (N) - 1 ) / (N))
1111#define OFFSET (row, col, ld ) ((row) * (ld) + (col))
1212#define FETCH_FLOAT4 (pointer ) (reinterpret_cast <float4 *>(&(pointer))[0 ])
1313
14+ void free_resource (float *ptr, int is_cuda = 1 )
15+ {
16+ if (nullptr != ptr)
17+ {
18+ if (is_cuda)
19+ {
20+ cudaFree (ptr);
21+ }
22+ else
23+ {
24+ delete[] ptr;
25+ }
26+ }
27+ ptr = nullptr ;
28+ }
29+
1430void sgemm_naive_cpu (float *A, float *B, float *C, int M, int N, int K)
1531{
1632 for (int x = 0 ; x < M; x++)
@@ -36,8 +52,8 @@ __global__ void __launch_bounds__((BM * BN) / (TM * TN), 1) sgemm_vectorize_kern
3652 const uint c_row = blockIdx .y ;
3753 const uint c_col = blockIdx .x ;
3854
39- const int block_row_thread = BN / TN ;
40- const int block_col_thread = BM / TM ;
55+ const int block_row_thread = BM / TM ;
56+ const int block_col_thread = BN / TN ;
4157 // 一个线程负责计算 block 中 TM*TN 个元素
4258 const int thread_num = block_row_thread * block_col_thread;
4359
@@ -73,8 +89,8 @@ __global__ void __launch_bounds__((BM * BN) / (TM * TN), 1) sgemm_vectorize_kern
7389 C += c_row * BM * N + c_col * BN;
7490
7591 float thread_results[TM * TN] = {0.0 };
76- // 每个线程搬运ldg_a_num轮,寄存器缓存ldg_a_num个float4元素,用于转置As矩阵
77- float ldg_reg_a[4 * ldg_a_num ] = {0 .};
92+ // 转置时,只用大小为 4 的数组就可以
93+ float ldg_reg_a[4 ] = {0 .};
7894 float reg_a[TM] = {0.0 }; // 缓存 smem_a
7995 float reg_b[TN] = {0.0 }; // 缓存 smem_b
8096
@@ -83,13 +99,12 @@ __global__ void __launch_bounds__((BM * BN) / (TM * TN), 1) sgemm_vectorize_kern
8399 {
84100 for (int i = 0 ; i < BM; i += stride_a)
85101 {
86- int ldg_index = i / stride_a * 4 ;
87- FETCH_FLOAT4 (ldg_reg_a[ldg_index]) = FETCH_FLOAT4 (A[OFFSET (i + inner_row_a, inner_col_a, K)]);
102+ FETCH_FLOAT4 (ldg_reg_a[0 ]) = FETCH_FLOAT4 (A[OFFSET (i + inner_row_a, inner_col_a, K)]);
88103 // smem_a 转置存,其中 ldg_reg_a 做中间缓存,目的是读取时可以按FLOAT4读取
89- smem_a[OFFSET (inner_col_a, i + inner_row_a, BM)] = ldg_reg_a[ldg_index ];
90- smem_a[OFFSET (inner_col_a + 1 , i + inner_row_a, BM)] = ldg_reg_a[ldg_index + 1 ];
91- smem_a[OFFSET (inner_col_a + 2 , i + inner_row_a, BM)] = ldg_reg_a[ldg_index + 2 ];
92- smem_a[OFFSET (inner_col_a + 3 , i + inner_row_a, BM)] = ldg_reg_a[ldg_index + 3 ];
104+ smem_a[OFFSET (inner_col_a, i + inner_row_a, BM)] = ldg_reg_a[0 ];
105+ smem_a[OFFSET (inner_col_a + 1 , i + inner_row_a, BM)] = ldg_reg_a[1 ];
106+ smem_a[OFFSET (inner_col_a + 2 , i + inner_row_a, BM)] = ldg_reg_a[2 ];
107+ smem_a[OFFSET (inner_col_a + 3 , i + inner_row_a, BM)] = ldg_reg_a[3 ];
93108 }
94109
95110 for (int i = 0 ; i < BK; i += stride_b)
@@ -166,7 +181,7 @@ int main(int argc, char *argv[])
166181
167182 // Allocate memory for matrices
168183 float *A, *B, *C, *C_ref;
169- float *d_A, *d_B, *d_C, *d_C_ref ;
184+ float *d_A, *d_B, *d_C;
170185
171186 A = new float [m * k];
172187 B = new float [k * n];
@@ -183,17 +198,10 @@ int main(int argc, char *argv[])
183198 cudaMalloc ((void **)&d_B, k * n * sizeof (float ));
184199 cudaMalloc ((void **)&d_C, m * n * sizeof (float ));
185200
186- // Copy data to device
187- cudaMalloc ((void **)&d_A, m * k * sizeof (float ));
188- cudaMalloc ((void **)&d_B, k * n * sizeof (float ));
189- cudaMalloc ((void **)&d_C, m * n * sizeof (float ));
190- cudaMalloc ((void **)&d_C_ref, m * n * sizeof (float ));
191-
192201 // Copy matrices to device
193202 cudaMemcpy (d_A, A, m * k * sizeof (float ), cudaMemcpyHostToDevice);
194203 cudaMemcpy (d_B, B, k * n * sizeof (float ), cudaMemcpyHostToDevice);
195204 cudaMemcpy (d_C, C, m * n * sizeof (float ), cudaMemcpyHostToDevice);
196- cudaMemcpy (d_C_ref, C_ref, m * n * sizeof (float ), cudaMemcpyHostToDevice);
197205
198206 run_sgemm_vectorize (d_A, d_B, d_C, m, n, k);
199207
@@ -230,5 +238,15 @@ int main(int argc, char *argv[])
230238 cudaEventElapsedTime (&elapsed_time, start, stop);
231239 float avg_run_time = elapsed_time * 1000 / 100 ;
232240 printf (" Average run time: %f us\n " , avg_run_time);
241+
242+ free_resource (A, 0 );
243+ free_resource (B, 0 );
244+ free_resource (C, 0 );
245+ free_resource (C_ref, 0 );
246+
247+ free_resource (d_A, 1 );
248+ free_resource (d_B, 1 );
249+ free_resource (d_C, 1 );
250+
233251 return 0 ;
234252}
0 commit comments