Skip to content

Commit de1da63

Browse files
cthifacebook-github-bot
authored andcommitted
Save some binary size (#4900)
Summary: Pull Request resolved: #4900 X-link: facebookresearch/FBGEMM#1927 We need to urgently trim down the fbgemm binary size, as new kernels being added are running into relocation issues. Long/mid term, we will need to split up gemm into per-op buck targets and update targets to only pull in what is required, as it's becoming unwieldy to include all gemm targets (quantize_ops), and will likely continue to cause problems in fbcode usage, especially since many targets (e.g. torchAO, Sigrid predictor (MRS + Ads) are pulling in fbgemm now. For now, let's do some house keeping to trim down the lib size. [FP8 Batched GEMM](https://www.internalfb.com/code/search?q=repo%3Afbcode%20torch.ops.fbgemm.f8f8bf16_rowwise_batched&leading_context=5&trailing_context=5): - Remove `fast_accum=False`, as no one uses it. - Only support fp32 bias, remove bf16 bias. Bias itself is unused except for unit test. - Significantly reduce FP8 e5m2 to only a single kernel instance. Its highly unlikely this is used, but hard for us to validate confidently right now. FP8 Int4 Mixed precision GEMM (I beliee this kernel was purely exploratory, and should be unused): - Remove FP8 e5m2 completely Reviewed By: jiawenliu64, q10 Differential Revision: D82842915 fbshipit-source-id: a483080d319aae5b7f24492db4a1403ebf0fab85
1 parent ac8ac3e commit de1da63

14 files changed

+172
-211
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched.cu

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel(
2929
at::Tensor WQ, // FP8
3030
at::Tensor x_scale, // FP32
3131
at::Tensor w_scale, // FP32
32-
bool use_fast_accum = true,
3332
std::optional<at::Tensor> bias = std::nullopt,
3433
std::optional<at::Tensor> output = std::nullopt) {
3534
const int arch = getDeviceArch();
@@ -41,6 +40,14 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel(
4140
M = XQ.size(1);
4241
N = WQ.size(1);
4342

43+
const bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2;
44+
if (use_e5m2) {
45+
TORCH_CHECK(
46+
arch == 9, "f8f8bf16_rowwise_batched only supports FP8 e5m2 on SM90");
47+
return f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f_e5m2(
48+
XQ, WQ, x_scale, w_scale, bias, output);
49+
}
50+
4451
if (arch == 10) {
4552
if ((M * N <= 4096 * 4096) || (N % 256 > 0 && M % 256 == 0) ||
4653
(M % 256 > 0 && N % 256 > 0) || M >= 1024 && N >= 1024) {
@@ -49,21 +56,21 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel(
4956
cute::size(
5057
cute::Shape<cute::Int<2>, cute::Int<1>, cute::Int<1>>{})) {
5158
return f8f8bf16_rowwise_batched_64_128_128_2_1_1_10_f(
52-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
59+
XQ, WQ, x_scale, w_scale, bias, output);
5360
} else {
5461
return f8f8bf16_rowwise_batched_128_128_128_2_1_1_10_t(
55-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
62+
XQ, WQ, x_scale, w_scale, bias, output);
5663
}
5764
} else {
5865
if ((ceildiv(M, 64 * 2) * ceildiv(N, 128 * 1)) <=
5966
kNumSMsForGB200 /
6067
cute::size(
6168
cute::Shape<cute::Int<1>, cute::Int<2>, cute::Int<1>>{})) {
6269
return f8f8bf16_rowwise_batched_64_128_128_1_2_1_10_f(
63-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
70+
XQ, WQ, x_scale, w_scale, bias, output);
6471
} else {
6572
return f8f8bf16_rowwise_batched_128_128_128_1_2_1_10_t(
66-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
73+
XQ, WQ, x_scale, w_scale, bias, output);
6774
}
6875
}
6976
} else {
@@ -74,21 +81,21 @@ at::Tensor dispatch_fp8_rowwise_batched_kernel(
7481
cute::size(
7582
cute::Shape<cute::Int<2>, cute::Int<1>, cute::Int<1>>{})) {
7683
return f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f(
77-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
84+
XQ, WQ, x_scale, w_scale, bias, output);
7885
} else {
7986
return f8f8bf16_rowwise_batched_128_128_128_2_1_1_9_t(
80-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
87+
XQ, WQ, x_scale, w_scale, bias, output);
8188
}
8289
} else {
8390
if ((ceildiv(M, 64 * 2) * ceildiv(N, 128 * 1)) <=
8491
kNumSMsForGB200 /
8592
cute::size(
8693
cute::Shape<cute::Int<1>, cute::Int<2>, cute::Int<1>>{})) {
8794
return f8f8bf16_rowwise_batched_64_128_128_1_2_1_9_f(
88-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
95+
XQ, WQ, x_scale, w_scale, bias, output);
8996
} else {
9097
return f8f8bf16_rowwise_batched_128_128_128_1_2_1_9_t(
91-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
98+
XQ, WQ, x_scale, w_scale, bias, output);
9299
}
93100
}
94101
}
@@ -102,8 +109,10 @@ at::Tensor f8f8bf16_rowwise_batched(
102109
std::optional<at::Tensor> bias = std::nullopt,
103110
bool use_fast_accum = true,
104111
std::optional<at::Tensor> output = std::nullopt) {
112+
TORCH_CHECK(
113+
use_fast_accum, "f8f8bf16_rowwise_batched only supports fast_accum=True");
105114
return dispatch_fp8_rowwise_batched_kernel(
106-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
115+
XQ, WQ, x_scale, w_scale, bias, output);
107116
}
108117

109118
#else

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_128_128_128_1_2_1_10_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@ at::Tensor f8f8bf16_rowwise_batched_128_128_128_1_2_1_10_t(
1515
at::Tensor WQ,
1616
at::Tensor x_scale,
1717
at::Tensor w_scale,
18-
bool use_fast_accum = true,
1918
std::optional<at::Tensor> bias = std::nullopt,
2019
std::optional<at::Tensor> output = std::nullopt) {
2120
// Dispatch this kernel to the correct underlying implementation.
22-
return f8f8bf16_rowwise_batched_wrapper<128, 128, 128, 1, 2, 1, 10, true>(
23-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
21+
return f8f8bf16_rowwise_batched_wrapper<
22+
128,
23+
128,
24+
128,
25+
1,
26+
2,
27+
1,
28+
10,
29+
true,
30+
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
2431
}
2532

2633
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_128_128_128_1_2_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@ at::Tensor f8f8bf16_rowwise_batched_128_128_128_1_2_1_9_t(
1515
at::Tensor WQ,
1616
at::Tensor x_scale,
1717
at::Tensor w_scale,
18-
bool use_fast_accum = true,
1918
std::optional<at::Tensor> bias = std::nullopt,
2019
std::optional<at::Tensor> output = std::nullopt) {
2120
// Dispatch this kernel to the correct underlying implementation.
22-
return f8f8bf16_rowwise_batched_wrapper<128, 128, 128, 1, 2, 1, 9, true>(
23-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
21+
return f8f8bf16_rowwise_batched_wrapper<
22+
128,
23+
128,
24+
128,
25+
1,
26+
2,
27+
1,
28+
9,
29+
true,
30+
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
2431
}
2532

2633
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_128_128_128_2_1_1_10_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@ at::Tensor f8f8bf16_rowwise_batched_128_128_128_2_1_1_10_t(
1515
at::Tensor WQ,
1616
at::Tensor x_scale,
1717
at::Tensor w_scale,
18-
bool use_fast_accum = true,
1918
std::optional<at::Tensor> bias = std::nullopt,
2019
std::optional<at::Tensor> output = std::nullopt) {
2120
// Dispatch this kernel to the correct underlying implementation.
22-
return f8f8bf16_rowwise_batched_wrapper<128, 128, 128, 2, 1, 1, 10, true>(
23-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
21+
return f8f8bf16_rowwise_batched_wrapper<
22+
128,
23+
128,
24+
128,
25+
2,
26+
1,
27+
1,
28+
10,
29+
true,
30+
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
2431
}
2532

2633
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_128_128_128_2_1_1_9_t.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@ at::Tensor f8f8bf16_rowwise_batched_128_128_128_2_1_1_9_t(
1515
at::Tensor WQ,
1616
at::Tensor x_scale,
1717
at::Tensor w_scale,
18-
bool use_fast_accum = true,
1918
std::optional<at::Tensor> bias = std::nullopt,
2019
std::optional<at::Tensor> output = std::nullopt) {
2120
// Dispatch this kernel to the correct underlying implementation.
22-
return f8f8bf16_rowwise_batched_wrapper<128, 128, 128, 2, 1, 1, 9, true>(
23-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
21+
return f8f8bf16_rowwise_batched_wrapper<
22+
128,
23+
128,
24+
128,
25+
2,
26+
1,
27+
1,
28+
9,
29+
true,
30+
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
2431
}
2532

2633
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_64_128_128_1_2_1_10_f.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@ at::Tensor f8f8bf16_rowwise_batched_64_128_128_1_2_1_10_f(
1515
at::Tensor WQ,
1616
at::Tensor x_scale,
1717
at::Tensor w_scale,
18-
bool use_fast_accum = true,
1918
std::optional<at::Tensor> bias = std::nullopt,
2019
std::optional<at::Tensor> output = std::nullopt) {
2120
// Dispatch this kernel to the correct underlying implementation.
22-
return f8f8bf16_rowwise_batched_wrapper<64, 128, 128, 1, 2, 1, 10, false>(
23-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
21+
return f8f8bf16_rowwise_batched_wrapper<
22+
64,
23+
128,
24+
128,
25+
1,
26+
2,
27+
1,
28+
10,
29+
false,
30+
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
2431
}
2532

2633
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_64_128_128_1_2_1_9_f.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@ at::Tensor f8f8bf16_rowwise_batched_64_128_128_1_2_1_9_f(
1515
at::Tensor WQ,
1616
at::Tensor x_scale,
1717
at::Tensor w_scale,
18-
bool use_fast_accum = true,
1918
std::optional<at::Tensor> bias = std::nullopt,
2019
std::optional<at::Tensor> output = std::nullopt) {
2120
// Dispatch this kernel to the correct underlying implementation.
22-
return f8f8bf16_rowwise_batched_wrapper<64, 128, 128, 1, 2, 1, 9, false>(
23-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
21+
return f8f8bf16_rowwise_batched_wrapper<
22+
64,
23+
128,
24+
128,
25+
1,
26+
2,
27+
1,
28+
9,
29+
false,
30+
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
2431
}
2532

2633
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_64_128_128_2_1_1_10_f.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@ at::Tensor f8f8bf16_rowwise_batched_64_128_128_2_1_1_10_f(
1515
at::Tensor WQ,
1616
at::Tensor x_scale,
1717
at::Tensor w_scale,
18-
bool use_fast_accum = true,
1918
std::optional<at::Tensor> bias = std::nullopt,
2019
std::optional<at::Tensor> output = std::nullopt) {
2120
// Dispatch this kernel to the correct underlying implementation.
22-
return f8f8bf16_rowwise_batched_wrapper<64, 128, 128, 2, 1, 1, 10, false>(
23-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
21+
return f8f8bf16_rowwise_batched_wrapper<
22+
64,
23+
128,
24+
128,
25+
2,
26+
1,
27+
1,
28+
10,
29+
false,
30+
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
2431
}
2532

2633
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_batched/f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f.cu

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@ at::Tensor f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f(
1515
at::Tensor WQ,
1616
at::Tensor x_scale,
1717
at::Tensor w_scale,
18-
bool use_fast_accum = true,
1918
std::optional<at::Tensor> bias = std::nullopt,
2019
std::optional<at::Tensor> output = std::nullopt) {
2120
// Dispatch this kernel to the correct underlying implementation.
22-
return f8f8bf16_rowwise_batched_wrapper<64, 128, 128, 2, 1, 1, 9, false>(
23-
XQ, WQ, x_scale, w_scale, use_fast_accum, bias, output);
21+
return f8f8bf16_rowwise_batched_wrapper<
22+
64,
23+
128,
24+
128,
25+
2,
26+
1,
27+
1,
28+
9,
29+
false,
30+
cutlass::float_e4m3_t>(XQ, WQ, x_scale, w_scale, bias, output);
2431
}
2532

2633
} // namespace fbgemm_gpu
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f8f8bf16_rowwise_batched_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor f8f8bf16_rowwise_batched_64_128_128_2_1_1_9_f_e5m2(
14+
at::Tensor XQ,
15+
at::Tensor WQ,
16+
at::Tensor x_scale,
17+
at::Tensor w_scale,
18+
std::optional<at::Tensor> bias = std::nullopt,
19+
std::optional<at::Tensor> output = std::nullopt) {
20+
// Dispatch this kernel to the correct underlying implementation.
21+
return f8f8bf16_rowwise_batched_wrapper<
22+
64,
23+
128,
24+
128,
25+
2,
26+
1,
27+
1,
28+
9,
29+
false,
30+
cutlass::float_e5m2_t>(XQ, WQ, x_scale, w_scale, bias, output);
31+
}
32+
33+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)