Skip to content

Commit c341f82

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Improve CUTLASS GMM for llama4x pretraining fprop (#4855)
Summary: Pull Request resolved: #4855 X-link: facebookresearch/FBGEMM#1868 Optimize BF16 CUTLASS GMM to bring 1.1x - 1.3x speedup for llama4x pretraining fprop shapes More results can be found in this [spreadsheet](https://docs.google.com/spreadsheets/d/172Nm0F9K6XJenNFoNFqC5Sp1Ll2KhLtfOJpIfkuHDzc/edit?usp=sharing) Reviewed By: jwfromm Differential Revision: D81704026 fbshipit-source-id: 9919e05f8915c6c5db4a44d580a41ee2d997c00c
1 parent 6bd58c8 commit c341f82

8 files changed

+406
-0
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,76 @@ get_kernel_via_heuristic(int arch, int G, int total_M, int N, int K) {
152152
}
153153
}
154154

155+
// Llama4.x pretraining
156+
if (N == 2560 && K == 5120) {
157+
if (total_M <= 256) {
158+
return bf16bf16bf16_grouped_128_64_128_2_2_1_9_f;
159+
} else if (total_M <= 512) {
160+
return bf16bf16bf16_grouped_128_128_128_2_1_1_9_f;
161+
} else if (total_M <= 1024) {
162+
return bf16bf16bf16_grouped_128_128_128_2_2_1_9_t;
163+
} else {
164+
return bf16bf16bf16_grouped_128_128_128_1_2_1_9_t;
165+
}
166+
} else if (N == 5120 && K == 5120) {
167+
if (total_M <= 256) {
168+
return bf16bf16bf16_grouped_128_128_128_2_1_1_9_f;
169+
} else if (total_M <= 1024) {
170+
return bf16bf16bf16_grouped_128_128_128_2_2_1_9_t;
171+
} else if (total_M <= 4096) {
172+
return bf16bf16bf16_grouped_128_128_128_1_2_1_9_t;
173+
} else {
174+
return bf16bf16bf16_grouped_128_128_128_4_4_1_9_t;
175+
}
176+
} else if (N == 3072 && K == 6144) {
177+
if (total_M <= 512) {
178+
return bf16bf16bf16_grouped_128_128_128_2_1_1_9_f;
179+
} else if (total_M <= 1024) {
180+
return bf16bf16bf16_grouped_128_128_128_2_2_1_9_t;
181+
} else if (total_M <= 2048) {
182+
return bf16bf16bf16_grouped_128_128_128_2_1_1_9_t;
183+
} else {
184+
return bf16bf16bf16_grouped_128_128_128_1_2_1_9_t;
185+
}
186+
} else if (N == 6144 && K == 6144) {
187+
if (total_M <= 512) {
188+
return bf16bf16bf16_grouped_128_128_128_4_1_1_9_f;
189+
} else if (total_M <= 1024) {
190+
return bf16bf16bf16_grouped_128_128_128_1_2_1_9_t;
191+
} else {
192+
return bf16bf16bf16_grouped_128_128_128_4_4_1_9_t;
193+
}
194+
195+
} else if (N == 5120 && K == 1280) {
196+
if (total_M <= 256) {
197+
return bf16bf16bf16_grouped_128_128_128_4_1_1_9_f;
198+
} else {
199+
return bf16bf16bf16_grouped_128_128_128_1_2_1_9_t;
200+
}
201+
} else if (N == 5120 && K == 2560) {
202+
if (total_M <= 256) {
203+
return bf16bf16bf16_grouped_128_128_128_1_2_1_9_f;
204+
} else if (total_M <= 1024) {
205+
return bf16bf16bf16_grouped_128_128_128_2_2_1_9_t;
206+
} else {
207+
return bf16bf16bf16_grouped_128_128_128_1_2_1_9_t;
208+
}
209+
} else if (N == 6144 && K == 1536) {
210+
if (total_M <= 4096) {
211+
return bf16bf16bf16_grouped_128_128_128_1_2_1_9_f;
212+
} else {
213+
return bf16bf16bf16_grouped_128_128_128_1_2_1_9_t;
214+
}
215+
} else if (N == 6144 && K == 3072) {
216+
if (total_M <= 256) {
217+
return bf16bf16bf16_grouped_128_128_128_1_2_1_9_f;
218+
} else if (total_M <= 4096) {
219+
return bf16bf16bf16_grouped_128_128_128_1_2_1_9_t;
220+
} else {
221+
return bf16bf16bf16_grouped_128_128_128_1_4_1_9_t;
222+
}
223+
}
224+
155225
// Fallback to legacy heuristic for now.
156226
if (total_M <= 16) {
157227
return bf16bf16bf16_grouped_128_16_128_1_1_1_9_f;
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 2, 1, false>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_f(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
128,
33+
128,
34+
1,
35+
2,
36+
1,
37+
false>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 2, 1, true>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_128_128_1_2_1_9_t(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
128,
33+
128,
34+
1,
35+
2,
36+
1,
37+
true>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 1, 4, 1, true>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_128_128_1_4_1_9_t(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
128,
33+
128,
34+
1,
35+
4,
36+
1,
37+
true>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 2, 2, 1, true>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_128_128_2_2_1_9_t(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
128,
33+
128,
34+
2,
35+
2,
36+
1,
37+
true>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_128_128_4_4_1_9_t(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 128, 128, 4, 4, 1, true>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_128_128_4_4_1_9_t(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
128,
33+
128,
34+
4,
35+
4,
36+
1,
37+
true>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "bf16bf16bf16_grouped_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor bf16bf16bf16_grouped_128_64_128_2_2_1_9_f(
14+
at::Tensor X, // BF16
15+
at::Tensor W, // BF16
16+
at::Tensor output,
17+
std::optional<at::Tensor> zero_start_index_M,
18+
std::optional<at::Tensor> M_sizes) {
19+
return bf16bf16bf16_grouped_impl<at::Tensor, 128, 64, 128, 2, 2, 1, false>(
20+
X, W, output, zero_start_index_M, M_sizes);
21+
}
22+
23+
at::Tensor bf16bf16bf16_grouped_128_64_128_2_2_1_9_f(
24+
at::TensorList X, // BF16
25+
at::TensorList W, // BF16
26+
at::Tensor output,
27+
std::optional<at::Tensor> zero_start_index_M,
28+
std::optional<at::Tensor> M_sizes) {
29+
return bf16bf16bf16_grouped_impl<
30+
at::TensorList,
31+
128,
32+
64,
33+
128,
34+
2,
35+
2,
36+
1,
37+
false>(X, W, output, zero_start_index_M, M_sizes);
38+
}
39+
40+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)