From d674436672eaa9543e13960a73e92780937a5cf2 Mon Sep 17 00:00:00 2001 From: jiawenl Date: Mon, 22 Sep 2025 17:27:09 -0700 Subject: [PATCH 1/4] Enable CUTLASS grouped GEMM for pretraining wgrad on GB200 and H100 (resubmit) Differential Revision: D83001505 --- .../gen_ai/bench/quantize_bench.py | 28 +- .../experimental/gen_ai/bench/quantize_ops.py | 48 +- .../bf16bf16bf16_grouped_wgrad.cu | 340 +++++++++ ...f16_grouped_wgrad_128_128_128_1_1_1_9_f.cu | 35 + ...f16_grouped_wgrad_128_128_128_1_2_1_9_f.cu | 35 + ...f16_grouped_wgrad_128_128_128_1_2_1_9_t.cu | 28 + ...f16_grouped_wgrad_128_128_128_1_4_1_9_t.cu | 28 + ...16_grouped_wgrad_128_128_128_2_1_1_10_f.cu | 42 ++ ...f16_grouped_wgrad_128_128_128_2_1_1_9_t.cu | 28 + ...f16_grouped_wgrad_128_128_128_2_2_1_9_t.cu | 28 + ...f16_grouped_wgrad_128_128_128_2_4_1_9_t.cu | 28 + ...f16_grouped_wgrad_128_128_128_4_2_1_9_t.cu | 28 + ...f16_grouped_wgrad_128_128_128_4_2_2_9_f.cu | 35 + ...f16_grouped_wgrad_128_128_128_4_2_2_9_t.cu | 28 + ...f16_grouped_wgrad_128_128_128_4_4_1_9_f.cu | 35 + ...f16_grouped_wgrad_128_128_128_4_4_1_9_t.cu | 28 + ...f16_grouped_wgrad_128_128_128_4_4_2_9_f.cu | 35 + ...f16_grouped_wgrad_128_128_128_4_4_2_9_t.cu | 28 + ...f16_grouped_wgrad_128_128_128_4_4_4_9_f.cu | 35 + ...f16_grouped_wgrad_128_128_128_4_4_4_9_t.cu | 28 + ...f16_grouped_wgrad_128_256_128_1_1_1_9_f.cu | 35 + ...f16_grouped_wgrad_128_256_128_1_2_1_9_f.cu | 35 + ...bf16_grouped_wgrad_128_32_128_1_1_1_9_f.cu | 28 + ...bf16_grouped_wgrad_128_32_128_1_4_1_9_f.cu | 28 + ...f16_grouped_wgrad_128_32_128_2_1_1_10_f.cu | 42 ++ ...bf16_grouped_wgrad_128_32_128_2_2_1_9_f.cu | 28 + ...bf16_grouped_wgrad_128_32_128_2_4_1_9_f.cu | 28 + ...bf16_grouped_wgrad_128_64_128_1_2_1_9_f.cu | 28 + ...bf16_grouped_wgrad_128_64_128_1_4_1_9_f.cu | 28 + ...f16_grouped_wgrad_128_64_128_2_1_1_10_f.cu | 42 ++ ...bf16_grouped_wgrad_128_64_128_2_2_1_9_f.cu | 28 + ...bf16_grouped_wgrad_128_64_128_4_2_1_9_f.cu | 28 + ...f16_grouped_wgrad_256_128_128_1_2_1_9_f.cu | 35 + ...16_grouped_wgrad_256_128_128_2_1_1_10_f.cu | 42 ++ ...16_grouped_wgrad_256_256_128_2_1_1_10_f.cu | 42 ++ ...bf16_grouped_wgrad_256_32_128_1_1_1_9_f.cu | 28 + ...f16_grouped_wgrad_256_32_128_2_1_1_10_f.cu | 42 ++ ...f16_grouped_wgrad_256_64_128_2_1_1_10_f.cu | 42 ++ .../bf16bf16bf16_grouped_wgrad_common.cuh | 646 ++++++++++++++++++ .../bf16bf16bf16_grouped_wgrad_manifest.cuh | 260 +++++++ .../gen_ai/src/quantize/quantize.cpp | 21 + .../gen_ai/src/quantize/quantize_defs.cpp | 2 + .../gen_ai/test/quantize/quantize_test.py | 111 +++ 43 files changed, 2595 insertions(+), 2 deletions(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_1_1_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py index 5d74555de3..c58908d101 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py @@ -486,10 +486,16 @@ def print_kernels(kernels: Optional[List[str]]) -> List[QuantizeOpBase]: default=None, help="If set with grouped mode, repeat input shapes this many times. Comma separated list of groups to benchmark", ) +@click.option( + "--total-K", + default=None, + help="If set, adjusts the K values to sum to this number. " + "This can help simulate real grouped workloads in backward wgrad.", +) @click.option( "--total-M", default=None, - help="If set, Adjusts the M values to sum to this number. " + help="If set, adjusts the M values to sum to this number. " "This can help simulate real grouped workloads.", ) @click.option( @@ -542,6 +548,7 @@ def invoke_main( pair_nk: bool, grouped: bool, groups: Optional[str], + total_k: Optional[str], total_m: Optional[str], no_cuda_graph: bool, use_rotating_buffer_bench: bool, @@ -553,6 +560,14 @@ def invoke_main( ): if enable_amd_env_vars: set_amd_env_vars() + + # Validate that total_m and total_k are mutually exclusive + if total_m is not None and total_k is not None: + raise ValueError( + "total_m and total_k cannot be specified at the same time. " + "Please provide only one of them." + ) + # If kernel filter is provided, parse it. Else, benchmark all kernels. all_kernels = kernels.strip().split(",") if kernels else None quantize_ops = collect_kernels_to_profile(all_kernels) @@ -629,6 +644,17 @@ def invoke_main( for g in groups_list for b, _, n, k in MNK ] + elif total_k: + MNK = [ + [ + [b] * g, + [m] * g, + [n] * g, + generate_group_tensor(g, int(total_k)), + ] + for g in groups_list + for b, m, n, _ in MNK + ] else: MNK = [ [[b] * g, [m] * g, [n] * g, [k] * g] diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index d5abd183f9..cb0be6d43a 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -2084,7 +2084,7 @@ def cuda(self) -> bool: @register_quantize_op class BF16GroupedGrad(QuantizeOpBase): """ - BF16 grouped matmul with grad inputs backed by cutlass + BF16 grouped matmul with dgrad inputs in pretraining backed by cutlass """ def preprocess(self, x, w): @@ -2126,6 +2126,52 @@ def cuda(self) -> bool: return True +@register_quantize_op +class BF16GroupedWGrad(QuantizeOpBase): + """ + BF16 grouped matmul with wgrad inputs in pretraining backed by cutlass + """ + + def preprocess(self, x, w): + # Get K values for each group + k_values = [xi.shape[1] for xi in x] # K dimension for each group + + # Convert k_values into sizes tensor + k_sizes = torch.tensor(k_values).to(dtype=torch.int64, device=x[0].device) + + x = torch.concat(x, dim=1).contiguous() # shape: (M, G*K) + w = torch.concat(w, dim=1).contiguous() # shape: (N, G*K) + + # Transpose the follows to simulate wgrad shapes + x = x.t().contiguous() # shape: (G*K, M) + w = w.t().contiguous() # shape: (G*K, N) + + # Return processed tensors + return x, w, k_sizes + + def quantize(self, x, w, k_sizes): + return x, w, k_sizes + + def compute(self, x, w, k_sizes): + return torch.ops.fbgemm.bf16bf16bf16_grouped_wgrad(x, w, k_sizes) + + def quantize_and_compute(self, x, w, k_sizes): + x, w, k_sizes = self.quantize(x, w, k_sizes) + return self.compute(x, w, k_sizes) + + @property + def name(self) -> str: + return "bf16_grouped_wgrad" + + @property + def hip(self) -> bool: + return False + + @property + def cuda(self) -> bool: + return True + + @register_quantize_op class BF16GroupedStacked(QuantizeOpBase): """ diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu new file mode 100644 index 0000000000..f6aa4adbba --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu @@ -0,0 +1,340 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include "bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh" +#include "fbgemm_gpu/quantize/tuning_cache.hpp" +#include "fbgemm_gpu/quantize/utils.h" +#include "fbgemm_gpu/quantize/utils_gpu.h" + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +namespace { +TuningCache& getTuningCache() { + static TuningCache cache("bf16bf16bf16_grouped_wgrad"); + return cache; +} +} // namespace + +Kernel_bf16bf16bf16_grouped_wgrad +get_wgrad_kernel_via_heuristic(int arch, int G, int total_M, int N, int K) { + // Use heuristics to pick best kernel implementation. + if (arch == 10) { + // Llama4 shapes + if ((N == 5120 && K == 1024) || (N == 2048 && K == 5120)) { + if (total_M <= 256) { + return bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f; + } else if (total_M <= 512) { + return bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f; + } else if (total_M <= 1024) { + return bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f; + } else { + return bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f; + } + } + + // Fallback to legacy heuristic. + if (total_M <= 64 || (total_M <= 256 and N <= 1024)) { + if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f; + } + } else if (total_M <= 512) { + if (N <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f; + } else if (N <= 8192) { + if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f; + } + } + } else if (total_M <= 1024) { + if (N <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f; + } else if (N <= 8192) { + if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f; + } + } + } else if (total_M <= 2048) { + if (N <= 1024) { + return bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f; + } else if (N <= 8192) { + if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f; + } + } + } + return bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f; + } else { // arch == 9 + // Llama4.x pretraining + if (total_M == 8192) { + if (N == 2560) { + if (K == 1280) { + return bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t; + } else if (K == 5120) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + } + } else if (N == 3072) { + if (K == 1536) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + } else if (K == 6144) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N == 5120) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (N == 6144) { + if (K == 1536) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + } else if (K == 6144) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } + } + + if (total_M == 16384) { + if (N == 2560 || N == 3072) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (N == 5120) { + if (K == 1280) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K == 5120) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + } + } else if (N == 6144) { + if (K == 1536) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + } else if (K == 6144) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; + } + } + } + + if (total_M == 65536) { + if (N <= 512) { + if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; + } + } else if (N <= 768) { + if (K <= 384) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; + } else if (K <= 768) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 1024) { + if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 1280) { + if (K <= 640) { + return bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f; + } else if (K <= 1280) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; + } + } else if (N <= 1536) { + if (K <= 768) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f; + } else if (K <= 1536) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 1792) { + if (K <= 1792) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; + } + } else if (N <= 2048) { + if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 2560) { + if (K <= 1280) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; + } else if (K <= 2560) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 3072) { + if (K <= 1536) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 3072) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; + } + } + } + + // Fallback to legacy heuristic + if (total_M <= 256) { + if (N <= 256) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (N <= 1024) { + return bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f; + } else if (N <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f; + } + } else if (total_M <= 1024) { + if (N <= 256) { + return bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f; + } else if (N <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (N <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (total_M <= 4096) { + if (N <= 256) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (total_M <= 16384) { + if (N <= 256) { + return bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f; + } else if (N <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (N <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; + } + } else { + if (N <= 256) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; + } else if (N <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (N <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; + } + } + } +} + +Kernel_bf16bf16bf16_grouped_wgrad get_kernel_via_tuning( + int arch, + int G, + int total_M, + int N, + int K, + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + auto& cache = getTuningCache(); + + // Reducing amount of auto tuning by rounding up total_m to next power of 2. + total_M = nextPowerOf2(total_M); + // Use (total_M, N, K, G) shape as the key. + const std::string shape_key = std::to_string(total_M) + "_" + + std::to_string(N) + "_" + std::to_string(K) + "_" + std::to_string(G); + const auto& kernels = get_bf16bf16bf16_grouped_wgrad_kernels(arch); + auto kernel = cache.findBestKernelMaybeAutotune( + shape_key, kernels, X, W, M_sizes, output, output_accum); + + return kernel; +} + +// BF16 grouped cutlass kernel dispatch. +at::Tensor dispatch_bf16_grouped_kernel( + int G, + int total_M, + int N, + int K, + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + const int arch = getDeviceArch(); + + // Select kernel to run via heuristics or tuning. + auto kernel = [&]() { + if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) { + return get_kernel_via_tuning( + arch, G, total_M, N, K, X, W, M_sizes, output, output_accum); + } else { + return get_wgrad_kernel_via_heuristic(arch, G, total_M, N, K); + } + }(); + // Invoke kernel + return kernel(X, W, M_sizes, output, output_accum); +} + +at::Tensor bf16bf16bf16_grouped_wgrad( + at::Tensor X, + at::Tensor W, + at::Tensor M_sizes, + std::optional output, + bool output_accum) { + int64_t total_M = X.size(0); + int64_t N = X.size(1); + int64_t K = W.size(1); + int64_t G = M_sizes.size(0); + TORCH_CHECK( + M_sizes.device() == X.device(), + "M_sizes must be on same device as inputs."); + TORCH_CHECK( + X.dim() == 2 && W.dim() == 2 && W.size(0) == total_M, + "Activations should be shape [GM, N] and weights should be shape [GM, K]") + + if (output_accum) { + TORCH_CHECK( + output.has_value(), "Must provide output tensor for output_accum=True"); + } + + at::Tensor Y; + if (output.has_value()) { + Y = output.value(); + TORCH_CHECK(Y.dtype() == at::kBFloat16); + } else { + Y = at::empty(G * N * K, X.options().dtype(at::kBFloat16)); + } + + // Early exit for empty inputs. + if (total_M == 0) { + return Y.view({G, N, K}); + } + // Return continuous view of output. + at::Tensor out = dispatch_bf16_grouped_kernel( + G, total_M, N, K, X, W, M_sizes, Y, output_accum); + return out.view({G, N, K}); +} + +#else + +at::Tensor bf16bf16bf16_grouped_wgrad( + at::Tensor, + at::Tensor, + at::Tensor, + std::optional, + bool) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f.cu new file mode 100644 index 0000000000..b841bf2f93 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f.cu @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 1, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f.cu new file mode 100644 index 0000000000..1d2103be34 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f.cu @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 2, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t.cu new file mode 100644 index 0000000000..a98ebb5bea --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 2, 1, true, true>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 2, 1, false, true>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t.cu new file mode 100644 index 0000000000..82cd72f5b4 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 4, 1, true, true>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 4, 1, false, true>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f.cu new file mode 100644 index 0000000000..2548259526 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t.cu new file mode 100644 index 0000000000..b567c8598d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 1, 1, true, true>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 1, 1, false, true>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t.cu new file mode 100644 index 0000000000..97e38d4806 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 2, 1, true, true>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 2, 1, false, true>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t.cu new file mode 100644 index 0000000000..18119354dd --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 4, 1, true, true>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 4, 1, false, true>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t.cu new file mode 100644 index 0000000000..a813cda85c --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 2, 1, true, true>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 2, 1, false, true>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_f.cu new file mode 100644 index 0000000000..6360182d90 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_f.cu @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 2, 2, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 4, + 2, + 2, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_t.cu new file mode 100644 index 0000000000..0dcc2876ec --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 2, 2, true, true>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 2, 2, false, true>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_f.cu new file mode 100644 index 0000000000..e88019f79f --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_f.cu @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 4, + 4, + 1, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t.cu new file mode 100644 index 0000000000..881b71a103 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 1, true, true>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 1, false, true>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_f.cu new file mode 100644 index 0000000000..322da70276 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_f.cu @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 2, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 4, + 4, + 2, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_t.cu new file mode 100644 index 0000000000..24a645af43 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 2, true, true>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 2, false, true>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_f.cu new file mode 100644 index 0000000000..0334a4c0c2 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_f.cu @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 4, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 4, + 4, + 4, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_t.cu new file mode 100644 index 0000000000..39e24fa1de --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 4, true, true>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 4, false, true>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f.cu new file mode 100644 index 0000000000..b023f51d9a --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f.cu @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 256, 128, 1, 1, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 256, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f.cu new file mode 100644 index 0000000000..c1f5bb81cd --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f.cu @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 256, 128, 1, 2, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 256, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f.cu new file mode 100644 index 0000000000..b8df243c35 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 1, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 1, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f.cu new file mode 100644 index 0000000000..40c04995b3 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 4, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 4, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f.cu new file mode 100644 index 0000000000..67f97db4c5 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 32, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 32, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f.cu new file mode 100644 index 0000000000..8c21b4f611 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 2, 2, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 2, 2, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f.cu new file mode 100644 index 0000000000..d3a9fb8e22 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 2, 4, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 2, 4, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f.cu new file mode 100644 index 0000000000..dc93f62074 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 2, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 2, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f.cu new file mode 100644 index 0000000000..bfaf4f6f96 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 4, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 4, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f.cu new file mode 100644 index 0000000000..af414928a5 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f.cu new file mode 100644 index 0000000000..266c300f4d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 2, 2, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 2, 2, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f.cu new file mode 100644 index 0000000000..acf12a2ff4 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 4, 2, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 4, 2, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f.cu new file mode 100644 index 0000000000..bda558d4be --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f.cu @@ -0,0 +1,35 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<256, 128, 128, 1, 2, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f.cu new file mode 100644 index 0000000000..5cb7c5401d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f.cu new file mode 100644 index 0000000000..9fff6b24e8 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_1_1_1_9_f.cu new file mode 100644 index 0000000000..aacf411946 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_1_1_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_256_32_128_1_1_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<256, 32, 128, 1, 1, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<256, 32, 128, 1, 1, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f.cu new file mode 100644 index 0000000000..453ddeaae9 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f.cu new file mode 100644 index 0000000000..743f359f95 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 64, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 64, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh new file mode 100644 index 0000000000..2a18078013 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh @@ -0,0 +1,646 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +// clang-format off +// The fixed ordering of the headers is required for CUTLASS 3.2+ +#include +#include // @manual +#include // @manual +#include // @manual +// clang-format on + +namespace fbgemm_gpu { + +inline int64_t _byte_align(int64_t offset) { + int64_t remainder = offset % 16; + if (remainder != 0) { + offset += (16 - remainder); + } + return offset; +} + +template < + typename ProblemShape, + typename ElementA, + typename ElementB, + typename ElementC, + typename StrideA, + typename StrideB, + typename StrideC> +__global__ void set_stacked_kernel_args_kernel( + int64_t G, + int64_t N, + int64_t K, + ProblemShape* problem_shape_ptr, + ElementA* x, + const ElementA** x_ptr, + ElementB* w, + const ElementB** w_ptr, + ElementC* output, + ElementC** output_ptr, + StrideA* stride_a_ptr, + StrideB* stride_b_ptr, + StrideC* stride_c_ptr, + int64_t* M_sizes) { + uint32_t group_index = blockIdx.x * blockDim.x + threadIdx.x; + // If this thread corresponds to a valid group, write kernel args to device + // memory. + if (group_index < G) { + // Its possible that we're only writing a subset of the groups to + // kernel args. To do this, we need to set all groups initially to empty. + // and keep a problem counter for the number of non-empty groups. + __shared__ int non_zero_counter; + // Initialize counter in first group. + if (group_index == 0) { + non_zero_counter = 0; + } + // Set problem shapes to empty by default. + problem_shape_ptr[group_index] = ProblemShape(0, 0, 0); + // Sync threads to get consistent state in the block. + __syncthreads(); + + // Compute shape for this group. + // M for this group is pulled directly from M_sizes. + int M = M_sizes[group_index]; + // Only proceed to writing kernel args if this group is non-empty. + if (M > 0) { + // Get the index for this group atomically. + int non_zero_idx = atomicAdd(&non_zero_counter, 1); + // We compute the offset by getting the cumulative sum over + // prior groups. + int64_t offset_M = 0; + for (int i = 0; i < group_index; i++) { + offset_M += M_sizes[i]; + } + // Set the problem shape for this group. + problem_shape_ptr[non_zero_idx] = ProblemShape(int(N), int(K), int(M)); + // Set input pointers. + x_ptr[non_zero_idx] = x + (offset_M * N); + w_ptr[non_zero_idx] = w + (offset_M * K); + output_ptr[non_zero_idx] = output + (N * K * group_index); + stride_a_ptr[non_zero_idx] = cutlass::make_cute_packed_stride( + StrideA{}, cute::make_shape(int(N), int(M), 1)); + stride_b_ptr[non_zero_idx] = cutlass::make_cute_packed_stride( + StrideB{}, cute::make_shape(int(K), int(M), 1)); + stride_c_ptr[non_zero_idx] = cutlass::make_cute_packed_stride( + StrideC{}, cute::make_shape(int(N), int(K), 1)); + } + } +} + +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool OUTPUT_ACCUM, + bool PONG> +at::Tensor bf16bf16bf16_grouped_wgrad_impl( + at::Tensor X, + at::Tensor W, + at::Tensor M_sizes, + at::Tensor output) { + int64_t G; + at::TensorOptions options; + G = M_sizes.size(0); + options = X.options(); + + // Return early if there are no elements in the output. + if (output.numel() == 0) { + return output; + } + + // Define gemm configuration. + using ProblemShape = + cutlass::gemm::GroupProblemShape>; + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + using ElementC = cutlass::bfloat16_t; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using TileShape = + cute::Shape, cute::Int, cute::Int>; + using ClusterShape = + cute::Shape, cute::Int, cute::Int>; + + using CooperativeSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using PongSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + + using CooperativeEpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using PongEpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + + using MainLoopSchedule = + cute::conditional_t; + using EpilogueSchedule = cute:: + conditional_t; + + using ComputeC = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, + ElementC, + ElementAccumulator, + cutlass::FloatRoundStyle::round_to_nearest>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using C_src = cutlass::epilogue::fusion::Sm90SrcFetch; + using EVTAddC = cutlass::epilogue::fusion::Sm90EVT; + + using CollectiveEpilogueDefault = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, // Use source tensor for epilogue operations. + LayoutC*, + 128 / cutlass::sizeof_bits::value, + ElementC, + LayoutC*, + 128 / cutlass::sizeof_bits::value, + EpilogueSchedule>::CollectiveOp; + + using CollectiveEpilogueAccum = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, // Use source tensor for epilogue operations. + LayoutC*, + 128 / cutlass::sizeof_bits::value, + ElementC, + LayoutC*, + 128 / cutlass::sizeof_bits::value, + EpilogueSchedule, + EVTAddC>::CollectiveOp; + + using CollectiveEpilogue = cute::conditional_t< + OUTPUT_ACCUM, + CollectiveEpilogueAccum, + CollectiveEpilogueDefault>; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutA*, + 128 / cutlass::sizeof_bits::value, + ElementB, + LayoutB*, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel:: + GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideD; + + // Create a buffer for kernel arguments. We do this by first figuring out + // how much space each sub-argument requires and setting up corresponding + // pointers. + const int64_t problem_size_offset = 0; + int64_t problem_size_buffer = + _byte_align(G * sizeof(ProblemShape::UnderlyingProblemShape)); + + // Next create space for X pointers. + const int64_t x_offset = problem_size_offset + problem_size_buffer; + int64_t x_size_buffer = _byte_align(G * sizeof(ElementA**)); + + // W Pointers. + const int64_t w_offset = x_offset + x_size_buffer; + int64_t w_size_buffer = _byte_align(G * sizeof(ElementB**)); + + // Outputs. + const int64_t output_offset = w_offset + w_size_buffer; + int64_t output_buffer = _byte_align(G * sizeof(ElementC**)); + + // A stride. + const int64_t stride_a_offset = output_offset + output_buffer; + int64_t stride_a_buffer = _byte_align(G * sizeof(StrideA)); + + // B stride; + const int64_t stride_b_offset = stride_a_offset + stride_a_buffer; + int64_t stride_b_buffer = _byte_align(G * sizeof(StrideB)); + + // C stride; + const int64_t stride_c_offset = stride_b_offset + stride_b_buffer; + int64_t stride_c_buffer = _byte_align(G * sizeof(StrideC)); + + // Compute total buffer size + int64_t total_buffer_size = stride_c_offset + stride_c_buffer; + + // Allocate space for gemm information. + at::Tensor kernel_args = + at::empty({total_buffer_size}, options.dtype(at::kByte)); + + // Get byte pointer to underlying data. + char* kernel_args_ptr = reinterpret_cast(kernel_args.data_ptr()); + + // Now use offsets to get appropriately typed pointers. + ProblemShape::UnderlyingProblemShape* problem_shape_ptr = + reinterpret_cast( + kernel_args_ptr + problem_size_offset); + const ElementA** x_ptr = + reinterpret_cast(kernel_args_ptr + x_offset); + const ElementB** w_ptr = + reinterpret_cast(kernel_args_ptr + w_offset); + ElementC** output_ptr = + reinterpret_cast(kernel_args_ptr + output_offset); + StrideA* stride_a_ptr = + reinterpret_cast(kernel_args_ptr + stride_a_offset); + StrideB* stride_b_ptr = + reinterpret_cast(kernel_args_ptr + stride_b_offset); + StrideC* stride_c_ptr = + reinterpret_cast(kernel_args_ptr + stride_c_offset); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK(M_sizes.dtype() == at::kLong, "M_sizes must be int64."); + int64_t total_M = X.size(0); + int64_t N = X.size(1); + int64_t K = W.size(1); + + int64_t* M_sizes_ptr = reinterpret_cast(M_sizes.data_ptr()); + set_stacked_kernel_args_kernel<<<1, G, 0, stream>>>( + G, + N, + K, + problem_shape_ptr, + reinterpret_cast(X.data_ptr()), + x_ptr, + reinterpret_cast(W.data_ptr()), + w_ptr, + reinterpret_cast(output.data_ptr()), + output_ptr, + stride_a_ptr, + stride_b_ptr, + stride_c_ptr, + M_sizes_ptr); + int kernel_groups = int(std::min(total_M, G)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {kernel_groups, problem_shape_ptr, nullptr}, + {x_ptr, stride_a_ptr, w_ptr, stride_b_ptr}, + {{}, + (const ElementC**)output_ptr, + stride_c_ptr, + output_ptr, + stride_c_ptr}}; + + int sm_count = at::cuda::getDeviceProperties(output.device().index()) + ->multiProcessorCount; + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + sm_count -= at::globalContext()._SMCarveout_EXPERIMENTAL().value(); + } + arguments.hw_info.sm_count = sm_count; + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + at::Tensor workspace = at::empty(workspace_size, options.dtype(at::kByte)); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize( + arguments, reinterpret_cast(workspace.data_ptr())); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +#if CUDART_VERSION >= 12080 +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool OUTPUT_ACCUM, + bool PONG> +at::Tensor bf16bf16bf16_grouped_wgrad_sm100_impl( + at::Tensor X, + at::Tensor W, + at::Tensor M_sizes, + at::Tensor output) { + int64_t G; + at::TensorOptions options; + G = M_sizes.size(0); + options = X.options(); + + // Return early if there are no elements in the output. + if (output.numel() == 0) { + return output; + } + + // Define gemm configuration. + using ProblemShape = + cutlass::gemm::GroupProblemShape>; + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + using ElementC = cutlass::bfloat16_t; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that + // supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using TileShape = + cute::Shape, cute::Int, cute::Int>; + using ClusterShape = + cute::Shape, cute::Int, cute::Int>; + + using MainLoopSchedule = cute::conditional_t< + (TBS_M % 2 == 0) || (TB_M == 256), + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100, + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100>; + using EpilogueSchedule = cute::conditional_t< + (TBS_M % 2 == 0) || (TB_M == 256), + cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm, + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm>; + + using ComputeC = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, + ElementC, + ElementAccumulator, + cutlass::FloatRoundStyle::round_to_nearest>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + using C_src = cutlass::epilogue::fusion::Sm90SrcFetch; + using EVTAddC = cutlass::epilogue::fusion::Sm90EVT; + + using CollectiveEpilogueDefault = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, // Use source tensor for epilogue operations. + LayoutC*, + 128 / cutlass::sizeof_bits::value, + ElementC, + LayoutC*, + 128 / cutlass::sizeof_bits::value, + EpilogueSchedule>::CollectiveOp; + + using CollectiveEpilogueAccum = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + cutlass::arch::OpClassTensorOp, + TileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementAccumulator, + ElementC, // Use source tensor for epilogue operations. + LayoutC*, + 128 / cutlass::sizeof_bits::value, + ElementC, + LayoutC*, + 128 / cutlass::sizeof_bits::value, + EpilogueSchedule, + EVTAddC>::CollectiveOp; + + using CollectiveEpilogue = cute::conditional_t< + OUTPUT_ACCUM, + CollectiveEpilogueAccum, + CollectiveEpilogueDefault>; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + LayoutA*, + 128 / cutlass::sizeof_bits::value, + ElementB, + LayoutB*, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + TileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel:: + GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideD; + + // Create a buffer for kernel arguments. We do this by first figuring out + // how much space each sub-argument requires and setting up corresponding + // pointers. + const int64_t problem_size_offset = 0; + int64_t problem_size_buffer = + _byte_align(G * sizeof(ProblemShape::UnderlyingProblemShape)); + + // Next create space for X pointers. + const int64_t x_offset = problem_size_offset + problem_size_buffer; + int64_t x_size_buffer = _byte_align(G * sizeof(ElementA**)); + + // W Pointers. + const int64_t w_offset = x_offset + x_size_buffer; + int64_t w_size_buffer = _byte_align(G * sizeof(ElementB**)); + + // Outputs. + const int64_t output_offset = w_offset + w_size_buffer; + int64_t output_buffer = _byte_align(G * sizeof(ElementC**)); + + // A stride. + const int64_t stride_a_offset = output_offset + output_buffer; + int64_t stride_a_buffer = _byte_align(G * sizeof(StrideA)); + + // B stride; + const int64_t stride_b_offset = stride_a_offset + stride_a_buffer; + int64_t stride_b_buffer = _byte_align(G * sizeof(StrideB)); + + // C stride; + const int64_t stride_c_offset = stride_b_offset + stride_b_buffer; + int64_t stride_c_buffer = _byte_align(G * sizeof(StrideC)); + + // Compute total buffer size + int64_t total_buffer_size = stride_c_offset + stride_c_buffer; + + // Allocate space for gemm information. + at::Tensor kernel_args = + at::empty({total_buffer_size}, options.dtype(at::kByte)); + + // Get byte pointer to underlying data. + char* kernel_args_ptr = reinterpret_cast(kernel_args.data_ptr()); + + // Now use offsets to get appropriately typed pointers. + ProblemShape::UnderlyingProblemShape* problem_shape_ptr = + reinterpret_cast( + kernel_args_ptr + problem_size_offset); + const ElementA** x_ptr = + reinterpret_cast(kernel_args_ptr + x_offset); + const ElementB** w_ptr = + reinterpret_cast(kernel_args_ptr + w_offset); + ElementC** output_ptr = + reinterpret_cast(kernel_args_ptr + output_offset); + StrideA* stride_a_ptr = + reinterpret_cast(kernel_args_ptr + stride_a_offset); + StrideB* stride_b_ptr = + reinterpret_cast(kernel_args_ptr + stride_b_offset); + StrideC* stride_c_ptr = + reinterpret_cast(kernel_args_ptr + stride_c_offset); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK(M_sizes.dtype() == at::kLong, "M_sizes must be int64."); + int64_t total_M = X.size(0); + int64_t N = X.size(1); + int64_t K = W.size(1); + + int64_t* M_sizes_ptr = reinterpret_cast(M_sizes.data_ptr()); + set_stacked_kernel_args_kernel<<<1, G, 0, stream>>>( + G, + N, + K, + problem_shape_ptr, + reinterpret_cast(X.data_ptr()), + x_ptr, + reinterpret_cast(W.data_ptr()), + w_ptr, + reinterpret_cast(output.data_ptr()), + output_ptr, + stride_a_ptr, + stride_b_ptr, + stride_c_ptr, + M_sizes_ptr); + int kernel_groups = int(std::min(total_M, G)); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {kernel_groups, problem_shape_ptr, nullptr}, + {x_ptr, stride_a_ptr, w_ptr, stride_b_ptr}, + {{}, + (const ElementC**)output_ptr, + stride_c_ptr, + output_ptr, + stride_c_ptr}}; + + int sm_count = at::cuda::getDeviceProperties(output.device().index()) + ->multiProcessorCount; + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + sm_count -= at::globalContext()._SMCarveout_EXPERIMENTAL().value(); + } + arguments.hw_info.sm_count = sm_count; + + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix + // multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + at::Tensor workspace = at::empty(workspace_size, options.dtype(at::kByte)); + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot implement"); + } + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize( + arguments, reinterpret_cast(workspace.data_ptr())); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("cutlass cannot initialize"); + } + + status = gemm(at::cuda::getCurrentCUDAStream()); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error( + std::string("cutlass cannot run") + + cutlass::cutlassGetStatusString(status)); + } + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +#else +template < + int TB_M, + int TB_N, + int TB_K, + int TBS_M, + int TBS_N, + int TBS_K, + bool OUTPUT_ACCUM, + bool PONG> +at::Tensor bf16bf16bf16_grouped_wgrad_sm100_impl( + at::Tensor X, + at::Tensor W, + at::Tensor M_sizes, + at::Tensor output) { + return output; +} +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh new file mode 100644 index 0000000000..15b2786a7d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh @@ -0,0 +1,260 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +using Kernel_bf16bf16bf16_grouped_wgrad = + at::Tensor (*)(at::Tensor, at::Tensor, at::Tensor, at::Tensor, bool); + +const std::unordered_map& +get_bf16bf16bf16_grouped_wgrad_kernels(int arch) { + static const std:: + unordered_map + kernelsSM90 = { + {"bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f", + bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f", + bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f", + bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f", + bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f", + bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f", + bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f", + bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f", + bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f", + bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t", + bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t", + bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t", + bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t", + bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t", + bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t", + bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t}, + {"bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f", + bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f", + bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f", + bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f}, + }; + static const std:: + unordered_map + kernelsSM100 = { + {"bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f", + bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f}, + {"bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f", + bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f}, + {"bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f", + bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f}, + {"bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f", + bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f}, + {"bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f", + bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f}, + {"bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f", + bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f", + bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f}, + }; + if (arch == 10) { + return kernelsSM100; + } else { + return kernelsSM90; + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 2de3346c3d..7a148e87b0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -80,6 +80,12 @@ at::Tensor bf16bf16bf16_grouped_stacked(at::Tensor X, at::Tensor W, at::Tensor M_sizes); at::Tensor bf16bf16bf16_grouped_grad(at::Tensor X, at::Tensor W, at::Tensor M_sizes); +at::Tensor bf16bf16bf16_grouped_wgrad( + at::Tensor X, + at::Tensor W, + at::Tensor M_sizes, + std::optional output = std::nullopt, + bool output_accum = false); at::Tensor f8f8bf16_rowwise( at::Tensor XQ, at::Tensor WQ, @@ -320,6 +326,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("f8i4bf16_shuffled_grouped", f8i4bf16_shuffled_grouped); m.impl("bf16i4bf16_shuffled_grouped", bf16i4bf16_shuffled_grouped); m.impl("bf16bf16bf16_grouped_grad", bf16bf16bf16_grouped_grad); + m.impl("bf16bf16bf16_grouped_wgrad", bf16bf16bf16_grouped_wgrad); m.impl("preshuffle_i4", preshuffle_i4); m.impl("bf16i4bf16_shuffled_batched", bf16i4bf16_shuffled_batched); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); @@ -804,6 +811,19 @@ at::Tensor bf16bf16bf16_grouped_grad_meta( return Y; } +at::Tensor bf16bf16bf16_grouped_wgrad_meta( + at::Tensor X, + at::Tensor W, + at::Tensor M_sizes, + std::optional /* output = std::nullopt */, + bool /* output_accum = false */) { + const at::SymInt G = M_sizes.size(0); + const at::SymInt N = X.sym_size(1); + const at::SymInt K = W.sym_size(1); + at::Tensor Y = at::empty_symint({G, N, K}, X.options().dtype(at::kBFloat16)); + return Y; +} + at::Tensor f8f8bf16_rowwise_grouped_stacked_meta( at::Tensor XQ, at::Tensor WQ, @@ -845,6 +865,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("bf16i4bf16_shuffled_batched", bf16i4bf16_shuffled_batched_meta); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta); m.impl("bf16bf16bf16_grouped_grad", bf16bf16bf16_grouped_grad_meta); + m.impl("bf16bf16bf16_grouped_wgrad", bf16bf16bf16_grouped_wgrad_meta); m.impl("f8f8bf16_lite", f8f8bf16_lite_meta); m.impl("scaled_fp4_quant", scaled_fp4_quant_meta); m.impl("preshuffle_i4", preshuffle_i4_meta); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp index 83b2106032..9f120073ad 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize_defs.cpp @@ -66,6 +66,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "bf16bf16bf16_grouped_stacked(Tensor X, Tensor W, Tensor M_sizes) -> Tensor"); m.def( "bf16bf16bf16_grouped_grad(Tensor X, Tensor W, Tensor M_sizes) -> Tensor"); + m.def( + "bf16bf16bf16_grouped_wgrad(Tensor X, Tensor W, Tensor M_sizes, Tensor(a!)? output=None, bool output_accum=False) -> Tensor"); m.def( "f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m=128, int block_n=128, int block_k=128) -> Tensor"); m.def( diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 469ccd0e2b..fb92b2f3ae 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -2198,5 +2198,116 @@ def test_fake_quantize_nvfp4_per_tensor( torch.testing.assert_close(fake_quant_y, y_ref, atol=0.1, rtol=0.1) +@unittest.skipIf( + not torch.cuda.is_available() + or torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9, + "Skip when MI300 or H100 is not available", +) +class BF16Tests(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.device = torch.accelerator.current_accelerator() + + @unittest.skipIf( + not torch.version.cuda, + "Skip on AMD: test_bf16_grouped_gemmw_wgrad not yet suported.", + ) + @settings(deadline=None) + @given( + G=st.sampled_from([2, 16]), + M=st.sampled_from([257, 2049]), + N=st.sampled_from([256, 2048]), + K=st.sampled_from([128, 1024]), + output_accum=st.booleans(), + ) + def test_bf16_grouped_gemmw_wgrad( + self, + G: int, + M: int, + N: int, + K: int, + output_accum: bool, + ) -> None: + torch.manual_seed(hash((G, M, N, K))) + # Inputs + dy_bf16 = torch.randn( + (M, N), dtype=torch.bfloat16, device=torch.accelerator.current_accelerator() + ) + x_bf16 = torch.randn( + (M, K), dtype=torch.bfloat16, device=torch.accelerator.current_accelerator() + ) + + def generate_random_splits(G: int, M: int) -> torch.Tensor: + m_cumsums = torch.sort( + torch.randint( + 0, + M, + (G + 1,), + dtype=torch.int32, + device=torch.accelerator.current_accelerator(), + ) + ).values + m_cumsums[0], m_cumsums[-1] = 0, M + m_sizes = m_cumsums[1:] - m_cumsums[:-1] + return m_sizes + + m_sizes = generate_random_splits(G, M) + + # Test + if output_accum: + wgrad_accum = torch.randn( + (G, N, K), + dtype=torch.bfloat16, + device=torch.accelerator.current_accelerator(), + ) + else: + wgrad_accum = None + + test_wgrad = torch.ops.fbgemm.bf16bf16bf16_grouped_wgrad( + dy_bf16, + x_bf16, + m_sizes.to(torch.int64), + output=wgrad_accum.clone() if output_accum else None, + output_accum=output_accum, + ) + + # Reference + dy_fp32 = dy_bf16.to(torch.float32) + x_fp32 = x_bf16.to(torch.float32) + ref_wgrad = torch.empty( + (G, N, K), + dtype=torch.float32, + device=torch.accelerator.current_accelerator(), + ) + + # Track which groups have non-zero size for comparison + non_zero_groups = [] + m_start = 0 + for g, m_size in enumerate(m_sizes.tolist()): + if m_size > 0: + # Actual slice - compute matrix multiplication + ref_wgrad[g, :, :] = ( + dy_fp32[m_start : m_start + m_size, :].T + @ x_fp32[m_start : m_start + m_size, :] + ) + non_zero_groups.append(g) + m_start += m_size + + if output_accum: + assert wgrad_accum is not None + ref_wgrad += wgrad_accum.to(torch.float32) + + ref_wgrad = ref_wgrad.to(torch.bfloat16) + + # Compare groups with non-zero m_size + if non_zero_groups: + torch.testing.assert_close( + test_wgrad[non_zero_groups], + ref_wgrad[non_zero_groups], + atol=1e-4, + rtol=1e-2, + ) + + if __name__ == "__main__": unittest.main() From cc0fd3cfd315127482073dc23a6d92a3d0c99a4f Mon Sep 17 00:00:00 2001 From: jiawenl Date: Mon, 22 Sep 2025 17:27:09 -0700 Subject: [PATCH 2/4] Support multiple total-k and total-m in quantize bench Differential Revision: D82700396 --- .../experimental/gen_ai/bench/quantize_bench.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py index c58908d101..898653d756 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py @@ -490,13 +490,15 @@ def print_kernels(kernels: Optional[List[str]]) -> List[QuantizeOpBase]: "--total-K", default=None, help="If set, adjusts the K values to sum to this number. " - "This can help simulate real grouped workloads in backward wgrad.", + "This can help simulate real grouped workloads in backward wgrad. " + "Comma separated list of total-K values to benchmark.", ) @click.option( "--total-M", default=None, help="If set, adjusts the M values to sum to this number. " - "This can help simulate real grouped workloads.", + "This can help simulate real grouped workloads." + "Comma separated list of total-M values to benchmark.", ) @click.option( "--no-cuda-graph", @@ -634,25 +636,29 @@ def invoke_main( if groups: groups_list = [int(g) for g in groups.strip().split(",")] if total_m: + total_m_list = [int(tm) for tm in total_m.strip().split(",")] MNK = [ [ [b] * g, - generate_group_tensor(g, int(total_m)), + generate_group_tensor(g, tm), [n] * g, [k] * g, ] for g in groups_list + for tm in total_m_list for b, _, n, k in MNK ] elif total_k: + total_k_list = [int(tk) for tk in total_k.strip().split(",")] MNK = [ [ [b] * g, [m] * g, [n] * g, - generate_group_tensor(g, int(total_k)), + generate_group_tensor(g, tk), ] for g in groups_list + for tk in total_k_list for b, m, n, _ in MNK ] else: From 73629a61c3a602f38331fd6c4051d9fee44a37ca Mon Sep 17 00:00:00 2001 From: jiawenl Date: Mon, 22 Sep 2025 17:27:09 -0700 Subject: [PATCH 3/4] Optimize wgrad CUTLASS grouped gemm Differential Revision: D82700455 --- .../bf16bf16bf16_grouped_wgrad.cu | 814 +++++++++++++++++- ...16_grouped_wgrad_128_128_128_1_1_1_9_t.cu} | 15 +- ...16_grouped_wgrad_128_128_128_1_4_1_9_f.cu} | 8 +- ...16_grouped_wgrad_128_128_128_2_1_1_9_f.cu} | 8 +- ...f16_grouped_wgrad_128_128_128_4_1_1_9_t.cu | 28 + ...bf16_grouped_wgrad_128_32_128_1_2_1_9_f.cu | 28 + ...bf16_grouped_wgrad_128_64_128_1_1_1_9_f.cu | 28 + ...bf16_grouped_wgrad_256_64_128_1_1_1_9_f.cu | 28 + ...bf16_grouped_wgrad_256_64_128_1_2_1_9_f.cu | 28 + ...bf16_grouped_wgrad_256_64_128_1_4_1_9_f.cu | 28 + .../bf16bf16bf16_grouped_wgrad_common.cuh | 7 +- .../bf16bf16bf16_grouped_wgrad_manifest.cuh | 104 ++- .../gen_ai/test/quantize/quantize_test.py | 29 +- 13 files changed, 1078 insertions(+), 75 deletions(-) rename fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/{bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f.cu => bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t.cu} (62%) rename fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/{bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f.cu => bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f.cu} (82%) rename fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/{bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f.cu => bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f.cu} (83%) create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f.cu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu index f6aa4adbba..cc314fdc61 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu @@ -132,7 +132,7 @@ get_wgrad_kernel_via_heuristic(int arch, int G, int total_M, int N, int K) { if (total_M == 65536) { if (N <= 512) { if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f; + return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; } else if (K <= 512) { return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; } @@ -148,13 +148,13 @@ get_wgrad_kernel_via_heuristic(int arch, int G, int total_M, int N, int K) { } } else if (N <= 1280) { if (K <= 640) { - return bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f; + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; } else if (K <= 1280) { return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; } } else if (N <= 1536) { if (K <= 768) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f; + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; } else if (K <= 1536) { return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; } @@ -182,51 +182,799 @@ get_wgrad_kernel_via_heuristic(int arch, int G, int total_M, int N, int K) { } // Fallback to legacy heuristic - if (total_M <= 256) { - if (N <= 256) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + if (total_M <= 128) { + if (N <= 128) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } + } else if (N <= 256) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } + } else if (N <= 512) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } + } else if (N <= 1024) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } + } else if (N <= 2048) { + if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } + } else if (N <= 4096) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 8192) { + if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else { + if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } + } + } else if (total_M <= 256) { + if (N <= 128) { + if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } + } else if (N <= 256) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } + } else if (N <= 512) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } + } else if (N <= 1024) { + if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } + } else if (N <= 2048) { + if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } + } else if (N <= 4096) { + if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 8192) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else { + if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } + } + } else if (total_M <= 512) { + if (N <= 128) { + if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } + } else if (N <= 256) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } + } else if (N <= 512) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } } else if (N <= 1024) { - return bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f; + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } + } else if (N <= 2048) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } } else if (N <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f; + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } + } else if (N <= 8192) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } } else { - return bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f; + if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } } } else if (total_M <= 1024) { - if (N <= 256) { - return bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f; + if (N <= 128) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } + } else if (N <= 256) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } + } else if (N <= 512) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } } else if (N <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } + } else if (N <= 2048) { + if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } } else if (N <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } + } else if (N <= 8192) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } + } + } else if (total_M <= 2048) { + if (N <= 128) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f; + } + } else if (N <= 256) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 512) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 1024) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 2048) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 4096) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } + } else if (N <= 8192) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } + } else { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } } } else if (total_M <= 4096) { - if (N <= 256) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; + if (N <= 128) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 256) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } + } else if (N <= 512) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 1024) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 2048) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 4096) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } + } else if (N <= 8192) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } } - } else if (total_M <= 16384) { - if (N <= 256) { - return bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f; + } else if (total_M <= 8192) { + if (N <= 128) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 256) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 512) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } } else if (N <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 2048) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } } else if (N <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 8192) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; + } } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } } } else { - if (N <= 256) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; + if (N <= 128) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 256) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 1024) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 512) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 512) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } } else if (N <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } + } else if (N <= 2048) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; + } } else if (N <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; + } + } else if (N <= 8192) { + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; + } } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; + if (K <= 128) { + return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + } else if (K <= 256) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 2048) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + } else if (K <= 4096) { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + } else if (K <= 8192) { + return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; + } else { + return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + } } } } @@ -308,7 +1056,15 @@ at::Tensor bf16bf16bf16_grouped_wgrad( at::Tensor Y; if (output.has_value()) { Y = output.value(); - TORCH_CHECK(Y.dtype() == at::kBFloat16); + if (output_accum) { + TORCH_CHECK( + Y.dtype() == at::kFloat, + "Output tensor must be Float32 when output_accum=True"); + } else { + TORCH_CHECK( + Y.dtype() == at::kBFloat16, + "Output tensor must be BFloat16 when output_accum=False"); + } } else { Y = at::empty(G * N * K, X.options().dtype(at::kBFloat16)); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t.cu similarity index 62% rename from fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f.cu rename to fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t.cu index bda558d4be..2f267b651a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t.cu @@ -10,25 +10,18 @@ namespace fbgemm_gpu { -at::Tensor bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f( +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor M_sizes, at::Tensor output, bool output_accum) { if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<256, 128, 128, 1, 2, 1, true, false>( + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 1, 1, true, true>( X, W, M_sizes, output); } else { - return bf16bf16bf16_grouped_wgrad_impl< - 256, - 128, - 128, - 1, - 2, - 1, - false, - false>(X, W, M_sizes, output); + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 1, 1, false, true>( + X, W, M_sizes, output); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f.cu similarity index 82% rename from fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f.cu rename to fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f.cu index b023f51d9a..871f991d9b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f.cu @@ -10,22 +10,22 @@ namespace fbgemm_gpu { -at::Tensor bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f( +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor M_sizes, at::Tensor output, bool output_accum) { if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 256, 128, 1, 1, 1, true, false>( + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 4, 1, true, false>( X, W, M_sizes, output); } else { return bf16bf16bf16_grouped_wgrad_impl< 128, - 256, + 128, 128, 1, - 1, + 4, 1, false, false>(X, W, M_sizes, output); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f.cu similarity index 83% rename from fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f.cu rename to fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f.cu index c1f5bb81cd..313f53caa1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f.cu @@ -10,23 +10,23 @@ namespace fbgemm_gpu { -at::Tensor bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f( +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor M_sizes, at::Tensor output, bool output_accum) { if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 256, 128, 1, 2, 1, true, false>( + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 1, 1, true, false>( X, W, M_sizes, output); } else { return bf16bf16bf16_grouped_wgrad_impl< 128, - 256, 128, - 1, + 128, 2, 1, + 1, false, false>(X, W, M_sizes, output); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t.cu new file mode 100644 index 0000000000..a5e68f7289 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 1, 1, true, true>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 1, 1, false, true>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f.cu new file mode 100644 index 0000000000..7ec4ac7170 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 2, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 2, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f.cu new file mode 100644 index 0000000000..ad54498384 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 1, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 1, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f.cu new file mode 100644 index 0000000000..2087b7a542 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 1, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 1, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f.cu new file mode 100644 index 0000000000..95383b1797 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 2, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 2, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f.cu new file mode 100644 index 0000000000..66d1e80e85 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "bf16bf16bf16_grouped_wgrad_common.cuh" + +namespace fbgemm_gpu { + +at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 4, 1, true, false>( + X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 4, 1, false, false>( + X, W, M_sizes, output); + } +} + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh index 2a18078013..45ee9e22b4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh @@ -127,7 +127,9 @@ at::Tensor bf16bf16bf16_grouped_wgrad_impl( cutlass::gemm::GroupProblemShape>; using ElementA = cutlass::bfloat16_t; using ElementB = cutlass::bfloat16_t; - using ElementC = cutlass::bfloat16_t; + using ElementC = + cute::conditional_t; + using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor; @@ -393,7 +395,8 @@ at::Tensor bf16bf16bf16_grouped_wgrad_sm100_impl( cutlass::gemm::GroupProblemShape>; using ElementA = cutlass::bfloat16_t; using ElementB = cutlass::bfloat16_t; - using ElementC = cutlass::bfloat16_t; + using ElementC = + cute::conditional_t; using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::RowMajor; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh index 15b2786a7d..0a1b28c8bf 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh @@ -117,70 +117,126 @@ at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t( at::Tensor output, bool output_accum); -at::Tensor bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f( +at::Tensor bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor M_sizes, at::Tensor output, bool output_accum); -at::Tensor bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f( +at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor M_sizes, at::Tensor output, bool output_accum); -at::Tensor bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f( +at::Tensor bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor M_sizes, at::Tensor output, bool output_accum); -at::Tensor bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f( +at::Tensor bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor M_sizes, at::Tensor output, bool output_accum); -at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f( +at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor M_sizes, at::Tensor output, bool output_accum); -at::Tensor bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f( +at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor M_sizes, at::Tensor output, bool output_accum); -at::Tensor bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f( +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor M_sizes, at::Tensor output, bool output_accum); -at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f( +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor M_sizes, at::Tensor output, bool output_accum); -at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f( +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor M_sizes, at::Tensor output, bool output_accum); -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f( +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f( + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum); + +at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f( at::Tensor X, // BF16 at::Tensor W, // BF16 at::Tensor M_sizes, @@ -225,12 +281,28 @@ get_bf16bf16bf16_grouped_wgrad_kernels(int arch) { bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t}, {"bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t", bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t}, - {"bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f", - bf16bf16bf16_grouped_wgrad_128_256_128_1_1_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f", - bf16bf16bf16_grouped_wgrad_128_256_128_1_2_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f", - bf16bf16bf16_grouped_wgrad_256_128_128_1_2_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f", + bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t", + bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f", + bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f", + bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t", + bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t}, + {"bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t", + bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t}, + {"bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f", + bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f", + bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f", + bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f", + bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f}, + {"bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f", + bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f}, }; static const std:: unordered_map diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index fb92b2f3ae..53f6cd12da 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -2257,7 +2257,7 @@ def generate_random_splits(G: int, M: int) -> torch.Tensor: if output_accum: wgrad_accum = torch.randn( (G, N, K), - dtype=torch.bfloat16, + dtype=torch.float32, device=torch.accelerator.current_accelerator(), ) else: @@ -2271,6 +2271,9 @@ def generate_random_splits(G: int, M: int) -> torch.Tensor: output_accum=output_accum, ) + if output_accum: + assert test_wgrad.dtype == torch.float32 + # Reference dy_fp32 = dy_bf16.to(torch.float32) x_fp32 = x_bf16.to(torch.float32) @@ -2295,18 +2298,26 @@ def generate_random_splits(G: int, M: int) -> torch.Tensor: if output_accum: assert wgrad_accum is not None - ref_wgrad += wgrad_accum.to(torch.float32) + ref_wgrad += wgrad_accum - ref_wgrad = ref_wgrad.to(torch.bfloat16) + ref_wgrad = ref_wgrad.to(test_wgrad.dtype) # Compare groups with non-zero m_size if non_zero_groups: - torch.testing.assert_close( - test_wgrad[non_zero_groups], - ref_wgrad[non_zero_groups], - atol=1e-4, - rtol=1e-2, - ) + if test_wgrad.dtype == torch.float32: + torch.testing.assert_close( + test_wgrad[non_zero_groups], + ref_wgrad[non_zero_groups], + atol=1e-4, + rtol=1e-4, + ) + else: + torch.testing.assert_close( + test_wgrad[non_zero_groups], + ref_wgrad[non_zero_groups], + atol=1e-4, + rtol=1e-2, + ) if __name__ == "__main__": From 8aaf3a7eaae595b56b508bc6334a916e372fd133 Mon Sep 17 00:00:00 2001 From: Jiawen Liu Date: Mon, 22 Sep 2025 18:32:08 -0700 Subject: [PATCH 4/4] Resolve wgrad grouped gemm relocation issue in fbgemm Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1939 This Diff resolves wgrad cutlass grouped gemm relocation issue in fbgemm in short-term For long term, we might need to refactor fbgemm genai quantize ops by spliting targets (or other more elegant solution) to reduce fbgemm binary sizes for internal buck build for xfn usage to avoid issue such as T238469849 Differential Revision: D83020887 --- .../bf16bf16bf16_grouped_wgrad.cu | 4273 +++++++++++++---- ...f16_grouped_wgrad_128_128_128_1_1_1_9_f.cu | 35 - ...f16_grouped_wgrad_128_128_128_1_1_1_9_t.cu | 28 - ...f16_grouped_wgrad_128_128_128_1_2_1_9_f.cu | 35 - ...f16_grouped_wgrad_128_128_128_1_2_1_9_t.cu | 28 - ...f16_grouped_wgrad_128_128_128_1_4_1_9_f.cu | 35 - ...f16_grouped_wgrad_128_128_128_1_4_1_9_t.cu | 28 - ...16_grouped_wgrad_128_128_128_2_1_1_10_f.cu | 42 - ...f16_grouped_wgrad_128_128_128_2_1_1_9_f.cu | 35 - ...f16_grouped_wgrad_128_128_128_2_1_1_9_t.cu | 28 - ...f16_grouped_wgrad_128_128_128_2_2_1_9_t.cu | 28 - ...f16_grouped_wgrad_128_128_128_2_4_1_9_t.cu | 28 - ...f16_grouped_wgrad_128_128_128_4_1_1_9_t.cu | 28 - ...f16_grouped_wgrad_128_128_128_4_2_1_9_t.cu | 28 - ...f16_grouped_wgrad_128_128_128_4_2_2_9_f.cu | 35 - ...f16_grouped_wgrad_128_128_128_4_2_2_9_t.cu | 28 - ...f16_grouped_wgrad_128_128_128_4_4_1_9_f.cu | 35 - ...f16_grouped_wgrad_128_128_128_4_4_1_9_t.cu | 28 - ...f16_grouped_wgrad_128_128_128_4_4_2_9_f.cu | 35 - ...f16_grouped_wgrad_128_128_128_4_4_2_9_t.cu | 28 - ...f16_grouped_wgrad_128_128_128_4_4_4_9_f.cu | 35 - ...f16_grouped_wgrad_128_128_128_4_4_4_9_t.cu | 28 - ...bf16_grouped_wgrad_128_32_128_1_1_1_9_f.cu | 28 - ...bf16_grouped_wgrad_128_32_128_1_2_1_9_f.cu | 28 - ...bf16_grouped_wgrad_128_32_128_1_4_1_9_f.cu | 28 - ...f16_grouped_wgrad_128_32_128_2_1_1_10_f.cu | 42 - ...bf16_grouped_wgrad_128_32_128_2_2_1_9_f.cu | 28 - ...bf16_grouped_wgrad_128_32_128_2_4_1_9_f.cu | 28 - ...bf16_grouped_wgrad_128_64_128_1_1_1_9_f.cu | 28 - ...bf16_grouped_wgrad_128_64_128_1_2_1_9_f.cu | 28 - ...bf16_grouped_wgrad_128_64_128_1_4_1_9_f.cu | 28 - ...f16_grouped_wgrad_128_64_128_2_1_1_10_f.cu | 42 - ...bf16_grouped_wgrad_128_64_128_2_2_1_9_f.cu | 28 - ...bf16_grouped_wgrad_128_64_128_4_2_1_9_f.cu | 28 - ...16_grouped_wgrad_256_128_128_2_1_1_10_f.cu | 42 - ...16_grouped_wgrad_256_256_128_2_1_1_10_f.cu | 42 - ...bf16_grouped_wgrad_256_32_128_1_1_1_9_f.cu | 28 - ...f16_grouped_wgrad_256_32_128_2_1_1_10_f.cu | 42 - ...bf16_grouped_wgrad_256_64_128_1_1_1_9_f.cu | 28 - ...bf16_grouped_wgrad_256_64_128_1_2_1_9_f.cu | 28 - ...bf16_grouped_wgrad_256_64_128_1_4_1_9_f.cu | 28 - ...f16_grouped_wgrad_256_64_128_2_1_1_10_f.cu | 42 - .../bf16bf16bf16_grouped_wgrad_manifest.cuh | 620 +-- 43 files changed, 3784 insertions(+), 2411 deletions(-) delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_t.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_t.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_t.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_1_1_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f.cu delete mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f.cu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu index cc314fdc61..9850da3cfb 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad.cu @@ -9,6 +9,7 @@ #include #include +#include "bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_common.cuh" #include "bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh" #include "fbgemm_gpu/quantize/tuning_cache.hpp" #include "fbgemm_gpu/quantize/utils.h" @@ -25,106 +26,1698 @@ TuningCache& getTuningCache() { } } // namespace -Kernel_bf16bf16bf16_grouped_wgrad -get_wgrad_kernel_via_heuristic(int arch, int G, int total_M, int N, int K) { +// TODO: Re-enable the follow when the fbgemm relocation issue is resolved +// Kernel_bf16bf16bf16_grouped_wgrad +// get_wgrad_kernel_via_heuristic(int arch, int G, int total_M, int N, int K) { +// // Use heuristics to pick best kernel implementation. +// if (arch == 10) { +// // Llama4 shapes +// if ((N == 5120 && K == 1024) || (N == 2048 && K == 5120)) { +// if (total_M <= 256) { +// return bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f; +// } else if (total_M <= 512) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f; +// } else if (total_M <= 1024) { +// return bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f; +// } +// } + +// // Fallback to legacy heuristic. +// if (total_M <= 64 || (total_M <= 256 and N <= 1024)) { +// if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f; +// } +// } else if (total_M <= 512) { +// if (N <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f; +// } else if (N <= 8192) { +// if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f; +// } +// } +// } else if (total_M <= 1024) { +// if (N <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f; +// } else if (N <= 8192) { +// if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f; +// } +// } +// } else if (total_M <= 2048) { +// if (N <= 1024) { +// return bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f; +// } else if (N <= 8192) { +// if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f; +// } +// } +// } +// return bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f; +// } else { // arch == 9 +// // Llama4.x pretraining +// if (total_M == 8192) { +// if (N == 2560) { +// if (K == 1280) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t; +// } else if (K == 5120) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } +// } else if (N == 3072) { +// if (K == 1536) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } else if (K == 6144) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N == 5120) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (N == 6144) { +// if (K == 1536) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } else if (K == 6144) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } +// } + +// if (total_M == 16384) { +// if (N == 2560 || N == 3072) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (N == 5120) { +// if (K == 1280) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K == 5120) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } +// } else if (N == 6144) { +// if (K == 1536) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } else if (K == 6144) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; +// } +// } +// } + +// if (total_M == 65536) { +// if (N <= 512) { +// if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; +// } +// } else if (N <= 768) { +// if (K <= 384) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 768) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1024) { +// if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1280) { +// if (K <= 640) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 1280) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; +// } +// } else if (N <= 1536) { +// if (K <= 768) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 1536) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1792) { +// if (K <= 1792) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; +// } +// } else if (N <= 2048) { +// if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 2560) { +// if (K <= 1280) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; +// } else if (K <= 2560) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 3072) { +// if (K <= 1536) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 3072) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; +// } +// } +// } + +// // Fallback to legacy heuristic +// if (total_M <= 128) { +// if (N <= 128) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 1024) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 2048) { +// if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 8192) { +// if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else { +// if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } +// } else if (total_M <= 256) { +// if (N <= 128) { +// if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 1024) { +// if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 2048) { +// if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 4096) { +// if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else { +// if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } +// } else if (total_M <= 512) { +// if (N <= 128) { +// if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 1024) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 2048) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else { +// if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } +// } else if (total_M <= 1024) { +// if (N <= 128) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } +// } else if (N <= 1024) { +// if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 2048) { +// if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } +// } +// } else if (total_M <= 2048) { +// if (N <= 128) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1024) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 2048) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } else { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } +// } else if (total_M <= 4096) { +// if (N <= 128) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1024) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 2048) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } +// } else if (total_M <= 8192) { +// if (N <= 128) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1024) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 2048) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } +// } else { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } +// } +// } else { +// if (N <= 128) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 256) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 1024) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 512) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 512) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 1024) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } +// } else if (N <= 2048) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } +// } else if (N <= 4096) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } +// } else if (N <= 8192) { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } +// } else { +// if (K <= 128) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; +// } else if (K <= 256) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 2048) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } else if (K <= 4096) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; +// } else if (K <= 8192) { +// return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; +// } else { +// return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; +// } +// } +// } +// } +// } + +Kernel_bf16bf16bf16_grouped_wgrad get_kernel_via_tuning( + int arch, + int G, + int total_M, + int N, + int K, + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + auto& cache = getTuningCache(); + + // Reducing amount of auto tuning by rounding up total_m to next power of 2. + total_M = nextPowerOf2(total_M); + // Use (total_M, N, K, G) shape as the key. + const std::string shape_key = std::to_string(total_M) + "_" + + std::to_string(N) + "_" + std::to_string(K) + "_" + std::to_string(G); + const auto& kernels = get_bf16bf16bf16_grouped_wgrad_kernels(arch); + auto kernel = cache.findBestKernelMaybeAutotune( + shape_key, kernels, X, W, M_sizes, output, output_accum); + + return kernel; +} + +// BF16 grouped cutlass kernel dispatch. +at::Tensor dispatch_bf16_grouped_kernel( + int G, + int total_M, + int N, + int K, + at::Tensor X, // BF16 + at::Tensor W, // BF16 + at::Tensor M_sizes, + at::Tensor output, + bool output_accum) { + const int arch = getDeviceArch(); + // Use heuristics to pick best kernel implementation. if (arch == 10) { // Llama4 shapes if ((N == 5120 && K == 1024) || (N == 2048 && K == 5120)) { if (total_M <= 256) { - return bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (total_M <= 512) { - return bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 64, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 64, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (total_M <= 1024) { - return bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else { - return bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } } // Fallback to legacy heuristic. - if (total_M <= 64 || (total_M <= 256 and N <= 1024)) { + if (total_M <= 64 || (total_M <= 256 && N <= 1024)) { if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else { - return bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 32, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 32, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } } else if (total_M <= 512) { if (N <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (N <= 8192) { if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 32, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 32, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 32, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } } } else if (total_M <= 1024) { if (N <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (N <= 8192) { if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 64, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 64, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 64, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } } } else if (total_M <= 2048) { if (N <= 1024) { - return bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (N <= 8192) { if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } } } - return bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_sm100_impl< + 256, + 256, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else { // arch == 9 // Llama4.x pretraining if (total_M == 8192) { if (N == 2560) { if (K == 1280) { - return bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 4, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 4, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } else if (K == 5120) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } } } else if (N == 3072) { if (K == 1536) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } } else if (K == 6144) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } } else if (N == 5120) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } else if (N == 6144) { if (K == 1536) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } } else if (K == 6144) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } } } if (total_M == 16384) { if (N == 2560 || N == 3072) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } else if (N == 5120) { if (K == 1280) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } else if (K == 5120) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } } } else if (N == 6144) { if (K == 1536) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } } else if (K == 6144) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } } } } @@ -132,51 +1725,351 @@ get_wgrad_kernel_via_heuristic(int arch, int G, int total_M, int N, int K) { if (total_M == 65536) { if (N <= 512) { if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } } } else if (N <= 768) { if (K <= 384) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 768) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } } else if (N <= 1024) { if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } } else if (N <= 1280) { if (K <= 640) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + true>(X, W, M_sizes, output); + } } else if (K <= 1280) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } } else if (N <= 1536) { if (K <= 768) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } else if (K <= 1536) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } } else if (N <= 1792) { if (K <= 1792) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 4, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 4, + 1, + false, + true>(X, W, M_sizes, output); + } } } else if (N <= 2048) { if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } } else if (N <= 2560) { if (K <= 1280) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } else if (K <= 2560) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } } else if (N <= 3072) { if (K <= 1536) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } else if (K <= 3072) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } } } @@ -185,850 +2078,1626 @@ get_wgrad_kernel_via_heuristic(int arch, int G, int total_M, int N, int K) { if (total_M <= 128) { if (N <= 128) { if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 4, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 4, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + true>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } } else if (N <= 256) { if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } } else if (N <= 512) { if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } } else if (N <= 1024) { if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 4, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 4, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } } else if (N <= 2048) { if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } } else if (N <= 4096) { if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } } else if (N <= 8192) { if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } } else { if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + true>(X, W, M_sizes, output); + } } } } else if (total_M <= 256) { if (N <= 128) { if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } } else if (N <= 256) { if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 32, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } } else if (N <= 512) { if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 64, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } } else if (N <= 1024) { if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } } else if (N <= 2048) { if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } } else if (N <= 4096) { if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } } else if (N <= 8192) { if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 2, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 256, + 64, + 128, + 1, + 2, + 1, + false, + false>(X, W, M_sizes, output); + } } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 2, + 1, + false, + true>(X, W, M_sizes, output); + } } } else { if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } - } - } else if (total_M <= 512) { - if (N <= 128) { - if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } - } else if (N <= 256) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } - } else if (N <= 512) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } - } else if (N <= 1024) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } - } else if (N <= 2048) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } - } else if (N <= 4096) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } - } else if (N <= 8192) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else { - if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } - } - } else if (total_M <= 1024) { - if (N <= 128) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } - } else if (N <= 256) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } - } else if (N <= 512) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } - } else if (N <= 1024) { - if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } - } else if (N <= 2048) { - if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } - } else if (N <= 4096) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } - } else if (N <= 8192) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } - } - } else if (total_M <= 2048) { - if (N <= 128) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f; - } - } else if (N <= 256) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 512) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 1024) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 2048) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 4096) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } - } else if (N <= 8192) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } - } else { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } - } else if (total_M <= 4096) { - if (N <= 128) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 256) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } - } else if (N <= 512) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 1024) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 2048) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 4096) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } - } else if (N <= 8192) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } - } else if (total_M <= 8192) { - if (N <= 128) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 256) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 512) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 1024) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 2048) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 4096) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 8192) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; - } - } else { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } - } - } else { - if (N <= 128) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 256) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 1024) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 512) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 512) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 1024) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } - } else if (N <= 2048) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; - } - } else if (N <= 4096) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; - } - } else if (N <= 8192) { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; - } - } else { - if (K <= 128) { - return bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f; - } else if (K <= 256) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 2048) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; - } else if (K <= 4096) { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t; - } else if (K <= 8192) { - return bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t; - } else { - return bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t; + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + false>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + false>(X, W, M_sizes, output); + } + } else { + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + true, + true>(X, W, M_sizes, output); + } else { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 2, + 1, + 1, + false, + true>(X, W, M_sizes, output); + } } } } - } -} -Kernel_bf16bf16bf16_grouped_wgrad get_kernel_via_tuning( - int arch, - int G, - int total_M, - int N, - int K, - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - auto& cache = getTuningCache(); - - // Reducing amount of auto tuning by rounding up total_m to next power of 2. - total_M = nextPowerOf2(total_M); - // Use (total_M, N, K, G) shape as the key. - const std::string shape_key = std::to_string(total_M) + "_" + - std::to_string(N) + "_" + std::to_string(K) + "_" + std::to_string(G); - const auto& kernels = get_bf16bf16bf16_grouped_wgrad_kernels(arch); - auto kernel = cache.findBestKernelMaybeAutotune( - shape_key, kernels, X, W, M_sizes, output, output_accum); - - return kernel; -} - -// BF16 grouped cutlass kernel dispatch. -at::Tensor dispatch_bf16_grouped_kernel( - int G, - int total_M, - int N, - int K, - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - const int arch = getDeviceArch(); - - // Select kernel to run via heuristics or tuning. - auto kernel = [&]() { - if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) { - return get_kernel_via_tuning( - arch, G, total_M, N, K, X, W, M_sizes, output, output_accum); + // Default fallback + if (output_accum) { + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + true, + false>(X, W, M_sizes, output); } else { - return get_wgrad_kernel_via_heuristic(arch, G, total_M, N, K); + return bf16bf16bf16_grouped_wgrad_impl< + 128, + 128, + 128, + 1, + 1, + 1, + false, + false>(X, W, M_sizes, output); } - }(); - // Invoke kernel - return kernel(X, W, M_sizes, output, output_accum); + } } at::Tensor bf16bf16bf16_grouped_wgrad( diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f.cu deleted file mode 100644 index b841bf2f93..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f.cu +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 1, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl< - 128, - 128, - 128, - 1, - 1, - 1, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t.cu deleted file mode 100644 index 2f267b651a..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 1, 1, true, true>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 1, 1, false, true>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f.cu deleted file mode 100644 index 1d2103be34..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f.cu +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 2, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl< - 128, - 128, - 128, - 1, - 2, - 1, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t.cu deleted file mode 100644 index a98ebb5bea..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 2, 1, true, true>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 2, 1, false, true>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f.cu deleted file mode 100644 index 871f991d9b..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f.cu +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 4, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl< - 128, - 128, - 128, - 1, - 4, - 1, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t.cu deleted file mode 100644 index 82cd72f5b4..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 4, 1, true, true>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 1, 4, 1, false, true>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f.cu deleted file mode 100644 index 2548259526..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 128, - 128, - 128, - 2, - 1, - 1, - true, - false>(X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 128, - 128, - 128, - 2, - 1, - 1, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f.cu deleted file mode 100644 index 313f53caa1..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f.cu +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 1, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl< - 128, - 128, - 128, - 2, - 1, - 1, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t.cu deleted file mode 100644 index b567c8598d..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 1, 1, true, true>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 1, 1, false, true>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t.cu deleted file mode 100644 index 97e38d4806..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 2, 1, true, true>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 2, 1, false, true>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t.cu deleted file mode 100644 index 18119354dd..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 4, 1, true, true>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 2, 4, 1, false, true>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t.cu deleted file mode 100644 index a5e68f7289..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 1, 1, true, true>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 1, 1, false, true>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t.cu deleted file mode 100644 index a813cda85c..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 2, 1, true, true>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 2, 1, false, true>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_f.cu deleted file mode 100644 index 6360182d90..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_f.cu +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 2, 2, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl< - 128, - 128, - 128, - 4, - 2, - 2, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_t.cu deleted file mode 100644 index 0dcc2876ec..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_2_2_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 2, 2, true, true>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 2, 2, false, true>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_f.cu deleted file mode 100644 index e88019f79f..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_f.cu +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl< - 128, - 128, - 128, - 4, - 4, - 1, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t.cu deleted file mode 100644 index 881b71a103..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 1, true, true>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 1, false, true>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_f.cu deleted file mode 100644 index 322da70276..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_f.cu +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 2, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl< - 128, - 128, - 128, - 4, - 4, - 2, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_t.cu deleted file mode 100644 index 24a645af43..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_2_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 2, true, true>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 2, false, true>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_f.cu deleted file mode 100644 index 0334a4c0c2..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_f.cu +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 4, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl< - 128, - 128, - 128, - 4, - 4, - 4, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_t.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_t.cu deleted file mode 100644 index 39e24fa1de..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_t.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_4_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 4, true, true>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 128, 128, 4, 4, 4, false, true>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f.cu deleted file mode 100644 index b8df243c35..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 1, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 1, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f.cu deleted file mode 100644 index 7ec4ac7170..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 2, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 2, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f.cu deleted file mode 100644 index 40c04995b3..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 4, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 1, 4, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f.cu deleted file mode 100644 index 67f97db4c5..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 128, - 32, - 128, - 2, - 1, - 1, - true, - false>(X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 128, - 32, - 128, - 2, - 1, - 1, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f.cu deleted file mode 100644 index 8c21b4f611..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 2, 2, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 2, 2, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f.cu deleted file mode 100644 index d3a9fb8e22..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 2, 4, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 32, 128, 2, 4, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f.cu deleted file mode 100644 index ad54498384..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 1, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 1, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f.cu deleted file mode 100644 index dc93f62074..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 2, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 2, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f.cu deleted file mode 100644 index bfaf4f6f96..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 4, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 1, 4, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f.cu deleted file mode 100644 index af414928a5..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 128, - 64, - 128, - 2, - 1, - 1, - true, - false>(X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 128, - 64, - 128, - 2, - 1, - 1, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f.cu deleted file mode 100644 index 266c300f4d..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 2, 2, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 2, 2, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f.cu deleted file mode 100644 index acf12a2ff4..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 4, 2, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<128, 64, 128, 4, 2, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f.cu deleted file mode 100644 index 5cb7c5401d..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 256, - 128, - 128, - 2, - 1, - 1, - true, - false>(X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 256, - 128, - 128, - 2, - 1, - 1, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f.cu deleted file mode 100644 index 9fff6b24e8..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 256, - 256, - 128, - 2, - 1, - 1, - true, - false>(X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 256, - 256, - 128, - 2, - 1, - 1, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_1_1_1_9_f.cu deleted file mode 100644 index aacf411946..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_1_1_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_256_32_128_1_1_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<256, 32, 128, 1, 1, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<256, 32, 128, 1, 1, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f.cu deleted file mode 100644 index 453ddeaae9..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 256, - 32, - 128, - 2, - 1, - 1, - true, - false>(X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 256, - 32, - 128, - 2, - 1, - 1, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f.cu deleted file mode 100644 index 2087b7a542..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 1, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 1, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f.cu deleted file mode 100644 index 95383b1797..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 2, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 2, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f.cu deleted file mode 100644 index 66d1e80e85..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 4, 1, true, false>( - X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_impl<256, 64, 128, 1, 4, 1, false, false>( - X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f.cu deleted file mode 100644 index 743f359f95..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include "bf16bf16bf16_grouped_wgrad_common.cuh" - -namespace fbgemm_gpu { - -at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum) { - if (output_accum) { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 256, - 64, - 128, - 2, - 1, - 1, - true, - false>(X, W, M_sizes, output); - } else { - return bf16bf16bf16_grouped_wgrad_sm100_impl< - 256, - 64, - 128, - 2, - 1, - 1, - false, - false>(X, W, M_sizes, output); - } -} - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh index 0a1b28c8bf..64375ea2e8 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/bf16bf16bf16_grouped_wgrad/bf16bf16bf16_grouped_wgrad_manifest.cuh @@ -12,321 +12,327 @@ namespace fbgemm_gpu { -at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); - -at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f( - at::Tensor X, // BF16 - at::Tensor W, // BF16 - at::Tensor M_sizes, - at::Tensor output, - bool output_accum); +// at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); + +// at::Tensor bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f( +// at::Tensor X, // BF16 +// at::Tensor W, // BF16 +// at::Tensor M_sizes, +// at::Tensor output, +// bool output_accum); using Kernel_bf16bf16bf16_grouped_wgrad = at::Tensor (*)(at::Tensor, at::Tensor, at::Tensor, at::Tensor, bool); const std::unordered_map& get_bf16bf16bf16_grouped_wgrad_kernels(int arch) { + // Return a static empty map static const std:: unordered_map - kernelsSM90 = { - {"bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f", - bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f", - bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f", - bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f", - bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f", - bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f", - bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f", - bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f", - bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f", - bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t", - bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t", - bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t", - bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t", - bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t", - bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t", - bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f", - bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t", - bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f", - bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f", - bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t", - bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t", - bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t}, - {"bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f", - bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f", - bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f", - bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f", - bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f}, - {"bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f", - bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f}, - }; - static const std:: - unordered_map - kernelsSM100 = { - {"bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f", - bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f}, - {"bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f", - bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f}, - {"bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f", - bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f}, - {"bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f", - bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f}, - {"bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f", - bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f}, - {"bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f", - bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f}, - {"bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f", - bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f}, - }; - if (arch == 10) { - return kernelsSM100; - } else { - return kernelsSM90; - } + empty_map; + return empty_map; + + // static const std:: + // unordered_map + // kernelsSM90 = { + // {"bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_32_128_1_1_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_32_128_1_4_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_32_128_2_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_32_128_2_4_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_64_128_1_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_64_128_1_4_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_64_128_2_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_64_128_4_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_1_2_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_2_2_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_2_4_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_4_2_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_1_1_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_128_128_1_4_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_4_1_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t", + // bf16bf16bf16_grouped_wgrad_128_128_128_4_4_1_9_t}, + // {"bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_32_128_1_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f", + // bf16bf16bf16_grouped_wgrad_128_64_128_1_1_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f", + // bf16bf16bf16_grouped_wgrad_256_64_128_1_1_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f", + // bf16bf16bf16_grouped_wgrad_256_64_128_1_2_1_9_f}, + // {"bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f", + // bf16bf16bf16_grouped_wgrad_256_64_128_1_4_1_9_f}, + // }; + // static const std:: + // unordered_map + // kernelsSM100 = { + // {"bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_256_32_128_2_1_1_10_f}, + // {"bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_256_64_128_2_1_1_10_f}, + // {"bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_256_128_128_2_1_1_10_f}, + // {"bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_256_256_128_2_1_1_10_f}, + // {"bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_128_32_128_2_1_1_10_f}, + // {"bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_128_64_128_2_1_1_10_f}, + // {"bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f", + // bf16bf16bf16_grouped_wgrad_128_128_128_2_1_1_10_f}, + // }; + // if (arch == 10) { + // return kernelsSM100; + // } else { + // return kernelsSM90; + // } } } // namespace fbgemm_gpu