66
77在介绍二维 Thread Tile 之前,我们先来回顾一下一维 Thread Tile 的优化方法。在初级系列中,我们使用了一维线程块来优化矩阵乘法的性能,我们将矩阵乘法的计算任务分配给了一维线程块,每个线程块负责计算一个小的矩阵块。这样做的好处是可以充分利用共享内存,减少全局内存的访问次数,从而提高矩阵乘法的性能。
88
9- 我们在每个线程中计算了一维的矩阵块。想要继续优化这个 Kernel 的性能,我们可以使用二维线程块来计算二维的矩阵块。
9+ 还记得一维 Thread Tile中的例子吗?如果输入的 A 和 B 都是 8x8 的矩阵:
10+
11+ 1 . 如果我们一次读取 1 行 A 和 1 列 B,当每一个线程只计算一个结果的时候,我们需要从 A 中读取 8 个数据,从 B 中读取 8 个数据,从 C 中读取 1 个数据,然后写一次 C。这样的话,每个线程需要读取 16 个数据,写一次数据。一共需要 64 个线程,共 64x17 = 1088 次 IO。
12+ 2 . 如果我们一次读取 4 行 A 和 1 列 B,那么每一个线程计算 4 个结果,次数需要从 A 中读取 4x8 个数据,从 B 中读取 8 个数据,从 C 中读取 4 个数据,然后写 4 次 C。一共需要 64/4=16 个线程,共 16x48 = 768 次 IO。
13+ 3 . 如果我们一次读取 4 行 A 和 4 列 B,那么每一个线程计算 16 个结果,次数需要从 A 中读取 4x8 个数据,从 B 中读取 4x8 个数据,从 C 中读取 16 个数据,然后写 16 次 C。一共需要 64/16=4 个线程,共 4x96 = 384 次 IO。
14+
15+ 上述的 ` 2 ` 就是一维 Thread Tile 优化,上述的 ` 3 ` 就是 二维 Thread Tile 优化,计算结果不变的同时,减少 IO 次数,提升算法效率。所以想要继续优化这个 Kernel 的性能,我们可以使用二维线程块来计算二维的矩阵块。
1016
1117## 2. 二维 Thread Tile
1218
1319### 2.1 优化思路
1420
1521本文的主要优化思路就是让每个线程计算一个 8\* 8 的网格。下面我们来看一下这个 Kernel 的主题思路图:
1622
17- ![ picture 1] ( images/9047246849f79b5117961c15e1a3a340a44ab003566140ecc00600058c70a9e2.png )
23+ ![ picture 1] ( images/9047246849f79b5117961c15e1a3a340a44ab003566140ecc00600058c70a9e2.png )
1824
1925首先在内核的第一阶段, 所有线程协同工作, 从全局内存中加载矩阵 A 和矩阵 B 到共享内存中。
2026
@@ -74,9 +80,9 @@ float thread_results[TM * TN] = {0.0};
7480float reg_m[ TM] = {0.0};
7581float reg_n[ TN] = {0.0};
7682
77- A += c_row * BM * K;
78- B += c_col * BN;
79- C += c_row * BM * N + c_col * BN;
83+ A += c_row * BM * K;
84+ B += c_col * BN;
85+ C += c_row * BM * N + c_col * BN;
8086
8187// 外层循环
8288for (uint bkIdx = 0; bkIdx < K; bkIdx += BK)
@@ -107,7 +113,7 @@ B += BK * N;
107113
108114下图可以更好的帮助我们理解上面的代码:
109115
110- ![ picture 2] ( images/f507ad687528e8bbb14a85c1fa3016cce50be55b5670ebc425c549cc5c5bd5a6.png )
116+ ![ picture 2] ( images/f507ad687528e8bbb14a85c1fa3016cce50be55b5670ebc425c549cc5c5bd5a6.png )
111117
112118图中画出了矩阵 A 加载共享内存的过程。在每一步中,每个线程负责加载一个元素到共享内存中。这个元素的索引是 ` inner_row_A ` 和 ` inner_col_A ` 。for 循环中的 ` load_offset ` 递增的步长是 ` stride_A ` 。在图中就是向下移动了 ` stride_A ` 个元素。
113119
@@ -130,7 +136,7 @@ for (uint dot_idx = 0; dot_idx < BK; ++dot_idx)
130136 {
131137 for (uint reg_idx_n = 0; reg_idx_n < TN; ++reg_idx_n)
132138 {
133- thread_results[reg_idx_m * TN + reg_idx_n] +=
139+ thread_results[reg_idx_m * TN + reg_idx_n] +=
134140 reg_m[reg_idx_m] * reg_n[reg_idx_n];
135141 }
136142 }
@@ -158,7 +164,7 @@ for (uint reg_idx_m = 0; reg_idx_m < TM; ++reg_idx_m)
158164{
159165 for (uint reg_idx_n = 0; reg_idx_n < TN; ++reg_idx_n)
160166 {
161- C[(thread_row * TM + reg_idx_m) * N + thread_col * TN + reg_idx_n] =
167+ C[(thread_row * TM + reg_idx_m) * N + thread_col * TN + reg_idx_n] =
162168 thread_results[reg_idx_m * TN + reg_idx_n];
163169 }
164170}
@@ -174,20 +180,20 @@ nvcc -o sgemm_tiled2d sgemm_tiled2d.cu
174180## 3. 性能测试
175181
176182我们将上该内核的性能和之前的内核进行比较,我们分别计算 256x256、512x512、1024x1024、2048x2048 (Matrix 1、Matrix 2、Matrix 3、Matrix 4、Matrix 5)的矩阵乘法的性能 (us)。在 1080Ti 上运行,结果如下:
177-
178183
179- | Algorithm | Matrix 1 | Matrix 2 | Matrix 3 | Matrix 4 |
180- | --------- | -------- | -------- | -------- | -------- |
181- | Naive | 95.5152 | 724.396 | 28424 | 228681 |
182- | 共享内存缓存块 | 40.5293 | 198.374 | 8245.68 | 59048.4 |
183- | 一维 Thread Tile | 35.215 | 174.731 | 894.779 | 5880.03 |
184- | 二维 Thread Tile | 34.708 | 92.946 | 557.829 | 3509.920 |
184+
185+ | Algorithm | Matrix 1 | Matrix 2 | Matrix 3 | Matrix 4 |
186+ | ---------------- | -------- | -------- | -------- | -------- |
187+ | Naive | 95.5152 | 724.396 | 28424 | 228681 |
188+ | 共享内存缓存块 | 40.5293 | 198.374 | 8245.68 | 59048.4 |
189+ | 一维 Thread Tile | 35.215 | 174.731 | 894.779 | 5880.03 |
190+ | 二维 Thread Tile | 34.708 | 92.946 | 557.829 | 3509.920 |
185191
186192## 4. 总结
187193
188194本文我们介绍了二维 Thread Tile 并行优化方法。我们将矩阵乘法的计算任务分配给了二维线程块,每个线程块负责计算一个小的矩阵块。这样做的好处是可以充分利用共享内存,减少全局内存的访问次数,从而提高矩阵乘法的性能。
189195
190- ## Reference
196+ ## Reference
191197
1921981 . https://siboehm.com/articles/22/CUDA-MMM
1931992 . https://space.keter.top/docs/high_performance/GEMM%E4%BC%98%E5%8C%96%E4%B8%93%E9%A2%98/%E4%BA%8C%E7%BB%B4Thread%20Tile%E5%B9%B6%E8%A1%8C%E4%BC%98%E5%8C%96
0 commit comments