|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <ATen/ATen.h> |
| 10 | +#include <ATen/cuda/CUDAContext.h> |
| 11 | + |
| 12 | +#include "bf16bf16bf16_grouped_grad/bf16bf16bf16_grouped_grad_manifest.cuh" |
| 13 | +#include "fbgemm_gpu/quantize/tuning_cache.hpp" |
| 14 | +#include "fbgemm_gpu/quantize/utils.h" |
| 15 | + |
| 16 | +namespace fbgemm_gpu { |
| 17 | + |
| 18 | +#if CUDART_VERSION >= 12000 |
| 19 | + |
| 20 | +namespace { |
| 21 | +TuningCache& getTuningCache() { |
| 22 | + static TuningCache cache("bf16bf16bf16_grouped_grad"); |
| 23 | + return cache; |
| 24 | +} |
| 25 | +} // namespace |
| 26 | + |
| 27 | +Kernel_bf16bf16bf16_grouped_grad |
| 28 | +get_kernel_via_heuristic(int arch, int G, int total_M, int N, int K) { |
| 29 | + // Use heuristics to pick best kernel implementation. |
| 30 | + if (arch == 10) { |
| 31 | + // Llama4 shapes |
| 32 | + if ((N == 5120 && K == 1024) || (N == 2048 && K == 5120)) { |
| 33 | + if (total_M <= 256) { |
| 34 | + return bf16bf16bf16_grouped_grad_256_32_128_2_1_1_10_f; |
| 35 | + } else if (total_M <= 512) { |
| 36 | + return bf16bf16bf16_grouped_grad_256_64_128_2_1_1_10_f; |
| 37 | + } else if (total_M <= 1024) { |
| 38 | + return bf16bf16bf16_grouped_grad_256_128_128_2_1_1_10_f; |
| 39 | + } else { |
| 40 | + return bf16bf16bf16_grouped_grad_256_256_128_2_1_1_10_f; |
| 41 | + } |
| 42 | + } |
| 43 | + |
| 44 | + // Fallback to legacy heuristic. |
| 45 | + if (total_M <= 64 || (total_M <= 256 and N <= 1024)) { |
| 46 | + if (K <= 4096) { |
| 47 | + return bf16bf16bf16_grouped_grad_256_32_128_2_1_1_10_f; |
| 48 | + } else { |
| 49 | + return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_10_f; |
| 50 | + } |
| 51 | + } else if (total_M <= 512) { |
| 52 | + if (N <= 1024) { |
| 53 | + return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_10_f; |
| 54 | + } else if (N <= 8192) { |
| 55 | + if (K <= 2048) { |
| 56 | + return bf16bf16bf16_grouped_grad_256_32_128_2_1_1_10_f; |
| 57 | + } else if (K <= 4096) { |
| 58 | + return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_10_f; |
| 59 | + } else { |
| 60 | + return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_10_f; |
| 61 | + } |
| 62 | + } |
| 63 | + } else if (total_M <= 1024) { |
| 64 | + if (N <= 1024) { |
| 65 | + return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_10_f; |
| 66 | + } else if (N <= 8192) { |
| 67 | + if (K <= 2048) { |
| 68 | + return bf16bf16bf16_grouped_grad_256_64_128_2_1_1_10_f; |
| 69 | + } else if (K <= 4096) { |
| 70 | + return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_10_f; |
| 71 | + } else { |
| 72 | + return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_10_f; |
| 73 | + } |
| 74 | + } |
| 75 | + } else if (total_M <= 2048) { |
| 76 | + if (N <= 1024) { |
| 77 | + return bf16bf16bf16_grouped_grad_256_256_128_2_1_1_10_f; |
| 78 | + } else if (N <= 8192) { |
| 79 | + if (K <= 2048) { |
| 80 | + return bf16bf16bf16_grouped_grad_256_128_128_2_1_1_10_f; |
| 81 | + } else if (K <= 4096) { |
| 82 | + return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_10_f; |
| 83 | + } |
| 84 | + } |
| 85 | + } |
| 86 | + return bf16bf16bf16_grouped_grad_256_256_128_2_1_1_10_f; |
| 87 | + } else { |
| 88 | + // Llama4 128E |
| 89 | + if (G == 128) { |
| 90 | + if (N == 5120 && K == 1024) { |
| 91 | + if (total_M <= 128) { |
| 92 | + return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f; |
| 93 | + } else if (total_M <= 256) { |
| 94 | + return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f; |
| 95 | + } else if (total_M <= 2048) { |
| 96 | + return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f; |
| 97 | + } else if (total_M <= 4096) { |
| 98 | + return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f; |
| 99 | + } else if (total_M <= 8192) { |
| 100 | + return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f; |
| 101 | + } else if (total_M <= 16384) { |
| 102 | + return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_t; |
| 103 | + } else { |
| 104 | + return bf16bf16bf16_grouped_grad_128_256_128_2_1_1_9_f; |
| 105 | + } |
| 106 | + } |
| 107 | + |
| 108 | + if (N == 2048 && K == 5120) { |
| 109 | + if (total_M <= 2048) { |
| 110 | + return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f; |
| 111 | + } else { |
| 112 | + return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_t; |
| 113 | + } |
| 114 | + } |
| 115 | + } |
| 116 | + |
| 117 | + // Llama4 64E |
| 118 | + if (G == 16) { |
| 119 | + if (N == 5120 && K == 1024) { |
| 120 | + if (total_M <= 32) { |
| 121 | + return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f; |
| 122 | + } else if (total_M <= 64) { |
| 123 | + return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f; |
| 124 | + } else if (total_M <= 256) { |
| 125 | + return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f; |
| 126 | + } else if (total_M <= 512) { |
| 127 | + return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f; |
| 128 | + } else if (total_M <= 1024) { |
| 129 | + return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_9_f; |
| 130 | + } else { |
| 131 | + return bf16bf16bf16_grouped_grad_128_256_128_2_1_1_9_f; |
| 132 | + } |
| 133 | + } |
| 134 | + |
| 135 | + if (N == 2048 && K == 5120) { |
| 136 | + if (total_M <= 16) { |
| 137 | + return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f; |
| 138 | + } else if (total_M <= 64) { |
| 139 | + return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f; |
| 140 | + } else if (total_M <= 256) { |
| 141 | + return bf16bf16bf16_grouped_grad_128_16_128_2_1_1_9_f; |
| 142 | + } else if (total_M <= 512) { |
| 143 | + return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f; |
| 144 | + } else if (total_M <= 1024) { |
| 145 | + return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f; |
| 146 | + } else { |
| 147 | + return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_t; |
| 148 | + } |
| 149 | + } |
| 150 | + } |
| 151 | + |
| 152 | + // Llama4.x pretraining |
| 153 | + if (N == 1280 && K == 5120) { |
| 154 | + if (total_M <= 256) { |
| 155 | + return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f; |
| 156 | + } else if (total_M <= 1024) { |
| 157 | + return bf16bf16bf16_grouped_grad_128_64_128_2_2_1_9_f; |
| 158 | + } else if (total_M <= 4096) { |
| 159 | + return bf16bf16bf16_grouped_grad_128_128_128_2_2_1_9_t; |
| 160 | + } else { |
| 161 | + return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t; |
| 162 | + } |
| 163 | + } else if (N == 2560 && K == 5120) { |
| 164 | + if (total_M <= 256) { |
| 165 | + return bf16bf16bf16_grouped_grad_128_64_128_2_2_1_9_f; |
| 166 | + } else if (total_M <= 1024) { |
| 167 | + return bf16bf16bf16_grouped_grad_128_64_128_2_2_1_9_f; |
| 168 | + } else { |
| 169 | + return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t; |
| 170 | + } |
| 171 | + } else if (N == 1536 && K == 6144) { |
| 172 | + if (total_M <= 256) { |
| 173 | + return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f; |
| 174 | + } else if (total_M <= 1024) { |
| 175 | + return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f; |
| 176 | + } else if (total_M <= 4096) { |
| 177 | + return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_t; |
| 178 | + } else { |
| 179 | + return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t; |
| 180 | + } |
| 181 | + } else if (N == 3072 && K == 6144) { |
| 182 | + if (total_M <= 256) { |
| 183 | + return bf16bf16bf16_grouped_grad_128_64_128_2_1_1_9_f; |
| 184 | + } else if (total_M <= 4096) { |
| 185 | + return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_t; |
| 186 | + } else { |
| 187 | + return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t; |
| 188 | + } |
| 189 | + } else if (N == 5120 && K == 2560) { |
| 190 | + if (total_M <= 256) { |
| 191 | + return bf16bf16bf16_grouped_grad_128_128_128_2_1_1_9_f; |
| 192 | + } else if (total_M <= 1024) { |
| 193 | + return bf16bf16bf16_grouped_grad_128_128_128_2_2_1_9_t; |
| 194 | + } else if (total_M <= 4096) { |
| 195 | + return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t; |
| 196 | + } else { |
| 197 | + return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t; |
| 198 | + } |
| 199 | + } else if (N == 5120 && K == 5120) { |
| 200 | + if (total_M <= 256) { |
| 201 | + return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_t; |
| 202 | + } else if (total_M <= 1024) { |
| 203 | + return bf16bf16bf16_grouped_grad_128_128_128_2_2_1_9_t; |
| 204 | + } else { |
| 205 | + return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t; |
| 206 | + } |
| 207 | + } else if (N == 6144 && K == 3072) { |
| 208 | + if (total_M <= 256) { |
| 209 | + return bf16bf16bf16_grouped_grad_128_32_128_2_1_1_9_f; |
| 210 | + } else if (total_M <= 1024) { |
| 211 | + return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_f; |
| 212 | + } else if (total_M <= 4096) { |
| 213 | + return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_t; |
| 214 | + } else { |
| 215 | + return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t; |
| 216 | + } |
| 217 | + } else if (N == 6144 && K == 6144) { |
| 218 | + return bf16bf16bf16_grouped_grad_128_128_128_1_2_1_9_t; |
| 219 | + } |
| 220 | + |
| 221 | + // Fallback to legacy heuristic for now. |
| 222 | + if (total_M <= 16) { |
| 223 | + return bf16bf16bf16_grouped_grad_128_16_128_1_1_1_9_f; |
| 224 | + } else if (total_M <= 32) { |
| 225 | + return bf16bf16bf16_grouped_grad_128_32_128_1_1_1_9_f; |
| 226 | + } else if (total_M <= 64) { |
| 227 | + return bf16bf16bf16_grouped_grad_128_64_128_1_1_1_9_f; |
| 228 | + } else if (total_M <= 128) { |
| 229 | + return bf16bf16bf16_grouped_grad_128_128_128_1_1_1_9_f; |
| 230 | + } else if (total_M <= 512) { |
| 231 | + return bf16bf16bf16_grouped_grad_256_128_128_2_1_1_9_f; |
| 232 | + } else { |
| 233 | + return bf16bf16bf16_grouped_grad_128_256_128_2_1_1_9_f; |
| 234 | + } |
| 235 | + } |
| 236 | +} |
| 237 | + |
| 238 | +Kernel_bf16bf16bf16_grouped_grad get_kernel_via_tuning( |
| 239 | + int arch, |
| 240 | + int G, |
| 241 | + int total_M, |
| 242 | + int N, |
| 243 | + int K, |
| 244 | + at::Tensor X, // BF16 |
| 245 | + at::Tensor W, // BF16 |
| 246 | + at::Tensor output, |
| 247 | + std::optional<at::Tensor> M_sizes = std::nullopt) { |
| 248 | + auto& cache = getTuningCache(); |
| 249 | + |
| 250 | + // Reducing amount of auto tuning by rounding up total_m to next power of 2. |
| 251 | + total_M = nextPowerOf2(total_M); |
| 252 | + // Use (total_M, N, K, G) shape as the key. |
| 253 | + const std::string shape_key = std::to_string(total_M) + "_" + |
| 254 | + std::to_string(N) + "_" + std::to_string(K) + "_" + std::to_string(G); |
| 255 | + const auto& kernels = get_bf16bf16bf16_grouped_grad_kernels(arch); |
| 256 | + auto kernel = cache.findBestKernelMaybeAutotune( |
| 257 | + shape_key, kernels, X, W, output, M_sizes); |
| 258 | + |
| 259 | + return kernel; |
| 260 | +} |
| 261 | + |
| 262 | +// BF16 grouped cutlass kernel dispatch. |
| 263 | +at::Tensor dispatch_bf16_grouped_kernel( |
| 264 | + int G, |
| 265 | + int total_M, |
| 266 | + int N, |
| 267 | + int K, |
| 268 | + at::Tensor X, // BF16 |
| 269 | + at::Tensor W, // BF16 |
| 270 | + at::Tensor output, |
| 271 | + std::optional<at::Tensor> M_sizes = std::nullopt) { |
| 272 | + static int arch = -1; |
| 273 | + // Avoid expensive cudaGetDeviceProperties call. |
| 274 | + if (arch < 0) { |
| 275 | + cudaDeviceProp prop; |
| 276 | + cudaGetDeviceProperties(&prop, 0); |
| 277 | + if (prop.major >= 10) { |
| 278 | + arch = 10; |
| 279 | + int runtimeVersion; |
| 280 | + C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion)); |
| 281 | + TORCH_CHECK( |
| 282 | + runtimeVersion >= 12080, |
| 283 | + "FP8 grouped GEMM on sm100a or above requires cuda >= 12.8"); |
| 284 | + } else { |
| 285 | + arch = 9; |
| 286 | + } |
| 287 | + } |
| 288 | + |
| 289 | + // Select kernel to run via heuristics or tuning. |
| 290 | + auto kernel = [&]() { |
| 291 | + if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) { |
| 292 | + return get_kernel_via_tuning( |
| 293 | + arch, G, total_M, N, K, X, W, output, M_sizes); |
| 294 | + } else { |
| 295 | + return get_kernel_via_heuristic(arch, G, total_M, N, K); |
| 296 | + } |
| 297 | + }(); |
| 298 | + // Invoke kernel |
| 299 | + return kernel(X, W, output, M_sizes); |
| 300 | +} |
| 301 | + |
| 302 | +at::Tensor |
| 303 | +bf16bf16bf16_grouped_grad(at::Tensor X, at::Tensor W, at::Tensor M_sizes) { |
| 304 | + int64_t total_M = X.size(0); |
| 305 | + int64_t N = W.size(1); |
| 306 | + int64_t K = W.size(2); |
| 307 | + int64_t G = M_sizes.size(0); |
| 308 | + TORCH_CHECK( |
| 309 | + M_sizes.device() == X.device(), |
| 310 | + "M_sizes must be on same device as inputs."); |
| 311 | + TORCH_CHECK( |
| 312 | + W.dim() == 3 && W.size(0) == G, "Weights should be shape [G, N, K].") |
| 313 | + |
| 314 | + TORCH_CHECK(X.stride(-1) == 1, "Activation memory layout must be row-major."); |
| 315 | + TORCH_CHECK(W.stride(-2) == 1, "Weight memory layout must be column-major."); |
| 316 | + |
| 317 | + at::Tensor Y = at::empty(total_M * N, X.options().dtype(at::kBFloat16)); |
| 318 | + // Early exit for empty inputs. |
| 319 | + if (total_M == 0) { |
| 320 | + return Y.view({total_M, N}); |
| 321 | + } |
| 322 | + // Return continuous view of output. |
| 323 | + at::Tensor out = |
| 324 | + dispatch_bf16_grouped_kernel(G, total_M, N, K, X, W, Y, M_sizes); |
| 325 | + return out.view({total_M, N}); |
| 326 | +} |
| 327 | + |
| 328 | +#else |
| 329 | + |
| 330 | +at::Tensor bf16bf16bf16_grouped_grad(at::Tensor, at::Tensor, at::Tensor) { |
| 331 | + throw std::runtime_error( |
| 332 | + "CUDA version is older than 12.0"); // requires CUDA>=12 |
| 333 | +} |
| 334 | + |
| 335 | +#endif |
| 336 | + |
| 337 | +} // namespace fbgemm_gpu |
0 commit comments