|
| 1 | +/* |
| 2 | + * Copyright (c) 2025 by FlashInfer team. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | +#include <flashinfer/cutlass_utils.cuh> |
| 17 | + |
| 18 | +#include "pytorch_extension_utils.h" |
| 19 | + |
| 20 | +using namespace flashinfer; |
| 21 | + |
| 22 | +#define DISPATCH_PYTORCH_INPUT_OUTPUT_DTYPE(input_dtype, output_dtype, c_type_in, c_type_out, ...) \ |
| 23 | + [&]() -> bool { \ |
| 24 | + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, c_type_out, [&] { \ |
| 25 | + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(input_dtype, c_type_in, \ |
| 26 | + [&] { return __VA_ARGS__(); }); \ |
| 27 | + }); \ |
| 28 | + }() |
| 29 | + |
| 30 | +#define DISPATCH_SCALE_GRANULARITY(scale_granularity_m, scale_granularity_n, scale_granularity_k, \ |
| 31 | + SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, \ |
| 32 | + ...) \ |
| 33 | + [&]() -> bool { \ |
| 34 | + /* SM120 Cooperative schedule uses 128x128x128 tile shape */ \ |
| 35 | + /* TODO (yongwww): PingPong schedule (64x128x128) will need additional dispatch logic */ \ |
| 36 | + constexpr int SCALE_GRANULARITY_K = 128; \ |
| 37 | + if (scale_granularity_k != 128) { \ |
| 38 | + TORCH_CHECK( \ |
| 39 | + false, \ |
| 40 | + "SM120 requires scale_granularity_k=128. CUTLASS enforces ScaleGranularityK must equal " \ |
| 41 | + "tile shape K dimension (128 for both Cooperative and PingPong schedules)."); \ |
| 42 | + return false; \ |
| 43 | + } \ |
| 44 | + /* Support (1,128,128) and (128,128,128) as per SM100's approach */ \ |
| 45 | + if (scale_granularity_m == 1 && scale_granularity_n == 128) { \ |
| 46 | + constexpr int SCALE_GRANULARITY_M = 1; \ |
| 47 | + constexpr int SCALE_GRANULARITY_N = 128; \ |
| 48 | + return __VA_ARGS__(); \ |
| 49 | + } else if (scale_granularity_m == 128 && scale_granularity_n == 128) { \ |
| 50 | + constexpr int SCALE_GRANULARITY_M = 128; \ |
| 51 | + constexpr int SCALE_GRANULARITY_N = 128; \ |
| 52 | + return __VA_ARGS__(); \ |
| 53 | + } \ |
| 54 | + TORCH_CHECK(false, "SM120: Unsupported scale granularity combination (", scale_granularity_m, \ |
| 55 | + ",", scale_granularity_n, ",", scale_granularity_k, ")"); \ |
| 56 | + return false; \ |
| 57 | + }() |
| 58 | + |
| 59 | +#define DISPATCH_SCALE_MAJOR_K(scale_major_mode, SCALE_MAJOR_K, ...) \ |
| 60 | + [&]() -> bool { \ |
| 61 | + if (scale_major_mode == "K") { \ |
| 62 | + constexpr bool SCALE_MAJOR_K = true; \ |
| 63 | + return __VA_ARGS__(); \ |
| 64 | + } else if (scale_major_mode == "MN") { \ |
| 65 | + constexpr bool SCALE_MAJOR_K = false; \ |
| 66 | + return __VA_ARGS__(); \ |
| 67 | + } \ |
| 68 | + TORCH_CHECK(false, "Unsupported Scale Major Mode"); \ |
| 69 | + return false; \ |
| 70 | + }() |
| 71 | + |
| 72 | +namespace flashinfer { |
| 73 | +namespace gemm { |
| 74 | + |
| 75 | +template <int ScaleGranularityM, int ScaleGranularityN, int ScaleGranularityK, bool ScaleMajorK, |
| 76 | + typename DTypeIn, typename DTypeOut> |
| 77 | +cudaError_t CutlassGroupwiseScaledGEMMSM120(void* float_buffer, size_t float_buffer_size_in_bytes, |
| 78 | + DTypeIn* A_ptr, DTypeIn* B_ptr, float* SFA_ptr, |
| 79 | + float* SFB_ptr, DTypeOut* C_ptr, int m, int n, int k, |
| 80 | + int l, cudaStream_t stream); |
| 81 | + |
| 82 | +} // namespace gemm |
| 83 | +} // namespace flashinfer |
| 84 | + |
| 85 | +void CutlassGemmGroupwiseScaledSM120(at::Tensor float_workspace_buffer, at::Tensor A, at::Tensor B, |
| 86 | + at::Tensor SFA, at::Tensor SFB, at::Tensor C, |
| 87 | + int64_t scale_granularity_m, int64_t scale_granularity_n, |
| 88 | + int64_t scale_granularity_k, std::string scale_major_mode) { |
| 89 | + const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device()); |
| 90 | + auto stream = at::cuda::getCurrentCUDAStream(); |
| 91 | + |
| 92 | + // Ensure scales are contiguous |
| 93 | + // Note: We keep the original shape and let the kernel's layout handle interpretation |
| 94 | + at::Tensor SFA_contig = SFA.is_contiguous() ? SFA : SFA.contiguous(); |
| 95 | + at::Tensor SFB_contig = SFB.is_contiguous() ? SFB : SFB.contiguous(); |
| 96 | + |
| 97 | + DISPATCH_SCALE_MAJOR_K(scale_major_mode, SCALE_MAJOR_K, [&] { |
| 98 | + return DISPATCH_PYTORCH_INPUT_OUTPUT_DTYPE( |
| 99 | + A.scalar_type(), C.scalar_type(), c_type_in, c_type_out, [&] { |
| 100 | + return DISPATCH_SCALE_GRANULARITY( |
| 101 | + scale_granularity_m, scale_granularity_n, scale_granularity_k, SCALE_GRANULARITY_M, |
| 102 | + SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, [&] { |
| 103 | + using cutlass_t_in = cutlass_dtype_t<c_type_in>; |
| 104 | + using cutlass_t_out = cutlass_dtype_t<c_type_out>; |
| 105 | + |
| 106 | + // Handle both 2D and 3D tensors (BMM) |
| 107 | + int m, n, k, l; |
| 108 | + if (A.dim() == 2) { |
| 109 | + // 2D case: simple matrix multiplication |
| 110 | + m = A.size(0); |
| 111 | + k = A.size(1); |
| 112 | + n = B.size(0); |
| 113 | + l = 1; // no batch dimension |
| 114 | + } else if (A.dim() == 3) { |
| 115 | + // 3D case: batch matrix multiplication |
| 116 | + l = A.size(0); // batch size |
| 117 | + m = A.size(1); // per-batch m dimension |
| 118 | + k = A.size(2); // per-batch k dimension |
| 119 | + n = B.size(2); // per-batch n dimension (B is [batch, k, n] column-major) |
| 120 | + } else { |
| 121 | + return false; // Unsupported tensor dimension |
| 122 | + } |
| 123 | + |
| 124 | + auto status = flashinfer::gemm::CutlassGroupwiseScaledGEMMSM120< |
| 125 | + SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, SCALE_MAJOR_K>( |
| 126 | + static_cast<void*>(float_workspace_buffer.data_ptr()), |
| 127 | + float_workspace_buffer.element_size() * float_workspace_buffer.numel(), |
| 128 | + static_cast<cutlass_t_in*>(A.data_ptr()), |
| 129 | + static_cast<cutlass_t_in*>(B.data_ptr()), |
| 130 | + static_cast<float*>(SFA_contig.data_ptr()), |
| 131 | + static_cast<float*>(SFB_contig.data_ptr()), |
| 132 | + static_cast<cutlass_t_out*>(C.data_ptr()), m, n, k, l, |
| 133 | + stream); // C is the output (D) |
| 134 | + return status == cudaSuccess; |
| 135 | + }); |
| 136 | + }); |
| 137 | + }); |
| 138 | +} |
0 commit comments