Skip to content

Commit dc0ab6d

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Enable CUTLASS grouped GEMM for llama4x pretraining grad on GB200 and H100 (#4856)
Summary: Pull Request resolved: #4856 X-link: facebookresearch/FBGEMM#1869 Enable CUTLASS grouped GEMM for llama4x pretraining grad on GB200 and H100 Next steps: 1. Currently enabled dgrad. Will build a new kernel for wgrad as followup 2. Will further optimize perf on GB200 Reviewed By: jwfromm Differential Revision: D81997154 fbshipit-source-id: 69366ac46e9881ff21d50066295f109ef5d2af97
1 parent c341f82 commit dc0ab6d

File tree

53 files changed

+2436
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+2436
-0
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,6 +2081,51 @@ def cuda(self) -> bool:
20812081
return True
20822082

20832083

2084+
@register_quantize_op
2085+
class BF16GroupedGrad(QuantizeOpBase):
2086+
"""
2087+
BF16 grouped matmul with grad inputs backed by cutlass
2088+
"""
2089+
2090+
def preprocess(self, x, w):
2091+
m_values = [i.shape[0] for i in x]
2092+
# Convert m_values into offsets into grouped tensor.
2093+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2094+
# Group weights as single tensor.
2095+
w = torch.stack(w, dim=0).contiguous()
2096+
# Prepare online dgrad during pretraining backward.
2097+
w_perm = w.permute(0, 2, 1).contiguous()
2098+
# w.contiguous() is very expensive so handling it inside the gmm kernel for free
2099+
w = w_perm.permute(0, 2, 1)
2100+
2101+
# Also view input as flattened.
2102+
x = torch.concat(x, dim=0).contiguous()
2103+
# Return processed tensors.
2104+
return x, w, m_sizes
2105+
2106+
def quantize(self, x, w, m_sizes):
2107+
return x, w, m_sizes
2108+
2109+
def compute(self, x, w, m_sizes):
2110+
return torch.ops.fbgemm.bf16bf16bf16_grouped_grad(x, w, m_sizes)
2111+
2112+
def quantize_and_compute(self, x, w, m_sizes):
2113+
x, w, m_sizes = self.quantize(x, w, m_sizes)
2114+
return self.compute(x, w, m_sizes)
2115+
2116+
@property
2117+
def name(self) -> str:
2118+
return "bf16_grouped_grad"
2119+
2120+
@property
2121+
def hip(self) -> bool:
2122+
return False
2123+
2124+
@property
2125+
def cuda(self) -> bool:
2126+
return True
2127+
2128+
20842129
@register_quantize_op
20852130
class BF16GroupedStacked(QuantizeOpBase):
20862131
"""
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
12+
#include "bf16bf16bf16_grouped_grad/bf16bf16bf16_grouped_grad_manifest.cuh"
13+
#include "fbgemm_gpu/quantize/tuning_cache.hpp"
14+
#include "fbgemm_gpu/quantize/utils.h"
15+
16+
namespace fbgemm_gpu {
17+
18+
#if CUDART_VERSION >= 12000
19+
20+
namespace {
21+
TuningCache& getTuningCache() {
22+
static TuningCache cache("bf16bf16bf16_grouped_grad");
23+
return cache;
24+
}
25+
} // namespace
26+
27+
Kernel_bf16bf16bf16_grouped_grad
28+
get_kernel_via_heuristic(int arch, int G, int total_M, int N, int K) {
29+
// Use heuristics to pick best kernel implementation.
30+
if (arch == 10) {
31+
// Llama4 shapes
32+
if ((N == 5120 && K == 1024) || (N == 2048 && K == 5120)) {
33+
if (total_M <= 256) {
34+
return bf16bf16bf16_grouped_grad_256_32_128_2_1_1_10_f;
35+
} else if (total_M <= 512) {
36+
return bf16bf16bf16_grouped_grad_256_64_128_2_1_1_10_f;
37+
} else if (total_M <= 1024) {
38+
return bf16bf16bf16_grouped_grad_256_128_128_2_1_1_10_f;
39+
} else {
40+
return bf16bf16bf16_grouped_grad_256_256_128_2_1_1_10_f;
41+
}
42+
}
43+
44+
// Fallback to legacy heuristic.
45+
if (total_M <= 64 || (total_M <= 256 and N <= 1024)) {
46+
if (K <= 4096) {
47+
return bf16bf16bf16_grouped_grad_256_32_128_2_1_1_10_f;
48+
} else {
49+
return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_10_f;
50+
}
51+
} else if (total_M <= 512) {
52+
if (N <= 1024) {
53+
return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_10_f;
54+
} else if (N <= 8192) {
55+
if (K <= 2048) {
56+
return bf16bf16bf16_grouped_grad_256_32_128_2_1_1_10_f;
57+
} else if (K <= 4096) {
58+
return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_10_f;
59+
} else {
60+
return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_10_f;
61+
}
62+
}
63+
} else if (total_M <= 1024) {
64+
if (N <= 1024) {
65+
return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_10_f;
66+
} else if (N <= 8192) {
67+
if (K <= 2048) {
68+
return bf16bf16bf16_grouped_grad_256_64_128_2_1_1_10_f;
69+
} else if (K <= 4096) {
70+
return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_10_f;
71+
} else {
72+
return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_10_f;
73+
}
74+
}
75+
} else if (total_M <= 2048) {
76+
if (N <= 1024) {
77+
return bf16bf16bf16_grouped_grad_256_256_128_2_1_1_10_f;
78+
} else if (N <= 8192) {
79+
if (K <= 2048) {
80+
return bf16bf16bf16_grouped_grad_256_128_128_2_1_1_10_f;
81+
} else if (K <= 4096) {
82+
return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_10_f;
83+
}
84+
}
85+
}
86+
return bf16bf16bf16_grouped_grad_256_256_128_2_1_1_10_f;
87+
} else {
88+
// Llama4 128E
89+
if (G == 128) {
90+
if (N == 5120 && K == 1024) {
91+
if (total_M <= 128) {
92+
return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f;
93+
} else if (total_M <= 256) {
94+
return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f;
95+
} else if (total_M <= 2048) {
96+
return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f;
97+
} else if (total_M <= 4096) {
98+
return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f;
99+
} else if (total_M <= 8192) {
100+
return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
101+
} else if (total_M <= 16384) {
102+
return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_t;
103+
} else {
104+
return bf16bf16bf16_grouped_grad_128_256_128_2_1_1_9_f;
105+
}
106+
}
107+
108+
if (N == 2048 && K == 5120) {
109+
if (total_M <= 2048) {
110+
return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f;
111+
} else {
112+
return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_t;
113+
}
114+
}
115+
}
116+
117+
// Llama4 64E
118+
if (G == 16) {
119+
if (N == 5120 && K == 1024) {
120+
if (total_M <= 32) {
121+
return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f;
122+
} else if (total_M <= 64) {
123+
return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f;
124+
} else if (total_M <= 256) {
125+
return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f;
126+
} else if (total_M <= 512) {
127+
return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f;
128+
} else if (total_M <= 1024) {
129+
return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_9_f;
130+
} else {
131+
return bf16bf16bf16_grouped_grad_128_256_128_2_1_1_9_f;
132+
}
133+
}
134+
135+
if (N == 2048 && K == 5120) {
136+
if (total_M <= 16) {
137+
return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f;
138+
} else if (total_M <= 64) {
139+
return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f;
140+
} else if (total_M <= 256) {
141+
return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f;
142+
} else if (total_M <= 512) {
143+
return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f;
144+
} else if (total_M <= 1024) {
145+
return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
146+
} else {
147+
return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_t;
148+
}
149+
}
150+
}
151+
152+
// Llama4.x pretraining
153+
if (N == 1280 && K == 5120) {
154+
if (total_M <= 256) {
155+
return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f;
156+
} else if (total_M <= 1024) {
157+
return bf16bf16bf16_grouped_grad_128_64_128_2_2_1_9_f;
158+
} else if (total_M <= 4096) {
159+
return bf16bf16bf16_grouped_grad_128_128_128_2_2_1_9_t;
160+
} else {
161+
return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
162+
}
163+
} else if (N == 2560 && K == 5120) {
164+
if (total_M <= 256) {
165+
return bf16bf16bf16_grouped_grad_128_64_128_2_2_1_9_f;
166+
} else if (total_M <= 1024) {
167+
return bf16bf16bf16_grouped_grad_128_64_128_2_2_1_9_f;
168+
} else {
169+
return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
170+
}
171+
} else if (N == 1536 && K == 6144) {
172+
if (total_M <= 256) {
173+
return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f;
174+
} else if (total_M <= 1024) {
175+
return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
176+
} else if (total_M <= 4096) {
177+
return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_t;
178+
} else {
179+
return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
180+
}
181+
} else if (N == 3072 && K == 6144) {
182+
if (total_M <= 256) {
183+
return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_9_f;
184+
} else if (total_M <= 4096) {
185+
return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_t;
186+
} else {
187+
return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
188+
}
189+
} else if (N == 5120 && K == 2560) {
190+
if (total_M <= 256) {
191+
return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_f;
192+
} else if (total_M <= 1024) {
193+
return bf16bf16bf16_grouped_grad_128_128_128_2_2_1_9_t;
194+
} else if (total_M <= 4096) {
195+
return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
196+
} else {
197+
return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
198+
}
199+
} else if (N == 5120 && K == 5120) {
200+
if (total_M <= 256) {
201+
return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_t;
202+
} else if (total_M <= 1024) {
203+
return bf16bf16bf16_grouped_grad_128_128_128_2_2_1_9_t;
204+
} else {
205+
return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
206+
}
207+
} else if (N == 6144 && K == 3072) {
208+
if (total_M <= 256) {
209+
return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f;
210+
} else if (total_M <= 1024) {
211+
return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f;
212+
} else if (total_M <= 4096) {
213+
return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_t;
214+
} else {
215+
return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
216+
}
217+
} else if (N == 6144 && K == 6144) {
218+
return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t;
219+
}
220+
221+
// Fallback to legacy heuristic for now.
222+
if (total_M <= 16) {
223+
return bf16bf16bf16_grouped_grad_128_16_128_1_1_1_9_f;
224+
} else if (total_M <= 32) {
225+
return bf16bf16bf16_grouped_grad_128_32_128_1_1_1_9_f;
226+
} else if (total_M <= 64) {
227+
return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f;
228+
} else if (total_M <= 128) {
229+
return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f;
230+
} else if (total_M <= 512) {
231+
return bf16bf16bf16_grouped_grad_256_128_128_2_1_1_9_f;
232+
} else {
233+
return bf16bf16bf16_grouped_grad_128_256_128_2_1_1_9_f;
234+
}
235+
}
236+
}
237+
238+
Kernel_bf16bf16bf16_grouped_grad get_kernel_via_tuning(
239+
int arch,
240+
int G,
241+
int total_M,
242+
int N,
243+
int K,
244+
at::Tensor X, // BF16
245+
at::Tensor W, // BF16
246+
at::Tensor output,
247+
std::optional<at::Tensor> M_sizes = std::nullopt) {
248+
auto& cache = getTuningCache();
249+
250+
// Reducing amount of auto tuning by rounding up total_m to next power of 2.
251+
total_M = nextPowerOf2(total_M);
252+
// Use (total_M, N, K, G) shape as the key.
253+
const std::string shape_key = std::to_string(total_M) + "_" +
254+
std::to_string(N) + "_" + std::to_string(K) + "_" + std::to_string(G);
255+
const auto& kernels = get_bf16bf16bf16_grouped_grad_kernels(arch);
256+
auto kernel = cache.findBestKernelMaybeAutotune(
257+
shape_key, kernels, X, W, output, M_sizes);
258+
259+
return kernel;
260+
}
261+
262+
// BF16 grouped cutlass kernel dispatch.
263+
at::Tensor dispatch_bf16_grouped_kernel(
264+
int G,
265+
int total_M,
266+
int N,
267+
int K,
268+
at::Tensor X, // BF16
269+
at::Tensor W, // BF16
270+
at::Tensor output,
271+
std::optional<at::Tensor> M_sizes = std::nullopt) {
272+
static int arch = -1;
273+
// Avoid expensive cudaGetDeviceProperties call.
274+
if (arch < 0) {
275+
cudaDeviceProp prop;
276+
cudaGetDeviceProperties(&prop, 0);
277+
if (prop.major >= 10) {
278+
arch = 10;
279+
int runtimeVersion;
280+
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
281+
TORCH_CHECK(
282+
runtimeVersion >= 12080,
283+
"FP8 grouped GEMM on sm100a or above requires cuda >= 12.8");
284+
} else {
285+
arch = 9;
286+
}
287+
}
288+
289+
// Select kernel to run via heuristics or tuning.
290+
auto kernel = [&]() {
291+
if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) {
292+
return get_kernel_via_tuning(
293+
arch, G, total_M, N, K, X, W, output, M_sizes);
294+
} else {
295+
return get_kernel_via_heuristic(arch, G, total_M, N, K);
296+
}
297+
}();
298+
// Invoke kernel
299+
return kernel(X, W, output, M_sizes);
300+
}
301+
302+
at::Tensor
303+
bf16bf16bf16_grouped_grad(at::Tensor X, at::Tensor W, at::Tensor M_sizes) {
304+
int64_t total_M = X.size(0);
305+
int64_t N = W.size(1);
306+
int64_t K = W.size(2);
307+
int64_t G = M_sizes.size(0);
308+
TORCH_CHECK(
309+
M_sizes.device() == X.device(),
310+
"M_sizes must be on same device as inputs.");
311+
TORCH_CHECK(
312+
W.dim() == 3 && W.size(0) == G, "Weights should be shape [G, N, K].")
313+
314+
TORCH_CHECK(X.stride(-1) == 1, "Activation memory layout must be row-major.");
315+
TORCH_CHECK(W.stride(-2) == 1, "Weight memory layout must be column-major.");
316+
317+
at::Tensor Y = at::empty(total_M * N, X.options().dtype(at::kBFloat16));
318+
// Early exit for empty inputs.
319+
if (total_M == 0) {
320+
return Y.view({total_M, N});
321+
}
322+
// Return continuous view of output.
323+
at::Tensor out =
324+
dispatch_bf16_grouped_kernel(G, total_M, N, K, X, W, Y, M_sizes);
325+
return out.view({total_M, N});
326+
}
327+
328+
#else
329+
330+
at::Tensor bf16bf16bf16_grouped_grad(at::Tensor, at::Tensor, at::Tensor) {
331+
throw std::runtime_error(
332+
"CUDA version is older than 12.0"); // requires CUDA>=12
333+
}
334+
335+
#endif
336+
337+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)