Skip to content

Commit c1f22a9

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Add output as an option in CUTLASS grouped GEMM (#4931)
Summary: Pull Request resolved: #4931 X-link: facebookresearch/FBGEMM#1954 Enable output as an option in CUTLASS grouped GEMM, as pretraining requires assigning empty preallocated output tensor for usecases in fprop and dgrad. Reviewed By: cthi Differential Revision: D83126291 fbshipit-source-id: 26761307d472f9421f115a8c83cc01ceaf28c7ce
1 parent bb87e43 commit c1f22a9

File tree

6 files changed

+217
-42
lines changed

6 files changed

+217
-42
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/bf16_grouped/bf16_grouped_gemm.hip

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,8 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
515515
at::Tensor bf16bf16bf16_grouped_stacked(
516516
at::Tensor X,
517517
at::Tensor W,
518-
at::Tensor M_sizes) {
518+
at::Tensor M_sizes,
519+
std::optional<at::Tensor> out) {
519520
// Check that input datatypes are valid.
520521
// First confirm that there are the same number of groups in all inputs.
521522
int64_t group_count = M_sizes.size(0);

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped.cu

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,11 @@ at::Tensor bf16bf16bf16_grouped_cat(at::TensorList X, at::TensorList W) {
345345
return _bf16bf16bf16_grouped<at::Tensor>(X, W);
346346
}
347347

348-
at::Tensor
349-
bf16bf16bf16_grouped_stacked(at::Tensor X, at::Tensor W, at::Tensor M_sizes) {
348+
at::Tensor bf16bf16bf16_grouped_stacked(
349+
at::Tensor X,
350+
at::Tensor W,
351+
at::Tensor M_sizes,
352+
std::optional<at::Tensor> out) {
350353
int64_t total_M = X.size(0);
351354
int64_t N = W.size(1);
352355
int64_t K = W.size(2);
@@ -356,15 +359,22 @@ bf16bf16bf16_grouped_stacked(at::Tensor X, at::Tensor W, at::Tensor M_sizes) {
356359
"M_sizes must be on same device as inputs.");
357360
TORCH_CHECK(
358361
W.dim() == 3 && W.size(0) == G, "Weights should be shape [G, N, K].")
359-
at::Tensor Y = at::empty(total_M * N, X.options().dtype(at::kBFloat16));
362+
363+
at::Tensor Y;
364+
if (out.has_value()) {
365+
Y = out.value();
366+
} else {
367+
Y = at::empty(total_M * N, X.options().dtype(at::kBFloat16));
368+
}
369+
360370
// Early exit for empty inputs.
361371
if (total_M == 0) {
362372
return Y.view({total_M, N});
363373
}
364374
// Return continuous view of output.
365-
at::Tensor out = dispatch_bf16_grouped_kernel<at::Tensor>(
375+
at::Tensor output = dispatch_bf16_grouped_kernel<at::Tensor>(
366376
G, total_M, N, K, X, W, Y, std::nullopt, M_sizes);
367-
return out.view({total_M, N});
377+
return output.view({total_M, N});
368378
}
369379

370380
at::Tensor bf16bf16bf16_grouped_dynamic(
@@ -411,7 +421,11 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
411421
"CUDA version is older than 12.0"); // requires CUDA>=12
412422
}
413423

414-
at::Tensor bf16bf16bf16_grouped_stacked(at::Tensor, at::Tensor, at::Tensor) {
424+
at::Tensor bf16bf16bf16_grouped_stacked(
425+
at::Tensor,
426+
at::Tensor,
427+
at::Tensor,
428+
std::optional<at::Tensor>) {
415429
throw std::runtime_error(
416430
"CUDA version is older than 12.0"); // requires CUDA>=12
417431
}

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_grad.cu

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,11 @@ at::Tensor dispatch_bf16_grouped_kernel(
300300
return kernel(X, W, output, M_sizes);
301301
}
302302

303-
at::Tensor
304-
bf16bf16bf16_grouped_grad(at::Tensor X, at::Tensor W, at::Tensor M_sizes) {
303+
at::Tensor bf16bf16bf16_grouped_grad(
304+
at::Tensor X,
305+
at::Tensor W,
306+
at::Tensor M_sizes,
307+
std::optional<at::Tensor> out) {
305308
int64_t total_M = X.size(0);
306309
int64_t N = W.size(1);
307310
int64_t K = W.size(2);
@@ -315,20 +318,29 @@ bf16bf16bf16_grouped_grad(at::Tensor X, at::Tensor W, at::Tensor M_sizes) {
315318
TORCH_CHECK(X.stride(-1) == 1, "Activation memory layout must be row-major.");
316319
TORCH_CHECK(W.stride(-2) == 1, "Weight memory layout must be column-major.");
317320

318-
at::Tensor Y = at::empty(total_M * N, X.options().dtype(at::kBFloat16));
321+
at::Tensor Y;
322+
if (out.has_value()) {
323+
Y = out.value();
324+
} else {
325+
Y = at::empty(total_M * N, X.options().dtype(at::kBFloat16));
326+
}
319327
// Early exit for empty inputs.
320328
if (total_M == 0) {
321329
return Y.view({total_M, N});
322330
}
323331
// Return continuous view of output.
324-
at::Tensor out =
332+
at::Tensor output =
325333
dispatch_bf16_grouped_kernel(G, total_M, N, K, X, W, Y, M_sizes);
326-
return out.view({total_M, N});
334+
return output.view({total_M, N});
327335
}
328336

329337
#else
330338

331-
at::Tensor bf16bf16bf16_grouped_grad(at::Tensor, at::Tensor, at::Tensor) {
339+
at::Tensor bf16bf16bf16_grouped_grad(
340+
at::Tensor,
341+
at::Tensor,
342+
at::Tensor,
343+
std::optional<at::Tensor>) {
332344
throw std::runtime_error(
333345
"CUDA version is older than 12.0"); // requires CUDA>=12
334346
}
@@ -338,12 +350,18 @@ at::Tensor bf16bf16bf16_grouped_grad(at::Tensor, at::Tensor, at::Tensor) {
338350
at::Tensor bf16bf16bf16_grouped_grad_meta(
339351
at::Tensor X,
340352
at::Tensor W,
341-
at::Tensor /* M_sizes */) {
353+
at::Tensor /* M_sizes */,
354+
std::optional<at::Tensor> out) {
342355
const at::SymInt total_M = X.sym_size(0);
343356
const at::SymInt N = W.sym_size(1);
344-
at::Tensor Y =
345-
at::empty_symint({total_M, N}, X.options().dtype(at::kBFloat16));
346-
return Y;
357+
358+
if (out.has_value()) {
359+
return out.value();
360+
} else {
361+
at::Tensor output =
362+
at::empty_symint({total_M, N}, X.options().dtype(at::kBFloat16));
363+
return output;
364+
}
347365
}
348366

349367
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
@@ -356,7 +374,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
356374

357375
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
358376
m.def(
359-
"bf16bf16bf16_grouped_grad(Tensor X, Tensor W, Tensor M_sizes) -> Tensor");
377+
"bf16bf16bf16_grouped_grad(Tensor X, Tensor W, Tensor M_sizes, Tensor? out=None) -> Tensor");
360378
}
361379

362380
} // namespace fbgemm_gpu

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,11 @@ at::Tensor bf16bf16bf16_grouped_dynamic(
7676
at::Tensor X,
7777
at::Tensor W,
7878
at::Tensor zero_start_index_M);
79-
at::Tensor
80-
bf16bf16bf16_grouped_stacked(at::Tensor X, at::Tensor W, at::Tensor M_sizes);
79+
at::Tensor bf16bf16bf16_grouped_stacked(
80+
at::Tensor X,
81+
at::Tensor W,
82+
at::Tensor M_sizes,
83+
std::optional<at::Tensor> out = std::nullopt);
8184
at::Tensor f8f8bf16_rowwise(
8285
at::Tensor XQ,
8386
at::Tensor WQ,
@@ -781,12 +784,18 @@ at::Tensor bf16bf16bf16_grouped_dynamic_meta(
781784
at::Tensor bf16bf16bf16_grouped_stacked_meta(
782785
at::Tensor X,
783786
at::Tensor W,
784-
at::Tensor /* M_sizes */) {
787+
at::Tensor /* M_sizes */,
788+
std::optional<at::Tensor> out) {
785789
const at::SymInt total_M = X.sym_size(0);
786790
const at::SymInt N = W.sym_size(1);
787-
at::Tensor Y =
788-
at::empty_symint({total_M, N}, X.options().dtype(at::kBFloat16));
789-
return Y;
791+
792+
if (out.has_value()) {
793+
return out.value();
794+
} else {
795+
at::Tensor output =
796+
at::empty_symint({total_M, N}, X.options().dtype(at::kBFloat16));
797+
return output;
798+
}
790799
}
791800

792801
at::Tensor f8f8bf16_rowwise_grouped_stacked_meta(

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
6363
m.def(
6464
"bf16bf16bf16_grouped_dynamic(Tensor X, Tensor W, Tensor zero_start_index_M) -> Tensor");
6565
m.def(
66-
"bf16bf16bf16_grouped_stacked(Tensor X, Tensor W, Tensor M_sizes) -> Tensor");
66+
"bf16bf16bf16_grouped_stacked(Tensor X, Tensor W, Tensor M_sizes, Tensor? out=None) -> Tensor");
6767
m.def(
6868
"f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m=128, int block_n=128, int block_k=128) -> Tensor");
6969
m.def(

fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py

Lines changed: 150 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,9 +2208,23 @@ class BF16Tests(unittest.TestCase):
22082208
def setUpClass(cls):
22092209
cls.device = torch.accelerator.current_accelerator()
22102210

2211+
def generate_random_splits(G: int, M: int) -> torch.Tensor:
2212+
m_cumsums = torch.sort(
2213+
torch.randint(
2214+
0,
2215+
M,
2216+
(G + 1,),
2217+
dtype=torch.int32,
2218+
device=torch.accelerator.current_accelerator(),
2219+
)
2220+
).values
2221+
m_cumsums[0], m_cumsums[-1] = 0, M
2222+
m_sizes = m_cumsums[1:] - m_cumsums[:-1]
2223+
return m_sizes
2224+
22112225
@unittest.skipIf(
22122226
not torch.version.cuda,
2213-
"Skip on AMD: test_bf16_grouped_gemmw_wgrad not yet suported.",
2227+
"Skip on AMD: test_grouped_gemm_wgrad not yet suported.",
22142228
)
22152229
@settings(deadline=None)
22162230
@given(
@@ -2220,7 +2234,7 @@ def setUpClass(cls):
22202234
K=st.sampled_from([128, 1024]),
22212235
output_accum=st.booleans(),
22222236
)
2223-
def test_bf16_grouped_gemmw_wgrad(
2237+
def test_grouped_gemm_wgrad(
22242238
self,
22252239
G: int,
22262240
M: int,
@@ -2237,21 +2251,7 @@ def test_bf16_grouped_gemmw_wgrad(
22372251
(M, K), dtype=torch.bfloat16, device=torch.accelerator.current_accelerator()
22382252
)
22392253

2240-
def generate_random_splits(G: int, M: int) -> torch.Tensor:
2241-
m_cumsums = torch.sort(
2242-
torch.randint(
2243-
0,
2244-
M,
2245-
(G + 1,),
2246-
dtype=torch.int32,
2247-
device=torch.accelerator.current_accelerator(),
2248-
)
2249-
).values
2250-
m_cumsums[0], m_cumsums[-1] = 0, M
2251-
m_sizes = m_cumsums[1:] - m_cumsums[:-1]
2252-
return m_sizes
2253-
2254-
m_sizes = generate_random_splits(G, M)
2254+
m_sizes = BF16Tests.generate_random_splits(G, M)
22552255

22562256
# Test
22572257
if output_accum:
@@ -2319,6 +2319,139 @@ def generate_random_splits(G: int, M: int) -> torch.Tensor:
23192319
rtol=1e-2,
23202320
)
23212321

2322+
@unittest.skipIf(
2323+
not torch.version.cuda,
2324+
"Skip on AMD: test_grouped_gemm_dgrad not yet suported.",
2325+
)
2326+
@settings(deadline=None)
2327+
@given(
2328+
G=st.sampled_from([2, 16]),
2329+
M=st.sampled_from([257, 2049]),
2330+
N=st.sampled_from([256, 2048]),
2331+
K=st.sampled_from([128, 1024]),
2332+
)
2333+
def test_grouped_gemm_dgrad(
2334+
self,
2335+
G: int,
2336+
M: int,
2337+
N: int,
2338+
K: int,
2339+
) -> None:
2340+
torch.manual_seed(hash((G, M, N, K)))
2341+
2342+
# Inputs
2343+
dy_bf16 = torch.randn(
2344+
(M, N), dtype=torch.bfloat16, device=torch.accelerator.current_accelerator()
2345+
)
2346+
w_bf16 = torch.randn(
2347+
(G, N, K),
2348+
dtype=torch.bfloat16,
2349+
device=torch.accelerator.current_accelerator(),
2350+
)
2351+
m_sizes = BF16Tests.generate_random_splits(G, M)
2352+
2353+
y_bf16 = torch.ops.fbgemm.bf16bf16bf16_grouped_grad(
2354+
dy_bf16,
2355+
w_bf16.permute(0, 2, 1),
2356+
m_sizes.to(torch.int64),
2357+
)
2358+
2359+
Y_preallocated = torch.empty(
2360+
(M * K),
2361+
dtype=torch.bfloat16,
2362+
device=torch.accelerator.current_accelerator(),
2363+
)
2364+
y_bf16_preallocated = torch.ops.fbgemm.bf16bf16bf16_grouped_grad(
2365+
dy_bf16,
2366+
w_bf16.permute(0, 2, 1),
2367+
m_sizes.to(torch.int64),
2368+
Y_preallocated,
2369+
)
2370+
2371+
# Reference
2372+
dy_fp32 = dy_bf16.to(torch.float32)
2373+
w_fp32 = w_bf16.to(torch.float32)
2374+
2375+
ref_y_fp32 = torch.empty(
2376+
(M, K), dtype=torch.float32, device=torch.accelerator.current_accelerator()
2377+
)
2378+
m_start = 0
2379+
for g, m_size in enumerate(m_sizes.tolist()):
2380+
ref_y_fp32[m_start : m_start + m_size, :] = dy_fp32[
2381+
m_start : m_start + m_size, :
2382+
] @ w_fp32[g, :, :].view(N, K)
2383+
m_start += m_size
2384+
ref_y_bf16 = ref_y_fp32.to(torch.bfloat16)
2385+
2386+
torch.testing.assert_close(y_bf16, ref_y_bf16, atol=1e-3, rtol=1.6e-2)
2387+
torch.testing.assert_close(
2388+
y_bf16_preallocated, ref_y_bf16, atol=1e-3, rtol=1.6e-2
2389+
)
2390+
2391+
@unittest.skipIf(
2392+
not torch.version.cuda,
2393+
"Skip on AMD: test_grouped_gemm_fprop not yet suported.",
2394+
)
2395+
@settings(deadline=None)
2396+
@given(
2397+
G=st.sampled_from([2, 16]),
2398+
M=st.sampled_from([257, 2049]),
2399+
N=st.sampled_from([256, 2048]),
2400+
K=st.sampled_from([128, 1024]),
2401+
)
2402+
def test_grouped_gemm_fprop(
2403+
self,
2404+
G: int,
2405+
M: int,
2406+
N: int,
2407+
K: int,
2408+
) -> None:
2409+
torch.manual_seed(hash((G, M, N, K)))
2410+
2411+
# Inputs
2412+
x_bf16 = torch.randn(
2413+
(M, K), dtype=torch.bfloat16, device=torch.accelerator.current_accelerator()
2414+
)
2415+
w_bf16 = torch.randn(
2416+
(G, N, K),
2417+
dtype=torch.bfloat16,
2418+
device=torch.accelerator.current_accelerator(),
2419+
)
2420+
m_sizes = BF16Tests.generate_random_splits(G, M)
2421+
2422+
y_bf16 = torch.ops.fbgemm.bf16bf16bf16_grouped_stacked(
2423+
x_bf16, w_bf16, m_sizes.to(torch.int64)
2424+
)
2425+
2426+
Y_preallocated = torch.empty(
2427+
(M * N),
2428+
dtype=torch.bfloat16,
2429+
device=torch.accelerator.current_accelerator(),
2430+
)
2431+
y_bf16_Y_preallocated = torch.ops.fbgemm.bf16bf16bf16_grouped_stacked(
2432+
x_bf16, w_bf16, m_sizes.to(torch.int64), Y_preallocated
2433+
)
2434+
2435+
# Reference
2436+
x_fp32 = x_bf16.to(torch.float32)
2437+
w_fp32 = w_bf16.to(torch.float32)
2438+
2439+
ref_y_fp32 = torch.empty(
2440+
(M, N), dtype=torch.float32, device=torch.accelerator.current_accelerator()
2441+
)
2442+
m_start = 0
2443+
for g, m_size in enumerate(m_sizes.tolist()):
2444+
ref_y_fp32[m_start : m_start + m_size, :] = (
2445+
x_fp32[m_start : m_start + m_size, :] @ w_fp32[g, :, :].view(N, K).T
2446+
)
2447+
m_start += m_size
2448+
ref_y_bf16 = ref_y_fp32.to(torch.bfloat16)
2449+
2450+
torch.testing.assert_close(y_bf16, ref_y_bf16, atol=1e-3, rtol=1.6e-2)
2451+
torch.testing.assert_close(
2452+
y_bf16_Y_preallocated, ref_y_bf16, atol=1e-3, rtol=1.6e-2
2453+
)
2454+
23222455

23232456
if __name__ == "__main__":
23242457
unittest.main()

0 commit comments

Comments
 (0)