diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py index 5d74555de3..898653d756 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py @@ -486,11 +486,19 @@ def print_kernels(kernels: Optional[List[str]]) -> List[QuantizeOpBase]: default=None, help="If set with grouped mode, repeat input shapes this many times. Comma separated list of groups to benchmark", ) +@click.option( + "--total-K", + default=None, + help="If set, adjusts the K values to sum to this number. " + "This can help simulate real grouped workloads in backward wgrad. " + "Comma separated list of total-K values to benchmark.", +) @click.option( "--total-M", default=None, - help="If set, Adjusts the M values to sum to this number. " - "This can help simulate real grouped workloads.", + help="If set, adjusts the M values to sum to this number. " + "This can help simulate real grouped workloads." + "Comma separated list of total-M values to benchmark.", ) @click.option( "--no-cuda-graph", @@ -542,6 +550,7 @@ def invoke_main( pair_nk: bool, grouped: bool, groups: Optional[str], + total_k: Optional[str], total_m: Optional[str], no_cuda_graph: bool, use_rotating_buffer_bench: bool, @@ -553,6 +562,14 @@ def invoke_main( ): if enable_amd_env_vars: set_amd_env_vars() + + # Validate that total_m and total_k are mutually exclusive + if total_m is not None and total_k is not None: + raise ValueError( + "total_m and total_k cannot be specified at the same time. " + "Please provide only one of them." + ) + # If kernel filter is provided, parse it. Else, benchmark all kernels. all_kernels = kernels.strip().split(",") if kernels else None quantize_ops = collect_kernels_to_profile(all_kernels) @@ -619,16 +636,31 @@ def invoke_main( if groups: groups_list = [int(g) for g in groups.strip().split(",")] if total_m: + total_m_list = [int(tm) for tm in total_m.strip().split(",")] MNK = [ [ [b] * g, - generate_group_tensor(g, int(total_m)), + generate_group_tensor(g, tm), [n] * g, [k] * g, ] for g in groups_list + for tm in total_m_list for b, _, n, k in MNK ] + elif total_k: + total_k_list = [int(tk) for tk in total_k.strip().split(",")] + MNK = [ + [ + [b] * g, + [m] * g, + [n] * g, + generate_group_tensor(g, tk), + ] + for g in groups_list + for tk in total_k_list + for b, m, n, _ in MNK + ] else: MNK = [ [[b] * g, [m] * g, [n] * g, [k] * g] diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index d5abd183f9..cb0be6d43a 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -2084,7 +2084,7 @@ def cuda(self) -> bool: @register_quantize_op class BF16GroupedGrad(QuantizeOpBase): """ - BF16 grouped matmul with grad inputs backed by cutlass + BF16 grouped matmul with dgrad inputs in pretraining backed by cutlass """ def preprocess(self, x, w): @@ -2126,6 +2126,52 @@ def cuda(self) -> bool: return True +@register_quantize_op +class BF16GroupedWGrad(QuantizeOpBase): + """ + BF16 grouped matmul with wgrad inputs in pretraining backed by cutlass + """ + + def preprocess(self, x, w): + # Get K values for each group + k_values = [xi.shape[1] for xi in x] # K dimension for each group + + # Convert k_values into sizes tensor + k_sizes = torch.tensor(k_values).to(dtype=torch.int64, device=x[0].device) + + x = torch.concat(x, dim=1).contiguous() # shape: (M, G*K) + w = torch.concat(w, dim=1).contiguous() # shape: (N, G*K) + + # Transpose the follows to simulate wgrad shapes + x = x.t().contiguous() # shape: (G*K, M) + w = w.t().contiguous() # shape: (G*K, N) + + # Return processed tensors + return x, w, k_sizes + + def quantize(self, x, w, k_sizes): + return x, w, k_sizes + + def compute(self, x, w, k_sizes): + return torch.ops.fbgemm.bf16bf16bf16_grouped_wgrad(x, w, k_sizes) + + def quantize_and_compute(self, x, w, k_sizes): + x, w, k_sizes = self.quantize(x, w, k_sizes) + return self.compute(x, w, k_sizes) + + @property + def name(self) -> str: + return "bf16_grouped_wgrad" + + @property + def hip(self) -> bool: + return False + + @property + def cuda(self) -> bool: + return True + + @register_quantize_op class BF16GroupedStacked(QuantizeOpBase): """ diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu new file mode 100644 index 0000000000..9850da3cfb --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu @@ -0,0 +1,3765 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include "bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh" +#include "bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh" +#include "fbgemm_gpu/quantize/tuning_cache.hpp" +#include "fbgemm_gpu/quantize/utils.h" +#include "fbgemm_gpu/quantize/utils_gpu.h" + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +namespace { +TuningCache& getTuningCache() { + static TuningCache cache("bf16bf16bf16_grouped_wgrad"); + return cache; +} +} // namespace + +// TODO: Re-enable the follow when the fbgemm relocation issue is resolved +// Kernel_bf16bf16bf16_grouped_wgrad +// get_wgrad_kernel_via_heuristic(int arch, int G, int total_M, int N, int K) { +// // Use heuristics to pick best kernel implementation. +// if (arch == 10) { +// // Llama4 shapes +// if ((N == 5120 && K == 1024) || (N == 2048 && K == 5120)) { +// if (total_M <= 256) { +// return bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f; +// } else if (total_M <= 512) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f; +// } else if (total_M <= 1024) { +// return bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f; +// } +// } + +// // Fallback to legacy heuristic. +// if (total_M <= 64 || (total_M <= 256 and N <= 1024)) { +// if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f; +// } +// } else if (total_M <= 512) { +// if (N <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f; +// } else if (N <= 8192) { +// if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f; +// } +// } +// } else if (total_M <= 1024) { +// if (N <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f; +// } else if (N <= 8192) { +// if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f; +// } +// } +// } else if (total_M <= 2048) { +// if (N <= 1024) { +// return bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f; +// } else if (N <= 8192) { +// if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f; +// } +// } +// } +// return bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f; +// } else { // arch == 9 +// // Llama4.x pretraining +// if (total_M == 8192) { +// if (N == 2560) { +// if (K == 1280) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t; +// } else if (K == 5120) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } +// } else if (N == 3072) { +// if (K == 1536) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } else if (K == 6144) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N == 5120) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (N == 6144) { +// if (K == 1536) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } else if (K == 6144) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } +// } + +// if (total_M == 16384) { +// if (N == 2560 || N == 3072) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (N == 5120) { +// if (K == 1280) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K == 5120) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } +// } else if (N == 6144) { +// if (K == 1536) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } else if (K == 6144) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; +// } +// } +// } + +// if (total_M == 65536) { +// if (N <= 512) { +// if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; +// } +// } else if (N <= 768) { +// if (K <= 384) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 768) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1024) { +// if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1280) { +// if (K <= 640) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 1280) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; +// } +// } else if (N <= 1536) { +// if (K <= 768) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 1536) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1792) { +// if (K <= 1792) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; +// } +// } else if (N <= 2048) { +// if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 2560) { +// if (K <= 1280) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; +// } else if (K <= 2560) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 3072) { +// if (K <= 1536) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 3072) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; +// } +// } +// } + +// // Fallback to legacy heuristic +// if (total_M <= 128) { +// if (N <= 128) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 1024) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 2048) { +// if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 8192) { +// if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else { +// if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } +// } else if (total_M <= 256) { +// if (N <= 128) { +// if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 1024) { +// if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 2048) { +// if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 4096) { +// if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else { +// if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } +// } else if (total_M <= 512) { +// if (N <= 128) { +// if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 1024) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 2048) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else { +// if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } +// } else if (total_M <= 1024) { +// if (N <= 128) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } +// } else if (N <= 1024) { +// if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 2048) { +// if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } +// } else if (total_M <= 2048) { +// if (N <= 128) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1024) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 2048) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } else { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } +// } else if (total_M <= 4096) { +// if (N <= 128) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1024) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 2048) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } +// } else if (total_M <= 8192) { +// if (N <= 128) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1024) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 2048) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } +// } else { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } +// } else { +// if (N <= 128) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1024) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 2048) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } +// } else { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } +// } +// } +// } +// } + +Kernel_bf16bf16bf16_grouped_wgrad get_kernel_via_tuning( + int arch, + int G, + int total_M, + int N, + int K, + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + auto& cache = getTuningCache(); + + // Reducing amount of auto tuning by rounding up total_m to next power of 2. + total_M = nextPowerOf2(total_M); + // Use (total_M, N, K, G) shape as the key. + const std::string shape_key = std::to_string(total_M) + "_" + + std::to_string(N) + "_" + std::to_string(K) + "_" + std::to_string(G); + const auto& kernels = get_bf16bf16bf16_grouped_wgrad_kernels(arch); + auto kernel = cache.findBestKernelMaybeAutotune( + shape_key, kernels, X, W, M_sizes, output, output_accum); + + return kernel; +} + +// BF16 grouped cutlass kernel dispatch. +at::Tensor dispatch_bf16_grouped_kernel( + int G, + int total_M, + int N, + int K, + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + const int arch = getDeviceArch(); + + // Use heuristics to pick best kernel implementation. + if (arch == 10) { + // Llama4 shapes + if ((N == 5120 && K == 1024) || (N == 2048 && K == 5120)) { + if (total_M <= 256) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (total_M <= 512) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 64, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 64, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (total_M <= 1024) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } + + // Fallback to legacy heuristic. + if (total_M <= 64 || (total_M <= 256 && N <= 1024)) { + if (K <= 4096) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 32, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 32, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } else if (total_M <= 512) { + if (N <= 1024) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (N <= 8192) { + if (K <= 2048) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 4096) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 32, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 32, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } + } else if (total_M <= 1024) { + if (N <= 1024) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (N <= 8192) { + if (K <= 2048) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 64, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 64, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 4096) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } + } else if (total_M <= 2048) { + if (N <= 1024) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (N <= 8192) { + if (K <= 2048) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 4096) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } + } + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { // arch == 9 + // Llama4.x pretraining + if (total_M == 8192) { + if (N == 2560) { + if (K == 1280) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 4, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 4, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } else if (K == 5120) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else if (N == 3072) { + if (K == 1536) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } + } else if (K == 6144) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else if (N == 5120) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } else if (N == 6144) { + if (K == 1536) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } + } else if (K == 6144) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } + } + + if (total_M == 16384) { + if (N == 2560 || N == 3072) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } else if (N == 5120) { + if (K == 1280) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } else if (K == 5120) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else if (N == 6144) { + if (K == 1536) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } + } else if (K == 6144) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } + } + + if (total_M == 65536) { + if (N <= 512) { + if (K <= 256) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 512) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else if (N <= 768) { + if (K <= 384) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 768) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else if (N <= 1024) { + if (K <= 1024) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else if (N <= 1280) { + if (K <= 640) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + true>(X, W, M_sizes, output); + } + } else if (K <= 1280) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else if (N <= 1536) { + if (K <= 768) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } else if (K <= 1536) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else if (N <= 1792) { + if (K <= 1792) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else if (N <= 2048) { + if (K <= 2048) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else if (N <= 2560) { + if (K <= 1280) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } else if (K <= 2560) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else if (N <= 3072) { + if (K <= 1536) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } else if (K <= 3072) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } + } + + // Fallback to legacy heuristic + if (total_M <= 128) { + if (N <= 128) { + if (K <= 128) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 4, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 4, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 256) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 512) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 1024) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 2048) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + true>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } else if (N <= 256) { + if (K <= 128) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 256) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 512) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 1024) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } else if (K <= 4096) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 8192) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } else if (N <= 512) { + if (K <= 128) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 2048) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 4096) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } else if (N <= 1024) { + if (K <= 128) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 256) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 512) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 4, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 4, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 4096) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 8192) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } else if (N <= 2048) { + if (K <= 512) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 1024) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 8192) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } else if (N <= 4096) { + if (K <= 128) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 256) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 8192) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else if (N <= 8192) { + if (K <= 8192) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else { + if (K <= 4096) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 8192) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } + } else if (total_M <= 256) { + if (N <= 128) { + if (K <= 256) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 512) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 1024) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 2048) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } else if (N <= 256) { + if (K <= 128) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 256) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 512) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 2048) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } else if (N <= 512) { + if (K <= 128) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 512) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 1024) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } else if (K <= 2048) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } else if (N <= 1024) { + if (K <= 512) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 1024) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 2048) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 4096) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 8192) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } else if (N <= 2048) { + if (K <= 256) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 512) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 4096) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } + } else if (N <= 4096) { + if (K <= 256) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 512) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 8192) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else if (N <= 8192) { + if (K <= 128) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 256) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else if (K <= 4096) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } else { + if (K <= 4096) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + true>(X, W, M_sizes, output); + } + } + } + } + + // Default fallback + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } +} + +at::Tensor bf16bf16bf16_grouped_wgrad( + at::Tensor X, + at::Tensor W, + at::Tensor M_sizes, + std::optional output, + bool output_accum) { + int64_t total_M = X.size(0); + int64_t N = X.size(1); + int64_t K = W.size(1); + int64_t G = M_sizes.size(0); + TORCH_CHECK( + M_sizes.device() == X.device(), + "M_sizes must be on same device as inputs."); + TORCH_CHECK( + X.dim() == 2 && W.dim() == 2 && W.size(0) == total_M, + "Activations should be shape [GM, N] and weights should be shape [GM, K]") + + if (output_accum) { + TORCH_CHECK( + output.has_value(), "Must provide output tensor for output_accum=True"); + } + + at::Tensor Y; + if (output.has_value()) { + Y = output.value(); + if (output_accum) { + TORCH_CHECK( + Y.dtype() == at::kFloat, + "Output tensor must be Float32 when output_accum=True"); + } else { + TORCH_CHECK( + Y.dtype() == at::kBFloat16, + "Output tensor must be BFloat16 when output_accum=False"); + } + } else { + Y = at::empty(G * N * K, X.options().dtype(at::kBFloat16)); + } + + // Early exit for empty inputs. + if (total_M == 0) { + return Y.view({G, N, K}); + } + // Return continuous view of output. + at::Tensor out = dispatch_bf16_grouped_kernel( + G, total_M, N, K, X, W, M_sizes, Y, output_accum); + return out.view({G, N, K}); +} + +#else + +at::Tensor bf16bf16bf16_grouped_wgrad( + at::Tensor, + at::Tensor, + at::Tensor, + std::optional, + bool) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh new file mode 100644 index 0000000000..45ee9e22b4 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh @@ -0,0 +1,649 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +namespace fbgemm_gpu { + +inline int64_t _byte_align(int64_t offset) { + int64_t remainder = offset % 16; + if (remainder != 0) { + offset += (16 - remainder); + } + return offset; +} + +template < + typename ProblemShape, + typename ElementA, + typename ElementB, + typename ElementC, + typename StrideA, + typename StrideB, + typename StrideC> +__global__ void set_stacked_kernel_args_kernel( + int64_t G, + int64_t N, + int64_t K, + ProblemShape* problem_shape_ptr, + ElementA* x, + const ElementA** x_ptr, + ElementB* w, + const ElementB** w_ptr, + ElementC* output, + ElementC** output_ptr, + StrideA* stride_a_ptr, + StrideB* stride_b_ptr, + StrideC* stride_c_ptr, + int64_t* M_sizes) { + uint32_t group_index = blockIdx.x * blockDim.x + threadIdx.x; + // If this thread corresponds to a valid group, write kernel args to device + // memory. + if (group_index < G) { + // Its possible that we're only writing a subset of the groups to + // kernel args. To do this, we need to set all groups initially to empty. + // and keep a problem counter for the number of non-empty groups. + __shared__ int non_zero_counter; + // Initialize counter in first group. + if (group_index == 0) { + non_zero_counter = 0; + } + // Set problem shapes to empty by default. + problem_shape_ptr[group_index] = ProblemShape(0, 0, 0); + // Sync threads to get consistent state in the block. + __syncthreads(); + + // Compute shape for this group. + // M for this group is pulled directly from M_sizes. + int M = M_sizes[group_index]; + // Only proceed to writing kernel args if this group is non-empty. + if (M > 0) { + // Get the index for this group atomically. + int non_zero_idx = atomicAdd(&non_zero_counter, 1); + // We compute the offset by getting the cumulative sum over + // prior groups. + int64_t offset_M = 0; + for (int i = 0; i < group_index; i++) { + offset_M += M_sizes[i]; + } + // Set the problem shape for this group. + problem_shape_ptr[non_zero_idx] = ProblemShape(int(N), int(K), int(M)); + // Set input pointers. + x_ptr[non_zero_idx] = x + (offset_M * N); + w_ptr[non_zero_idx] = w + (offset_M * K); + output_ptr[non_zero_idx] = output + (N * K * group_index); + stride_a_ptr[non_zero_idx] = cutlass::make_cute_packed_stride( + StrideA{}, cute::make_shape(int(N), int(M), 1)); + stride_b_ptr[non_zero_idx] = cutlass::make_cute_packed_stride( + StrideB{}, cute::make_shape(int(K), int(M), 1)); + stride_c_ptr[non_zero_idx] = cutlass::make_cute_packed_stride( + StrideC{}, cute::make_shape(int(N), int(K), 1)); + } + } +} + +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool OUTPUT_ACCUM, + bool PONG> +at::Tensor bf16bf16bf16_grouped_wgrad_impl( + at::Tensor X, + at::Tensor W, + at::Tensor M_sizes, + at::Tensor output) { + int64_t G; + at::TensorOptions options; + G = M_sizes.size(0); + options = X.options(); + + // Return early if there are no elements in the output. + if (output.numel() == 0) { + return output; + } + + // Define gemm configuration. + using ProblemShape = + cutlass::gemm::GroupProblemShape>; + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + using ElementC = + cute::conditional_t; + + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using TileShape = + cute::Shape, cute::Int, cute::Int>; + using ClusterShape = + cute::Shape, cute::Int, cute::Int>; + + using CooperativeSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using PongSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + + using CooperativeEpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using PongEpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + + using MainLoopSchedule = + cute::conditional_t; + using EpilogueSchedule = cute:: + conditional_t; + + using ComputeC = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, + ElementC, + ElementAccumulator, + cutlass::FloatRoundStyle::round_to_nearest>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using C_src = cutlass::epilogue::fusion::Sm90SrcFetch; + using EVTAddC = cutlass::epilogue::fusion::Sm90EVT; + + using CollectiveEpilogueDefault = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, // Use source tensor for epilogue operations. + LayoutC*, + 128 / cutlass::sizeof_bits::value, + ElementC, + LayoutC*, + 128 / cutlass::sizeof_bits::value, + EpilogueSchedule>::CollectiveOp; + + using CollectiveEpilogueAccum = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, // Use source tensor for epilogue operations. + LayoutC*, + 128 / cutlass::sizeof_bits::value, + ElementC, + LayoutC*, + 128 / cutlass::sizeof_bits::value, + EpilogueSchedule, + EVTAddC>::CollectiveOp; + + using CollectiveEpilogue = cute::conditional_t< + OUTPUT_ACCUM, + CollectiveEpilogueAccum, + CollectiveEpilogueDefault>; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutA*, + 128 / cutlass::sizeof_bits::value, + ElementB, + LayoutB*, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel:: + GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideD; + + // Create a buffer for kernel arguments. We do this by first figuring out + // how much space each sub-argument requires and setting up corresponding + // pointers. + const int64_t problem_size_offset = 0; + int64_t problem_size_buffer = + _byte_align(G * sizeof(ProblemShape::UnderlyingProblemShape)); + + // Next create space for X pointers. + const int64_t x_offset = problem_size_offset + problem_size_buffer; + int64_t x_size_buffer = _byte_align(G * sizeof(ElementA**)); + + // W Pointers. + const int64_t w_offset = x_offset + x_size_buffer; + int64_t w_size_buffer = _byte_align(G * sizeof(ElementB**)); + + // Outputs. + const int64_t output_offset = w_offset + w_size_buffer; + int64_t output_buffer = _byte_align(G * sizeof(ElementC**)); + + // A stride. + const int64_t stride_a_offset = output_offset + output_buffer; + int64_t stride_a_buffer = _byte_align(G * sizeof(StrideA)); + + // B stride; + const int64_t stride_b_offset = stride_a_offset + stride_a_buffer; + int64_t stride_b_buffer = _byte_align(G * sizeof(StrideB)); + + // C stride; + const int64_t stride_c_offset = stride_b_offset + stride_b_buffer; + int64_t stride_c_buffer = _byte_align(G * sizeof(StrideC)); + + // Compute total buffer size + int64_t total_buffer_size = stride_c_offset + stride_c_buffer; + + // Allocate space for gemm information. + at::Tensor kernel_args = + at::empty({total_buffer_size}, options.dtype(at::kByte)); + + // Get byte pointer to underlying data. + char* kernel_args_ptr = reinterpret_cast(kernel_args.data_ptr()); + + // Now use offsets to get appropriately typed pointers. + ProblemShape::UnderlyingProblemShape* problem_shape_ptr = + reinterpret_cast( + kernel_args_ptr + problem_size_offset); + const ElementA** x_ptr = + reinterpret_cast(kernel_args_ptr + x_offset); + const ElementB** w_ptr = + reinterpret_cast(kernel_args_ptr + w_offset); + ElementC** output_ptr = + reinterpret_cast(kernel_args_ptr + output_offset); + StrideA* stride_a_ptr = + reinterpret_cast(kernel_args_ptr + stride_a_offset); + StrideB* stride_b_ptr = + reinterpret_cast(kernel_args_ptr + stride_b_offset); + StrideC* stride_c_ptr = + reinterpret_cast(kernel_args_ptr + stride_c_offset); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK(M_sizes.dtype() == at::kLong, "M_sizes must be int64."); + int64_t total_M = X.size(0); + int64_t N = X.size(1); + int64_t K = W.size(1); + + int64_t* M_sizes_ptr = reinterpret_cast(M_sizes.data_ptr()); + set_stacked_kernel_args_kernel<<<1, G, 0, stream>>>( + G, + N, + K, + problem_shape_ptr, + reinterpret_cast(X.data_ptr()), + x_ptr, + reinterpret_cast(W.data_ptr()), + w_ptr, + reinterpret_cast(output.data_ptr()), + output_ptr, + stride_a_ptr, + stride_b_ptr, + stride_c_ptr, + M_sizes_ptr); + int kernel_groups = int(std::min(total_M, G)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {kernel_groups, problem_shape_ptr, nullptr}, + {x_ptr, stride_a_ptr, w_ptr, stride_b_ptr}, + {{}, + (const ElementC**)output_ptr, + stride_c_ptr, + output_ptr, + stride_c_ptr}}; + + int sm_count = at::cuda::getDeviceProperties(output.device().index()) + ->multiProcessorCount; + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + sm_count -= at::globalContext()._SMCarveout_EXPERIMENTAL().value(); + } + arguments.hw_info.sm_count = sm_count; + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + at::Tensor workspace = at::empty(workspace_size, options.dtype(at::kByte)); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize( + arguments, reinterpret_cast(workspace.data_ptr())); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +#if CUDART_VERSION >= 12080 +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool OUTPUT_ACCUM, + bool PONG> +at::Tensor bf16bf16bf16_grouped_wgrad_sm100_impl( + at::Tensor X, + at::Tensor W, + at::Tensor M_sizes, + at::Tensor output) { + int64_t G; + at::TensorOptions options; + G = M_sizes.size(0); + options = X.options(); + + // Return early if there are no elements in the output. + if (output.numel() == 0) { + return output; + } + + // Define gemm configuration. + using ProblemShape = + cutlass::gemm::GroupProblemShape>; + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + using ElementC = + cute::conditional_t; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using TileShape = + cute::Shape, cute::Int, cute::Int>; + using ClusterShape = + cute::Shape, cute::Int, cute::Int>; + + using MainLoopSchedule = cute::conditional_t< + (TBS_M % 2 == 0) || (TB_M == 256), + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100>; + using EpilogueSchedule = cute::conditional_t< + (TBS_M % 2 == 0) || (TB_M == 256), + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm>; + + using ComputeC = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, + ElementC, + ElementAccumulator, + cutlass::FloatRoundStyle::round_to_nearest>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using C_src = cutlass::epilogue::fusion::Sm90SrcFetch; + using EVTAddC = cutlass::epilogue::fusion::Sm90EVT; + + using CollectiveEpilogueDefault = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, // Use source tensor for epilogue operations. + LayoutC*, + 128 / cutlass::sizeof_bits::value, + ElementC, + LayoutC*, + 128 / cutlass::sizeof_bits::value, + EpilogueSchedule>::CollectiveOp; + + using CollectiveEpilogueAccum = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, // Use source tensor for epilogue operations. + LayoutC*, + 128 / cutlass::sizeof_bits::value, + ElementC, + LayoutC*, + 128 / cutlass::sizeof_bits::value, + EpilogueSchedule, + EVTAddC>::CollectiveOp; + + using CollectiveEpilogue = cute::conditional_t< + OUTPUT_ACCUM, + CollectiveEpilogueAccum, + CollectiveEpilogueDefault>; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutA*, + 128 / cutlass::sizeof_bits::value, + ElementB, + LayoutB*, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel:: + GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideD; + + // Create a buffer for kernel arguments. We do this by first figuring out + // how much space each sub-argument requires and setting up corresponding + // pointers. + const int64_t problem_size_offset = 0; + int64_t problem_size_buffer = + _byte_align(G * sizeof(ProblemShape::UnderlyingProblemShape)); + + // Next create space for X pointers. + const int64_t x_offset = problem_size_offset + problem_size_buffer; + int64_t x_size_buffer = _byte_align(G * sizeof(ElementA**)); + + // W Pointers. + const int64_t w_offset = x_offset + x_size_buffer; + int64_t w_size_buffer = _byte_align(G * sizeof(ElementB**)); + + // Outputs. + const int64_t output_offset = w_offset + w_size_buffer; + int64_t output_buffer = _byte_align(G * sizeof(ElementC**)); + + // A stride. + const int64_t stride_a_offset = output_offset + output_buffer; + int64_t stride_a_buffer = _byte_align(G * sizeof(StrideA)); + + // B stride; + const int64_t stride_b_offset = stride_a_offset + stride_a_buffer; + int64_t stride_b_buffer = _byte_align(G * sizeof(StrideB)); + + // C stride; + const int64_t stride_c_offset = stride_b_offset + stride_b_buffer; + int64_t stride_c_buffer = _byte_align(G * sizeof(StrideC)); + + // Compute total buffer size + int64_t total_buffer_size = stride_c_offset + stride_c_buffer; + + // Allocate space for gemm information. + at::Tensor kernel_args = + at::empty({total_buffer_size}, options.dtype(at::kByte)); + + // Get byte pointer to underlying data. + char* kernel_args_ptr = reinterpret_cast(kernel_args.data_ptr()); + + // Now use offsets to get appropriately typed pointers. + ProblemShape::UnderlyingProblemShape* problem_shape_ptr = + reinterpret_cast( + kernel_args_ptr + problem_size_offset); + const ElementA** x_ptr = + reinterpret_cast(kernel_args_ptr + x_offset); + const ElementB** w_ptr = + reinterpret_cast(kernel_args_ptr + w_offset); + ElementC** output_ptr = + reinterpret_cast(kernel_args_ptr + output_offset); + StrideA* stride_a_ptr = + reinterpret_cast(kernel_args_ptr + stride_a_offset); + StrideB* stride_b_ptr = + reinterpret_cast(kernel_args_ptr + stride_b_offset); + StrideC* stride_c_ptr = + reinterpret_cast(kernel_args_ptr + stride_c_offset); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK(M_sizes.dtype() == at::kLong, "M_sizes must be int64."); + int64_t total_M = X.size(0); + int64_t N = X.size(1); + int64_t K = W.size(1); + + int64_t* M_sizes_ptr = reinterpret_cast(M_sizes.data_ptr()); + set_stacked_kernel_args_kernel<<<1, G, 0, stream>>>( + G, + N, + K, + problem_shape_ptr, + reinterpret_cast(X.data_ptr()), + x_ptr, + reinterpret_cast(W.data_ptr()), + w_ptr, + reinterpret_cast(output.data_ptr()), + output_ptr, + stride_a_ptr, + stride_b_ptr, + stride_c_ptr, + M_sizes_ptr); + int kernel_groups = int(std::min(total_M, G)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {kernel_groups, problem_shape_ptr, nullptr}, + {x_ptr, stride_a_ptr, w_ptr, stride_b_ptr}, + {{}, + (const ElementC**)output_ptr, + stride_c_ptr, + output_ptr, + stride_c_ptr}}; + + int sm_count = at::cuda::getDeviceProperties(output.device().index()) + ->multiProcessorCount; + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + sm_count -= at::globalContext()._SMCarveout_EXPERIMENTAL().value(); + } + arguments.hw_info.sm_count = sm_count; + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + at::Tensor workspace = at::empty(workspace_size, options.dtype(at::kByte)); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize( + arguments, reinterpret_cast(workspace.data_ptr())); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +#else +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool OUTPUT_ACCUM, + bool PONG> +at::Tensor bf16bf16bf16_grouped_wgrad_sm100_impl( + at::Tensor X, + at::Tensor W, + at::Tensor M_sizes, + at::Tensor output) { + return output; +} +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh new file mode 100644 index 0000000000..64375ea2e8 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh @@ -0,0 +1,338 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace fbgemm_gpu { + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +using Kernel_bf16bf16bf16_grouped_wgrad = + at::Tensor (*)(at::Tensor, at::Tensor, at::Tensor, at::Tensor, bool); + +const std::unordered_map& +get_bf16bf16bf16_grouped_wgrad_kernels(int arch) { + // Return a static empty map + static const std:: + unordered_map + empty_map; + return empty_map; + + // static const std:: + // unordered_map + // kernelsSM90 = { + // {"bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f", + // bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f", + // bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f}, + // }; + // static const std:: + // unordered_map + // kernelsSM100 = { + // {"bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f}, + // {"bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f}, + // {"bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f}, + // {"bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f}, + // {"bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f}, + // {"bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f}, + // }; + // if (arch == 10) { + // return kernelsSM100; + // } else { + // return kernelsSM90; + // } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 2de3346c3d..7a148e87b0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -80,6 +80,12 @@ at::Tensor bf16bf16bf16_grouped_stacked(at::Tensor X, at::Tensor W, at::Tensor M_sizes); at::Tensor bf16bf16bf16_grouped_grad(at::Tensor X, at::Tensor W, at::Tensor M_sizes); +at::Tensor bf16bf16bf16_grouped_wgrad( + at::Tensor X, + at::Tensor W, + at::Tensor M_sizes, + std::optional output = std::nullopt, + bool output_accum = false); at::Tensor f8f8bf16_rowwise( at::Tensor XQ, at::Tensor WQ, @@ -320,6 +326,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("f8i4bf16_shuffled_grouped", f8i4bf16_shuffled_grouped); m.impl("bf16i4bf16_shuffled_grouped", bf16i4bf16_shuffled_grouped); m.impl("bf16bf16bf16_grouped_grad", bf16bf16bf16_grouped_grad); + m.impl("bf16bf16bf16_grouped_wgrad", bf16bf16bf16_grouped_wgrad); m.impl("preshuffle_i4", preshuffle_i4); m.impl("bf16i4bf16_shuffled_batched", bf16i4bf16_shuffled_batched); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); @@ -804,6 +811,19 @@ at::Tensor bf16bf16bf16_grouped_grad_meta( return Y; } +at::Tensor bf16bf16bf16_grouped_wgrad_meta( + at::Tensor X, + at::Tensor W, + at::Tensor M_sizes, + std::optional /* output = std::nullopt */, + bool /* output_accum = false */) { + const at::SymInt G = M_sizes.size(0); + const at::SymInt N = X.sym_size(1); + const at::SymInt K = W.sym_size(1); + at::Tensor Y = at::empty_symint({G, N, K}, X.options().dtype(at::kBFloat16)); + return Y; +} + at::Tensor f8f8bf16_rowwise_grouped_stacked_meta( at::Tensor XQ, at::Tensor WQ, @@ -845,6 +865,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("bf16i4bf16_shuffled_batched", bf16i4bf16_shuffled_batched_meta); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta); m.impl("bf16bf16bf16_grouped_grad", bf16bf16bf16_grouped_grad_meta); + m.impl("bf16bf16bf16_grouped_wgrad", bf16bf16bf16_grouped_wgrad_meta); m.impl("f8f8bf16_lite", f8f8bf16_lite_meta); m.impl("scaled_fp4_quant", scaled_fp4_quant_meta); m.impl("preshuffle_i4", preshuffle_i4_meta); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp index 83b2106032..9f120073ad 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp @@ -66,6 +66,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "bf16bf16bf16_grouped_stacked(Tensor X, Tensor W, Tensor M_sizes) -> Tensor"); m.def( "bf16bf16bf16_grouped_grad(Tensor X, Tensor W, Tensor M_sizes) -> Tensor"); + m.def( + "bf16bf16bf16_grouped_wgrad(Tensor X, Tensor W, Tensor M_sizes, Tensor(a!)? output=None, bool output_accum=False) -> Tensor"); m.def( "f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m=128, int block_n=128, int block_k=128) -> Tensor"); m.def( diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 469ccd0e2b..53f6cd12da 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -2198,5 +2198,127 @@ def test_fake_quantize_nvfp4_per_tensor( torch.testing.assert_close(fake_quant_y, y_ref, atol=0.1, rtol=0.1) +@unittest.skipIf( + not torch.cuda.is_available() + or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9, + "Skip when MI300 or H100 is not available", +) +class BF16Tests(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.device = torch.accelerator.current_accelerator() + + @unittest.skipIf( + not torch.version.cuda, + "Skip on AMD: test_bf16_grouped_gemmw_wgrad not yet suported.", + ) + @settings(deadline=None) + @given( + G=st.sampled_from([2, 16]), + M=st.sampled_from([257, 2049]), + N=st.sampled_from([256, 2048]), + K=st.sampled_from([128, 1024]), + output_accum=st.booleans(), + ) + def test_bf16_grouped_gemmw_wgrad( + self, + G: int, + M: int, + N: int, + K: int, + output_accum: bool, + ) -> None: + torch.manual_seed(hash((G, M, N, K))) + # Inputs + dy_bf16 = torch.randn( + (M, N), dtype=torch.bfloat16, device=torch.accelerator.current_accelerator() + ) + x_bf16 = torch.randn( + (M, K), dtype=torch.bfloat16, device=torch.accelerator.current_accelerator() + ) + + def generate_random_splits(G: int, M: int) -> torch.Tensor: + m_cumsums = torch.sort( + torch.randint( + 0, + M, + (G + 1,), + dtype=torch.int32, + device=torch.accelerator.current_accelerator(), + ) + ).values + m_cumsums[0], m_cumsums[-1] = 0, M + m_sizes = m_cumsums[1:] - m_cumsums[:-1] + return m_sizes + + m_sizes = generate_random_splits(G, M) + + # Test + if output_accum: + wgrad_accum = torch.randn( + (G, N, K), + dtype=torch.float32, + device=torch.accelerator.current_accelerator(), + ) + else: + wgrad_accum = None + + test_wgrad = torch.ops.fbgemm.bf16bf16bf16_grouped_wgrad( + dy_bf16, + x_bf16, + m_sizes.to(torch.int64), + output=wgrad_accum.clone() if output_accum else None, + output_accum=output_accum, + ) + + if output_accum: + assert test_wgrad.dtype == torch.float32 + + # Reference + dy_fp32 = dy_bf16.to(torch.float32) + x_fp32 = x_bf16.to(torch.float32) + ref_wgrad = torch.empty( + (G, N, K), + dtype=torch.float32, + device=torch.accelerator.current_accelerator(), + ) + + # Track which groups have non-zero size for comparison + non_zero_groups = [] + m_start = 0 + for g, m_size in enumerate(m_sizes.tolist()): + if m_size > 0: + # Actual slice - compute matrix multiplication + ref_wgrad[g, :, :] = ( + dy_fp32[m_start : m_start + m_size, :].T + @ x_fp32[m_start : m_start + m_size, :] + ) + non_zero_groups.append(g) + m_start += m_size + + if output_accum: + assert wgrad_accum is not None + ref_wgrad += wgrad_accum + + ref_wgrad = ref_wgrad.to(test_wgrad.dtype) + + # Compare groups with non-zero m_size + if non_zero_groups: + if test_wgrad.dtype == torch.float32: + torch.testing.assert_close( + test_wgrad[non_zero_groups], + ref_wgrad[non_zero_groups], + atol=1e-4, + rtol=1e-4, + ) + else: + torch.testing.assert_close( + test_wgrad[non_zero_groups], + ref_wgrad[non_zero_groups], + atol=1e-4, + rtol=1e-2, + ) + + if __name__ == "__main__": unittest.main()