diff --git a/benchmarks/bench_mm_fp8.py b/benchmarks/bench_mm_fp8.py new file mode 100644 index 0000000000..a4df76ebd9 --- /dev/null +++ b/benchmarks/bench_mm_fp8.py @@ -0,0 +1,98 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Dict +from flashinfer.autotuner import autotune +from flashinfer.trtllm_low_latency_gemm import prepare_low_latency_gemm_weights +import numpy as np +import torch + +from flashinfer import mm_fp8 +from flashinfer.testing.utils import bench_gpu_time + +_cache_permute_indices: Dict[torch.Size, torch.Tensor] = {} + + +def to_float8( + x: torch.Tensor, dtype=torch.float8_e4m3fn +) -> tuple[torch.Tensor, torch.Tensor]: + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +def bench_mm_fp8(m, n, k, in_dtype, out_dtype): + torch.manual_seed(123) + input_tensor = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + input_fp8, input_inv_s = to_float8(input_tensor, dtype=in_dtype) + + # mat2 row major -> column major + mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) + mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=in_dtype) + + res = torch.zeros([m, n], device="cuda", dtype=out_dtype) + global_scale = input_inv_s * mat2_inv_s + + # Do row shuffling. + prepared_weights = prepare_low_latency_gemm_weights( + mat2_fp8, _cache_permute_indices + ) + + with autotune(True): + mm_fp8( + input_fp8, + prepared_weights, + global_scale, + out=res, + ) + + measurements = bench_gpu_time( + lambda: mm_fp8( + input_fp8, + prepared_weights, + global_scale, + res, + ), + dry_run_time_ms=500, + repeat_time_ms=2500, + use_cuda_graph=True, + ) + ms = np.median(measurements) + tflops_per_second = 2 * m * n * k * 1e-9 / ms + + bandwidth = ( + ( + input_fp8.numel() * input_fp8.element_size() + + prepared_weights.numel() * prepared_weights.element_size() + + res.numel() * res.element_size() + ) + / ms + / 1e9 + ) + + print( + f"mm_fp8 m={m} n={n} k={k} in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s over {ms:.6f} ms, {bandwidth:.2f} TB/s" + ) + + +if __name__ == "__main__": + for m in [1, 2, 4, 8, 16, 32, 64]: + for n in [2560, 5120, 8192]: + for k in [16384, 32768]: + bench_mm_fp8(m, n, k, torch.float8_e4m3fn, torch.bfloat16) diff --git a/csrc/trtllm_gemm_runner.cu b/csrc/trtllm_gemm_runner.cu index c4ec9d5cff..d4bdc48cef 100644 --- a/csrc/trtllm_gemm_runner.cu +++ b/csrc/trtllm_gemm_runner.cu @@ -43,19 +43,20 @@ struct TrtllmGenGemmRunnerOptions { int64_t select_kernel_fp8(int32_t M, int32_t N, int32_t K, const gemm::gemm::GemmInterface& interface) { static constexpr const char* KERNEL_NAME_HIGH_N_K_RATIO = - "gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128u2_s6_et64x8_m64x8x32_cga1x1x1_16dp256b_TN_transOut_" + "gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128u2_s6_et64x8_m64x8x32_cga1x1x1_16dp256b_rM_TN_" + "transOut_" "noShflA_dsFp8_schedP2x2x1x3_sm100f"; static constexpr const char* KERNEL_NAME_LOW_N_K_RATIO = - "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_TN_" + "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_rM_TN_" "transOut_noShflA_dsFp8_schedS_sm100f"; static constexpr const char* KERNEL_NAME_LARGE_N = - "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_TN_" + "gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_rM_TN_" "transOut_noShflA_dsFp8_schedP2x2x1x3_sm100f"; static constexpr const char* KERNEL_NAME_DEFAULT = - "gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128u2_s6_et64x16_m64x16x32_cga1x1x1_16dp256b_TN_" + "gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128u2_s6_et64x16_m64x16x32_cga1x1x1_16dp256b_rM_TN_" "transOut_noShflA_dsFp8_schedS_sm100f"; double const n_k_ratio = static_cast(N) / static_cast(K); diff --git a/csrc/trtllm_low_latency_gemm_runner.cu b/csrc/trtllm_low_latency_gemm_runner.cu new file mode 100644 index 0000000000..d0df47cb65 --- /dev/null +++ b/csrc/trtllm_low_latency_gemm_runner.cu @@ -0,0 +1,324 @@ +/* + * Copyright (c) 2020-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +#include "flashinfer/exception.h" +#include "flashinfer/trtllm/common.h" +#include "flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h" +#include "flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h" +#include "flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/DtypeDecl.h" +#include "flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/SfLayoutDecl.h" + +namespace { +static thread_local gemm::gemm::GemmInterface::ModuleCache globalTrtllmLowLatencyGemmModuleCache; +} // namespace + +namespace flashinfer { + +using tvm::ffi::Array; +using tvm::ffi::Optional; + +struct TrtllmLowLatencyGemmRunnerOptions { + gemm::trtllm::gen::Dtype eltType; + gemm::trtllm::gen::Dtype outputType; +}; + +gemm::gemm::GemmData createGemmData(int64_t m, int64_t n, int64_t k) { + gemm::gemm::GemmData gemmData{}; + + // Dims + gemmData.mProblemDimensions.mM = n; + gemmData.mProblemDimensions.mN = m; + gemmData.mProblemDimensions.mK = k; + gemmData.mProblemDimensions.mRank = 0; + gemmData.mProblemDimensions.mWorldSize = 1; + + return gemmData; +} + +/** + * Very rough heuristic for selecting a kernel. Prefer using auto-tuning. + */ +int64_t select_kernel(int32_t m, int32_t n, int32_t k, const gemm::gemm::GemmInterface& interface) { + static constexpr const char* KERNEL_MMAN_8_TILEK_128_SPLITK_2 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128_s7_et128x8_m128x8x32_cga1x1x2_16dp256b_splitK2_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_8_TILEK_128_SPLITK_3 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128_s7_et128x8_m128x8x32_cga1x1x3_16dp256b_splitK3_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_8_TILEK_256_SPLITK_2 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x256_s4_et128x8_m128x8x32_cga1x1x2_16dp256b_splitK2_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_8_TILEK_256_SPLITK_3 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x256_s4_et128x8_m128x8x32_cga1x1x3_16dp256b_splitK3_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_16_TILEK_128_SPLITK_2 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128_s7_et128x16_m128x16x32_cga1x1x2_16dp256b_splitK2_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_16_TILEK_128_SPLITK_3 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128_s7_et128x16_m128x16x32_cga1x1x3_16dp256b_splitK3_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_16_TILEK_256_SPLITK_2 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x256_s5_et128x16_m128x16x32_cga1x1x2_16dp256b_splitK2_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_16_TILEK_256_SPLITK_3 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x256_s5_et128x16_m128x16x32_cga1x1x3_16dp256b_splitK3_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_32_TILEK_128_SPLITK_2 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128_s9_et128x32_m128x32x32_cga1x1x2_16dp256b_splitK2_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_32_TILEK_128_SPLITK_3 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128_s9_et128x32_m128x32x32_cga1x1x3_16dp256b_splitK3_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_32_TILEK_256_SPLITK_2 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x256_s5_et128x32_m128x32x32_cga1x1x2_16dp256b_splitK2_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_32_TILEK_256_SPLITK_3 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x256_s5_et128x32_m128x32x32_cga1x1x3_16dp256b_splitK3_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_64_TILEK_128_SPLITK_2 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x128_s7_et128x64_m128x64x32_cga1x1x2_16dp256b_splitK2_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_64_TILEK_128_SPLITK_3 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x128_s7_et128x64_m128x64x32_cga1x1x3_16dp256b_splitK3_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_64_TILEK_256_SPLITK_2 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x256_s3_et128x64_m128x64x32_cga1x1x2_16dp256b_splitK2_BN_" + "transOut_schedS_sm100f"; + static constexpr const char* KERNEL_MMAN_64_TILEK_256_SPLITK_3 = + "Gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x256_s3_et128x64_m128x64x32_cga1x1x3_16dp256b_splitK3_BN_" + "transOut_schedS_sm100f"; + + std::string kernel_name; + if (m <= 8) { + kernel_name = KERNEL_MMAN_8_TILEK_128_SPLITK_2; + } else if (m <= 16) { + kernel_name = KERNEL_MMAN_16_TILEK_128_SPLITK_2; + } else if (m <= 32) { + kernel_name = KERNEL_MMAN_32_TILEK_128_SPLITK_2; + } else { + kernel_name = KERNEL_MMAN_64_TILEK_128_SPLITK_2; + } + + auto const& configs = interface.getGemmConfigs(); + size_t const num_configs = interface.getNumGemmConfigs(); + + for (size_t i = 0; i < num_configs; ++i) { + if (std::string(configs[i].mFunctionName) == kernel_name) { + return static_cast(i); + } + } + + TVM_FFI_LOG_AND_THROW(RuntimeError) + << "No kernel was found heuristically for the given problem size"; +} + +int64_t getWorkspaceSizeInBytes(int64_t m, int64_t n, int64_t k, int64_t tactic) { + auto gemm = gemm::gemm::GemmInterface(); + + if (tactic == -1) { + tactic = select_kernel(m, n, k, gemm); + } + + auto const configs = gemm.getGemmConfigs(); + FLASHINFER_CHECK(tactic >= 0 && tactic < gemm.getNumGemmConfigs(), + "Invalid tactic in getWorkspaceSizeInBytes"); + auto const config = configs[tactic]; + + auto const gemmData = createGemmData(m, n, k); + + return gemm.getWorkspaceSizeInBytes(config, gemmData); +} + +class TrtllmLowLatencyGemmRunner { + public: + explicit TrtllmLowLatencyGemmRunner(TrtllmLowLatencyGemmRunnerOptions const& options) + : mOptions(options) { + // Select a GEMM kernel config to use + auto const gemm = gemm::gemm::GemmInterface(); + auto const configs = gemm.getGemmConfigs(); + + mPassingConfigIndices.clear(); + + for (size_t i = 0; i < gemm.getNumGemmConfigs(); ++i) { + auto const configOptions = configs[i].mOptions; + + if (configOptions.mDtypeA == mOptions.eltType && + configOptions.mDtypeC == mOptions.outputType && + configOptions.mTransposeMmaOutput == true && + configOptions.mLayoutA == gemm::gemm::MatrixLayout::BlockMajorK && + configOptions.mUseShuffledMatrixA) { + mPassingConfigIndices.push_back(i); + } + } + + FLASHINFER_CHECK( + mPassingConfigIndices.size() > 0, + "No valid low latency TRTLLM-GEN GEMM kernel was found for the given data types."); + } + + void run(int64_t m, int64_t n, int64_t k, void const* a, void const* b, void* c, void* cScale, + void* workspace, CUstream stream, int32_t device_index, int64_t tactic) { + auto gemm = gemm::gemm::GemmInterface(); + auto const configs = gemm.getGemmConfigs(); + TVM_FFI_ICHECK(tactic >= 0 && tactic < gemm.getNumGemmConfigs()) << "Invalid tactic id in run"; + auto const& config = configs[tactic]; + + gemm::gemm::GemmData gemmData = createGemmData(m, n, k); + + // Inputs + gemmData.mInputBuffers.mPtrA = b; + gemmData.mInputBuffers.mPtrB = a; + gemmData.mInputBuffers.mPtrScaleC = cScale; + + // Outputs + gemmData.mOutputBuffers.mPtrC = c; + + TVM_FFI_ICHECK(gemm.isValidConfig(config, gemmData)) + << "The selected tactic points to a TRTLLM-GEN low latency GEMM kernel that is not valid " + "for " + "the given problem size."; + + int32_t const multiProcessorCount = [device_index]() { + static thread_local int32_t cached_multi_processor_count = -1; + static thread_local int cached_device_index = -1; + + if (device_index == cached_device_index && cached_multi_processor_count != -1) { + return cached_multi_processor_count; + } else { + int32_t count; + cudaError_t cudaStatus = + cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_index); + TVM_FFI_ICHECK(cudaStatus == cudaSuccess) + << "Failed to get device attribute: " << cudaGetErrorString(cudaStatus); + cached_multi_processor_count = count; + cached_device_index = device_index; + return count; + } + }(); + + TVM_FFI_ICHECK(gemm.run(config, workspace, gemmData, static_cast(stream), + multiProcessorCount, true, globalTrtllmLowLatencyGemmModuleCache) == 0) + << "Error occurred when running low latency TRTLLM-GEN GEMM!"; + } + + std::vector getValidTactics(int64_t m, int64_t n, int64_t k) const { + auto const gemm = gemm::gemm::GemmInterface(); + auto const configs = gemm.getGemmConfigs(); + + auto const gemmData = createGemmData(m, n, k); + + std::vector validTactics{}; + for (auto const& configIndex : mPassingConfigIndices) { + auto const& config = configs[configIndex]; + if (gemm.isValidConfig(config, gemmData)) { + validTactics.push_back(configIndex); + } + } + return validTactics; + } + + private: + TrtllmLowLatencyGemmRunnerOptions mOptions; + std::vector mPassingConfigIndices; +}; + +void trtllm_low_latency_gemm(Tensor workspace_buffer, Tensor a, Tensor b, Tensor globalScale, + Tensor out, int64_t tactic) { + CHECK_DEVICE(a, b); + CHECK_DEVICE(a, out); + CHECK_INPUT(a); + CHECK_INPUT(b); + CHECK_INPUT(out); + CHECK_INPUT(workspace_buffer); + CHECK_DIM(2, a); + TVM_FFI_ICHECK(b->ndim == 3) << "b must be a block layout matrix (3D tensor with " + "dims [N/BLOCK_SIZE, K, BLOCK_SIZE])"; + TVM_FFI_ICHECK_EQ(a->dtype, b->dtype); + TVM_FFI_ICHECK(a->dtype == dl_float8_e4m3fn) << "a must be a Float8 tensor"; + + int32_t m = a->shape[0]; + int32_t k = a->shape[1]; + int32_t n = b->shape[1]; + auto const blockSize = b->shape[2]; + auto const kFromB = b->shape[0] * blockSize; + TVM_FFI_ICHECK(kFromB == a->shape[1]) << "Matrix dimensions don't match for multiplication"; + TVM_FFI_ICHECK(out->shape[0] == m && out->shape[1] == n) << "Output tensor has wrong dimensions"; + + if (tactic == -1) { + tactic = select_kernel(m, n, k, gemm::gemm::GemmInterface()); + } + + auto runner = + flashinfer::TrtllmLowLatencyGemmRunner(flashinfer::TrtllmLowLatencyGemmRunnerOptions{ + .eltType = gemm::trtllm::gen::Dtype::E4m3, + .outputType = gemm::trtllm::gen::Dtype::Bfloat16, + }); + + auto stream = get_stream(a->device); + + int64_t const required_workspace_size = getWorkspaceSizeInBytes(m, n, k, tactic); + int64_t const provided_workspace_size = + get_numel(workspace_buffer) * get_element_size(workspace_buffer); + if (provided_workspace_size < required_workspace_size) { + TVM_FFI_LOG_AND_THROW(RuntimeError) + << "The size of the provided workspace to the TRTLLM-GEN low latency GEMM is too small. " + "Please use the provided workspace sizing function to pre-allocate an adequate " + "workspace."; + } + + runner.run(m, n, k, a->data, b->data, out->data, globalScale->data, workspace_buffer->data, + stream, a->device.device_id, tactic); +} + +enum class Dtype : int64_t { + E2m1 = 0, + E4m3 = 1, + Bfloat16 = 2, +}; + +Array trtllm_low_latency_gemm_tactics(int64_t m, int64_t n, int64_t k, int64_t input_dtype, + int64_t output_dtype) { + TVM_FFI_ICHECK(input_dtype == static_cast(Dtype::E4m3)) << "Unsupported input dtype"; + TVM_FFI_ICHECK_EQ(output_dtype, static_cast(Dtype::Bfloat16)) + << "Unsupported output dtype"; + + auto runner = + flashinfer::TrtllmLowLatencyGemmRunner(flashinfer::TrtllmLowLatencyGemmRunnerOptions{ + .eltType = gemm::trtllm::gen::Dtype::E4m3, + .outputType = gemm::trtllm::gen::Dtype::Bfloat16, + }); + + return runner.getValidTactics(m, n, k); +} + +namespace trtllm_cubin_loader { +#include +} + +} // namespace flashinfer + +// Exposes low latency optimized GEMMs that require some pre-processing of the inputs. +TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_low_latency_gemm, flashinfer::trtllm_low_latency_gemm); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_low_latency_gemm_tactics, + flashinfer::trtllm_low_latency_gemm_tactics); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(get_workspace_size_in_bytes, flashinfer::getWorkspaceSizeInBytes); diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index ccfae46ee7..2f30039a47 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -88,6 +88,7 @@ from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper from .gemm import bmm_fp8 as bmm_fp8 from .gemm import mm_fp4 as mm_fp4 +from .gemm import mm_fp8 as mm_fp8 from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100 from .mla import BatchMLAPagedAttentionWrapper as BatchMLAPagedAttentionWrapper from .norm import fused_add_rmsnorm as fused_add_rmsnorm @@ -143,5 +144,8 @@ from .sparse import ( VariableBlockSparseAttentionWrapper as VariableBlockSparseAttentionWrapper, ) +from .trtllm_low_latency_gemm import ( + prepare_low_latency_gemm_weights as prepare_low_latency_gemm_weights, +) from .utils import next_positive_power_of_2 as next_positive_power_of_2 from .xqa import xqa as xqa diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 5062e7c576..69c78ec00c 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -57,6 +57,7 @@ gen_gemm_sm120_module, gen_gemm_sm120_module_cutlass_fp4, gen_trtllm_gen_gemm_module, + gen_trtllm_low_latency_gemm_module, ) from .jit.spdlog import gen_spdlog_module from .jit.mla import gen_mla_module @@ -460,6 +461,7 @@ def gen_all_modules( ) jit_specs.append(gen_mxfp8_quantization_sm100_module()) jit_specs.append(gen_trtllm_gen_gemm_module()) + jit_specs.append(gen_trtllm_low_latency_gemm_module()) jit_specs.append(gen_trtllm_gen_fused_moe_sm100_module()) if has_sm100f: # Add TGV GEMM modules compiled with SM100f flags for both bf16 and fp16 diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 4d344d02bb..523fae8fa6 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -79,7 +79,7 @@ class ArtifactPath: "e6f22dcc3fdeb29ff87af2f4a2cb3d30b8d273e0/batched_gemm-45beda1-ee6a802" ) TRTLLM_GEN_GEMM: str = ( - "037e528e719ec3456a7d7d654f26b805e44c63b1/gemm-8704aa4-f91dc9e" + "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3" ) CUDNN_SDPA: str = "4c623163877c8fef5751c9c7a59940cd2baae02e/fmha/cudnn" DEEPGEMM: str = "51d730202c9eef782f06ecc950005331d85c5d4b/deep-gemm" @@ -95,7 +95,7 @@ class MetaInfoHash: ) DEEPGEMM: str = "b4374f857c3066089c4ec6b5e79e785559fa2c05ce2623710b0b04bf86414a48" TRTLLM_GEN_GEMM: str = ( - "0345358c916d990709f9670e113e93f35c76aa22715e2d5128ec2ca8740be5ba" + "bd5c3227bec4f8d7a7d3a27fd7628e010d99a5c42651d0a6b97e146803e63340" ) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 364d4182f1..705de6917e 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -192,7 +192,7 @@ def _maybe_get_cached_w3_w1_permute_indices( return permute_indices -def _maybe_get_cached_w2_permute_indices( +def get_w2_permute_indices_with_cache( _cache_permute_indices, dst_w2_weight: torch.Tensor, epilogue_tile_m: int, diff --git a/flashinfer/fused_moe/utils.py b/flashinfer/fused_moe/utils.py index 2bb196858a..17764b8599 100644 --- a/flashinfer/fused_moe/utils.py +++ b/flashinfer/fused_moe/utils.py @@ -205,7 +205,7 @@ def get_power_of_2_num_tokens_buckets(max_num_tokens) -> Tuple[int]: def get_last_power_of_2_num_tokens_buckets( max_num_tokens, min_num_tokens=1 -) -> Tuple[int]: +) -> Tuple[int, ...]: max_num_tokens = last_positive_power_of_2(max_num_tokens) num_token_buckets = [] m = max_num_tokens diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index ffd6d8ce56..de6fd275bf 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -19,6 +19,7 @@ from types import SimpleNamespace from typing import List, Literal, Optional, Tuple +from flashinfer.trtllm_low_latency_gemm import trtllm_low_latency_gemm import torch from .autotuner import ( @@ -34,6 +35,7 @@ last_positive_power_of_2, ) from .utils import ( + get_native_fp4_dtype, is_sm100a_supported, is_sm100f_supported, is_sm120a_supported, @@ -1153,14 +1155,6 @@ def _is_cublas_fp4_available_in_cudnn(): ) -def _get_native_fp4_dtype(): - """get native fp4 datatype if supported in the torch, otherwise return uint8.""" - if hasattr(torch, "float4_e2m1fn_x2"): - return torch.float4_e2m1fn_x2 - else: - return torch.uint8 - - # Global cudnn handle. need to make it per device in future _cudnn_handle = None @@ -1299,8 +1293,8 @@ def execute_cudnn_gemm_fp4_graph( workspace_buffer, ): variant_pack = { - UIDs.A_UID.value: a.view(_get_native_fp4_dtype()), - UIDs.B_UID.value: b.view(_get_native_fp4_dtype()), + UIDs.A_UID.value: a.view(get_native_fp4_dtype()), + UIDs.B_UID.value: b.view(get_native_fp4_dtype()), UIDs.BLOCK_DESCALE_A_UID.value: a_descale, UIDs.BLOCK_DESCALE_B_UID.value: b_descale, UIDs.O_UID.value: c_final, @@ -1540,6 +1534,116 @@ def _expand_block_scale_tensor_shape(block_scale_tensor, batch_size): return (tuple(block_scale_shape), tuple(block_scale_stride)) +def mm_fp8( + a: torch.Tensor, + b: torch.Tensor, + alpha: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + backend: Literal["trtllm_low_latency"] = "trtllm_low_latency", +): + r"""FP8 matrix multiplication. + + Parameters + ---------- + a: torch.Tensor + Input tensor, shape (m, k), fp8 e4m3. + + b: torch.Tensor + - When using "trtllm_low_latency" backend, + Weight tensor, shape (k // block_size, n, block_size), fp8 e4m3 + B needs to be pre-processed using `prepare_low_latency_gemm_weights`. + block_size is 128 for e4m3. + + alpha: Optional[torch.Tensor] + Scale tensor for the output, float. If None, defaults to 1.0 for no scaling. + + out_dtype: torch.dtype + Output tensor data type. Default is torch.bfloat16. + + out: Optional[torch.Tensor] + Output tensor, shape (m, n). If None, a new tensor will be allocated. + + backend: Literal["trtllm_low_latency"] + Backend to use for computation. Default is "trtllm_low_latency". + - "trtllm_low_latency": optimized for small M dimension. + + Returns + ------- + torch.Tensor + Output tensor of shape (m, n) with dtype `out_dtype`. + + Examples + -------- + >>> import torch + >>> from flashinfer import mm_fp8, prepare_low_latency_gemm_weights + >>> m = 16 + >>> n = 2560 + >>> k = 32768 + >>> a = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + >>> a_fp8, a_inv_s = to_float8(a, dtype=torch.float8_e4m3fn) + >>> b = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) + >>> b_fp8, b_inv_s = to_float8(b, dtype=torch.float8_e4m3fn) + >>> prepared_b = prepare_low_latency_gemm_weights(b_fp8) + >>> alpha = a_inv_s * b_inv_s + >>> out = mm_fp8(a_fp8, prepared_b, alpha) + >>> out.shape + torch.Size([16, 2560]) + """ + + supported_out_dtypes = (torch.bfloat16,) + supported_backends = ("trtllm_low_latency",) + + if backend == "trtllm_low_latency": + m = a.shape[0] + n = b.shape[1] + else: + raise ValueError( + f"Unsupported backend: {backend}. " + f"Only {supported_backends} are supported for FP8 GEMM operations." + ) + + # allocate the output tensor if not provided + if out is None: + if out_dtype not in supported_out_dtypes: + raise ValueError( + f"Unsupported output dtype: {out_dtype}. " + f"Only {supported_out_dtypes} are supported for FP8 GEMM operations." + ) + out = torch.empty( + (m, n), + device=a.device, + dtype=out_dtype, + ) + else: + if out.dtype not in supported_out_dtypes: + raise ValueError( + f"Unsupported output dtype: {out.dtype}. " + f"Only {supported_out_dtypes} are supported for FP8 GEMM operations." + ) + if out.shape != (a.shape[0], b.shape[1]): + raise ValueError( + f"Output shape mismatch. Expected {a.shape[0], b.shape[1]}, got {out.shape}." + ) + if out.device != a.device: + raise ValueError( + f"Output device mismatch. Expected {a.device}, got {out.device}." + ) + if out_dtype is not None and out.dtype != out_dtype: + raise ValueError( + f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}." + ) + + if backend == "trtllm_low_latency": + trtllm_low_latency_gemm(a, b, alpha, out) + else: + raise ValueError( + f"Unsupported backend: {backend}. " + f"Only {supported_backends} are supported for FP8 GEMM operations." + ) + return out + + def mm_fp4( a: torch.Tensor, b: torch.Tensor, @@ -1621,9 +1725,9 @@ def mm_fp4( raise ValueError( f"K dimension mismatch in mm_fp4. got a.shape[1] = {a.shape[1]}, b.shape[0] = {b.shape[0]}" ) - if a.dtype not in {torch.uint8, _get_native_fp4_dtype()} or b.dtype not in { + if a.dtype not in {torch.uint8, get_native_fp4_dtype()} or b.dtype not in { torch.uint8, - _get_native_fp4_dtype(), + get_native_fp4_dtype(), }: raise ValueError( f"a and b must have float4_e2m1fn_x2 packed into uint8. " diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index e01e5de460..32f6440f7d 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -237,12 +237,12 @@ def build(self, verbose: bool, need_lock: bool = True) -> None: self.write_ninja() run_ninja(jit_env.FLASHINFER_JIT_DIR, self.ninja_path, verbose) - def load(self, so_path: Path, class_name: str = None): + def load(self, so_path: Path): return tvm_ffi.load_module(str(so_path)) - def build_and_load(self, class_name: str = None): + def build_and_load(self): if self.is_aot: - return self.load(self.aot_path, class_name) + return self.load(self.aot_path) # Guard both build and load with the same lock to avoid race condition # where another process is building the library and removes the .so file. @@ -250,7 +250,7 @@ def build_and_load(self, class_name: str = None): so_path = self.jit_library_path verbose = os.environ.get("FLASHINFER_JIT_VERBOSE", "0") == "1" self.build(verbose, need_lock=False) - result = self.load(so_path, class_name) + result = self.load(so_path) return result diff --git a/flashinfer/jit/gemm/__init__.py b/flashinfer/jit/gemm/__init__.py index 1f0e0c1656..f1681d3bf5 100644 --- a/flashinfer/jit/gemm/__init__.py +++ b/flashinfer/jit/gemm/__init__.py @@ -22,6 +22,7 @@ gen_gemm_sm100_module, gen_gemm_sm120_module, gen_trtllm_gen_gemm_module, + gen_trtllm_low_latency_gemm_module, gen_tgv_gemm_sm10x_module, gen_gemm_sm90_module, ) @@ -35,6 +36,7 @@ "gen_gemm_sm100_module", "gen_gemm_sm120_module", "gen_trtllm_gen_gemm_module", + "gen_trtllm_low_latency_gemm_module", "gen_tgv_gemm_sm10x_module", "gen_gemm_sm90_module", "gen_deepgemm_sm100_module", diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index 9f55cb0920..a65d1873b5 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -499,3 +499,31 @@ def gen_gemm_sm90_module() -> JitSpec: source_paths, extra_cuda_cflags=sm90a_nvcc_flags, ) + + +def gen_trtllm_low_latency_gemm_module() -> JitSpec: + include_path = f"{ArtifactPath.TRTLLM_GEN_GEMM}/include" + header_name = "flashinferMetaInfo" + + # use `get_cubin` to get "flashinferMetaInfo.h" + metainfo = get_cubin( + f"{include_path}/{header_name}.h", + MetaInfoHash.TRTLLM_GEN_GEMM, + ) + # make sure "flashinferMetaInfo.h" is downloaded or cached + assert metainfo, f"{header_name}.h not found" + return gen_jit_spec( + "trtllm_low_latency_gemm", + [ + jit_env.FLASHINFER_CSRC_DIR / "trtllm_low_latency_gemm_runner.cu", + ], + extra_cuda_cflags=[ + "-DTLLM_GEN_EXPORT_INTERFACE", + "-DTLLM_ENABLE_CUDA", + f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"', + ] + + sm100a_nvcc_flags, + # link "include" sub-directory in cache + extra_include_paths=[jit_env.FLASHINFER_CUBIN_DIR / include_path], + extra_ldflags=["-lcuda"], + ) diff --git a/flashinfer/trtllm_low_latency_gemm.py b/flashinfer/trtllm_low_latency_gemm.py new file mode 100644 index 0000000000..2d69bc1e98 --- /dev/null +++ b/flashinfer/trtllm_low_latency_gemm.py @@ -0,0 +1,224 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from types import SimpleNamespace +from typing import Dict, List + +import functools + +from flashinfer.fused_moe.core import ( + convert_to_block_layout, + get_w2_permute_indices_with_cache, +) +from flashinfer.jit.gemm.core import gen_trtllm_low_latency_gemm_module +import torch + +from flashinfer.autotuner import ( + AutoTuner, + TuningConfig, + DynamicTensorSpec, + ConstraintSpec, + TunableRunner, + OptimizationProfile, +) +from flashinfer.fused_moe.utils import ( + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, +) +from flashinfer.jit import setup_cubin_loader +from flashinfer.utils import _get_cache_buf + + +@functools.cache +def get_trtllm_low_latency_gemm_module(): + mod = gen_trtllm_low_latency_gemm_module() + op = mod.build_and_load() + setup_cubin_loader(str(mod.get_library_path())) + + class TrtllmLowLatencyGemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + a_tensor_index = 0 + b_tensor_index = 1 + + # NOTE : expects A=MxK, B=(K//B)xNxB, out=MxN + a = profile.get_opt_shapes()[a_tensor_index] + b = profile.get_opt_shapes()[b_tensor_index] + m = a[0] + n = b[1] + k = a[1] + ( + a, + b, + global_scale, + out, + ) = inputs + type_e4m3 = 1 + type_bf16 = 2 + valid_tactics = list( + op.trtllm_low_latency_gemm_tactics(m, n, k, type_e4m3, type_bf16) + ) + return valid_tactics + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ) -> torch.Tensor: + ( + a, + b, + global_scale, + out, + ) = inputs + if tactic < 0: + return out + m = a.shape[0] + n = b.shape[1] + k = a.shape[1] + workspace_size = op.get_workspace_size_in_bytes(m, n, k, tactic) + workspace_buffer = _get_cache_buf( + "trllm_low_latency_gemm", workspace_size, a.device + ) + op.trtllm_low_latency_gemm( + workspace_buffer, + a, + b, + global_scale, + out, + tactic, + ) + return out + + def gemm_runner(): + return TrtllmLowLatencyGemmRunner() + + # Register the module + return SimpleNamespace( + gemm_runner=gemm_runner, + ) + + +def trtllm_low_latency_gemm( + A: torch.Tensor, + B: torch.Tensor, + global_scale: torch.Tensor, + out: torch.Tensor, +) -> None: + r"""GEMM optimized for low M dimension. B needs to be shuffled and its layout needs to be adjusted. + Only supported on Blackwell GPUs. + + Parameters + ---------- + A: torch.Tensor + Input tensor, shape (m, k), fp8 e4m3. + + B: torch.Tensor + Mat2 tensor, shape (k // block_size, n, block_size), fp8 e4m3. block_size is 128 for e4m3. + + global_scale: torch.Tensor + Scale tensor for the output, float. + + out: torch.Tensor + Out tensor, shape (m, n), bf16. + + Examples + -------- + >>> import torch + >>> from flashinfer import mm_fp8, prepare_low_latency_gemm_weights + >>> m = 16 + >>> n = 2560 + >>> k = 32768 + >>> a = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + >>> a_fp8, a_inv_s = to_float8(a, dtype=torch.float8_e4m3fn) + >>> b = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) + >>> b_fp8, b_inv_s = to_float8(b, dtype=torch.float8_e4m3fn) + >>> prepared_b = prepare_low_latency_gemm_weights(b_fp8, _cache_permute_indices) + >>> prepared_b.shape + torch.Size([256, 16, 128]) + >>> global_scale = a_inv_s * b_inv_s + >>> out = torch.zeros([m, n], device="cuda", dtype=torch.bfloat16) + >>> mm_fp8(a_fp8, prepared_b, global_scale, out) + >>> out.shape + torch.Size([16, 2560]) + """ + + tuner = AutoTuner.get() + a_tensor_index = 0 + out_tensor_index = 3 + tuning_config = TuningConfig( + dynamic_tensor_specs=( + DynamicTensorSpec( + (a_tensor_index,), + (-2,), + get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2, + ), + ), + constraint_specs=( + ConstraintSpec( + out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2] + ), + ), + ) + inputs = [A, B, global_scale, out] + runners: List[TunableRunner] = [] + runners.append(get_trtllm_low_latency_gemm_module().gemm_runner()) + runner, tactic = tuner.choose_one( + "trtllm_low_latency_gemm", + runners, + tuning_config, + inputs, + ) + + runner(inputs=inputs, tactic=tactic) + return out + + +def prepare_low_latency_gemm_weights( + w: torch.Tensor, permutation_indices_cache: Dict[torch.Size, torch.Tensor] +) -> torch.Tensor: + r"""Helper method to prepare the input weight tensor for low-latency TRTLLM GEMM. It includes shuffling and converting to block layout. + + Parameters + ---------- + w: torch.Tensor + The weight tensor to shuffle, shape (n, k), fp8 e4m3. + + permutation_indices_cache: dict + Some location to cache permutation indices. Calculating them is expensive. + + Returns + ------- + block_layout_shuffled_weights: torch.Tensor + The shuffled and block-layout weight tensor, shape (k // 128, n, 128), fp8 e4m3. + """ + + epilogue_tile_m = 128 # NOTE: should be aligned with kernel configuration. + + permute_indices = get_w2_permute_indices_with_cache( + permutation_indices_cache, w, epilogue_tile_m + ) + shuffled_weights = w[permute_indices.to(device=w.device)].contiguous() + + block_k = 128 + block_layout_weights = convert_to_block_layout(shuffled_weights, block_k) + return block_layout_weights diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 448f6d116a..28340440c9 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -186,7 +186,7 @@ def get_alibi_slopes(n_heads: int) -> torch.Tensor: def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: key = (name, device) buf = _cache_buf.get(key) - if buf is None: + if buf is None or buf.size(0) < bytes: buf = torch.empty(bytes, dtype=torch.uint8, device=device) _cache_buf[key] = buf return buf @@ -734,3 +734,11 @@ def get_shuffle_matrix_sf_a_row_indices( row_indices = get_shuffle_matrix_a_row_indices(input_tensor, epilogue_tile_m) return row_indices + + +def get_native_fp4_dtype(): + """get native fp4 datatype if supported in Torch, otherwise return uint8.""" + if hasattr(torch, "float4_e2m1fn_x2"): + return torch.float4_e2m1fn_x2 + else: + return torch.uint8 diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h index 8439d82de3..d77b936476 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/Enums.h @@ -97,6 +97,23 @@ enum class TileScheduler { //////////////////////////////////////////////////////////////////////////////////////////////////// +enum class CtaSwizzleType : uint32_t { + // Rasterize CTAs along the M dimension. + RasterizeAlongM = 0, + // Rasterize CTAs along the N dimension. + RasterizeAlongN, + // Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 2. + ZigZagAlongM2, + // Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 2. + ZigZagAlongN2, + // Swizzle CTAs in zig-zag pattern along M dimension, Zig-zag width is 4. + ZigZagAlongM4, + // Swizzle CTAs in zig-zag pattern along N dimension, Zig-zag width is 4. + ZigZagAlongN4, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // Helper functions to check the SplitK type. #define SPLIT_K_FUNCTION(Mode) \ diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h index 187b8e936c..ffea1cf4f4 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmInterface.h @@ -285,6 +285,12 @@ class GemmInterface { template inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const; + // Returns the number of tiles and number of CTAs for Z dimension. + std::tuple getGridSize(int32_t M, int32_t N, int32_t tileM, + int32_t tileN, int32_t clusterDimX, + int32_t clusterDimY, + int32_t numSlicesForSplitK) const; + // Creates GemmOptions from kernel and data. GemmOptions getOptionsFromConfigAndData(GemmConfig const& config, GemmData const& data) const; @@ -328,6 +334,20 @@ size_t GemmInterface::getNumGemmConfigs() const { //////////////////////////////////////////////////////////////////////////////////////////////////// +std::tuple GemmInterface::getGridSize(int32_t M, int32_t N, + int32_t tileM, int32_t tileN, + int32_t clusterDimX, + int32_t clusterDimY, + int32_t numSlicesForSplitK) const { + // The number of tiles in the M dimension. + auto numTilesM = gemm::divUpMul(gemm::divUp(M, tileM), clusterDimX); + // The number of tiles in the N dimension. + auto numTilesN = gemm::divUpMul(gemm::divUp(N, tileN), clusterDimY); + return std::make_tuple(numTilesM, numTilesN, numSlicesForSplitK); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + GemmOptions GemmInterface::getOptionsFromConfigAndData(GemmConfig const& config, GemmData const& data) const { // Create options from config and data. @@ -363,10 +383,10 @@ std::vector GemmInterface::getWorkspaceSizesInBytes(GemmConfig const& co // Get options from config. auto& options = config.mOptions; - // The number of tiles in the M dimension. - int32_t numTilesM = gemm::divUp(data.mProblemDimensions.mM, options.mTileM); - // The number of tiles in the N dimension. - int32_t numTilesN = gemm::divUp(data.mProblemDimensions.mN, options.mTileN); + // Get the number of tiles and cluster dimension Z. + auto [numTilesM, numTilesN, gridDimZ] = getGridSize( + data.mProblemDimensions.mM, data.mProblemDimensions.mN, options.mTileM, options.mTileN, + options.mClusterDimX, options.mClusterDimY, options.mNumSlicesForSplitK); std::vector workspaceSizes; @@ -439,10 +459,10 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c } } - // The number of tiles in the M dimension. - int numTilesM = gemm::divUp(options.mM, options.mTileM); - // The number of tiles in the N dimension. - int numTilesN = gemm::divUp(options.mN, options.mTileN); + // Get the number of tiles and number of CTAs for Z dimension. + auto [numTilesM, numTilesN, gridDimZ] = + getGridSize(options.mM, options.mN, options.mTileM, options.mTileN, options.mClusterDimX, + options.mClusterDimY, options.mNumSlicesForSplitK); // Create kernel params. auto kernelParams = gemm::KernelParamsSetup::setKernelParams( @@ -455,9 +475,8 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c data.mAllReduceBuffers.mPtrMultiMemCompletionBars, dPtrSplitKCompletionBars, /* dPtrNumNonExitingCtas */ nullptr, data.mProblemDimensions.mRank, data.mProblemDimensions.mWorldSize); - // The size of the grid. - std::vector grid{numTilesM, numTilesN, options.mNumSlicesForSplitK}; + std::vector grid{numTilesM, numTilesN, gridDimZ}; // When split-k is enabled and to guarantee the forward progress, we must ensure that the number // of tiles is less than number of SMs. This way, at least one CTA in the grid can make forward. @@ -482,6 +501,7 @@ int32_t GemmInterface::run(GemmConfig const& config, void* workspace, GemmData c std::string cubin = flashinfer::trtllm_cubin_loader::getCubin(fname_cubin, sha256); cuModuleLoadData(&cuModule, cubin.c_str()); }; + if (moduleCache.has_value()) { ModuleCache& moduleCacheRef = moduleCache.value().get(); @@ -564,10 +584,11 @@ int32_t GemmInterface::runInitBeforeWorldSync(GemmConfig const& config, GemmData return 1; } } - // The number of tiles in the M dimension. - int numTilesM = gemm::divUp(options.mM, options.mTileM); - // The number of tiles in the N dimension. - int numTilesN = gemm::divUp(options.mN, options.mTileN); + + // Get the number of tiles and number of CTAs for Z dimension. + auto [numTilesM, numTilesN, gridDimZ] = + getGridSize(options.mM, options.mN, options.mTileM, options.mTileN, options.mClusterDimX, + options.mClusterDimY, options.mNumSlicesForSplitK); // The number of bytes for the tile barriers. int32_t numBytesTileBars = numTilesM * numTilesN * sizeof(uint32_t); // Sanitize system barriers. diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h index 5bfb9f40e4..cfb9e348cc 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/GemmOptions.h @@ -92,16 +92,18 @@ struct GemmOptions { GemmOptions() = default; GemmOptions(AllReduceAlgo allReduceAlgo, BiasType biasType, int blockK, int clusterDimX, - int clusterDimY, int clusterDimZ, tg::Dtype dtypeAcc, tg::Dtype dtypeA, - tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, tg::Dtype dtypeMmaB, - bool enablesEarlyExit, bool enablesDelayedEarlyExit, bool enablesGlobalPtxKnobs, - int epilogueLdtmDps, int epilogueLdtmBits, int epilogueTileM, int epilogueTileN, - bool gridTriggerSecondaryA, bool gridTriggerSecondaryB, - bool gridWaitForPrimaryEarlyExit, bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, - bool hoistLoadTaskInit, bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, - MatrixLayout layoutA, MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, - int mmaM, int mmaN, bool mockAllReduce, int n, int numSlicesForSplitK, - int numSlicesForSliceK, int numStages, int numStagesMma, + int clusterDimY, int clusterDimZ, CtaSwizzleType ctaSwizzleType, tg::Dtype dtypeAcc, + tg::Dtype dtypeA, tg::Dtype dtypeB, tg::Dtype dtypeC, tg::Dtype dtypeMmaA, + tg::Dtype dtypeMmaB, bool enablesEarlyExit, bool enablesDelayedEarlyExit, + bool enablesGlobalPtxKnobs, int epilogueLdtmDps, int epilogueLdtmBits, + int epilogueTileM, int epilogueTileN, bool gridTriggerSecondaryA, + bool gridTriggerSecondaryB, bool gridWaitForPrimaryEarlyExit, + bool gridWaitForPrimaryA, bool gridWaitForPrimaryB, bool hoistLoadTaskInit, + bool hoistMmaTaskTryWaits, int k, KernelTraits kernelTraits, MatrixLayout layoutA, + MatrixLayout layoutB, int m, int mmaK, tg::MmaKind mmaKind, int mmaM, int mmaN, + bool mockAllReduce, int n, int numRegsCastAWarps, int numRegsCopySfLdsSttm, + int numRegsPerThreadEpilogueWarp, int numRegsPerThreadNonEpilogueWarp, + int numSlicesForSplitK, int numSlicesForSliceK, int numStages, int numStagesMma, int numStagesMmaWithinWorkTile, int numStagesMmaAcrossWorkTile, int numStagesWorkId, bool outputDebugTensors, bool patchF2fp, std::optional sfBlockSizeA, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC, @@ -117,6 +119,7 @@ struct GemmOptions { mClusterDimX{clusterDimX}, mClusterDimY{clusterDimY}, mClusterDimZ{clusterDimZ}, + mCtaSwizzleType{ctaSwizzleType}, mDtypeAcc{dtypeAcc}, mDtypeA{dtypeA}, mDtypeB{dtypeB}, @@ -148,6 +151,10 @@ struct GemmOptions { mMmaN{mmaN}, mMockAllReduce{mockAllReduce}, mN{n}, + mNumRegsCastAWarps(numRegsCastAWarps), + mNumRegsCopySfLdsSttm(numRegsCopySfLdsSttm), + mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp), + mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp), mNumSlicesForSplitK{numSlicesForSplitK}, mNumSlicesForSliceK{numSlicesForSliceK}, mNumStages{numStages}, @@ -193,6 +200,8 @@ struct GemmOptions { int mClusterDimY{1}; // Cluster size in Z dim. int mClusterDimZ{1}; + // The type of CTA swizzle. + CtaSwizzleType mCtaSwizzleType{CtaSwizzleType::RasterizeAlongM}; // Data type of the accumulators. tg::Dtype mDtypeAcc{tg::Dtype::Fp32}; // Data type of the A matrix. @@ -263,6 +272,14 @@ struct GemmOptions { bool mMockAllReduce{false}; // The N dimension of GEMM. int mN{64 * 4}; + // Number of registers for the cast A warps. + int mNumRegsCastAWarps{0}; + // Number of registers for the LDS+STTM warps. + int mNumRegsCopySfLdsSttm{0}; + // Number of registers per thread for epilogue warps + int mNumRegsPerThreadEpilogueWarp{0}; + // Number of registers per thread for non-epilogue warps + int mNumRegsPerThreadNonEpilogueWarp{0}; // Number of partitions along the K dimension. When mNumSlicesForSplitK > 1, // the problem is distributed across several SMs, where each CTA works on its local K slice. // Partial results are accumulated afterwards using either GMEM or DSMEM (in CGA) @@ -369,6 +386,7 @@ struct GemmConfig { char const* mHash{nullptr}; #else trtllm::gen::CudaRunner* mCudaRunner{nullptr}; + int32_t mInstanceIdx{0}; #endif GemmOptions mOptions{}; @@ -409,6 +427,8 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mClusterDimX=" << options.mClusterDimX << "," << std::endl; ss << "mClusterDimY=" << options.mClusterDimY << "," << std::endl; ss << "mClusterDimZ=" << options.mClusterDimZ << "," << std::endl; + ss << "mCtaSwizzleType=" << "gemm::CtaSwizzleType(" + << static_cast(options.mCtaSwizzleType) << ")" << "," << std::endl; ss << "mDtypeAcc=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeAcc) << ")" << "," << std::endl; ss << "mDtypeA=" << "trtllm::gen::Dtype(" << static_cast(options.mDtypeA) << ")" << "," @@ -449,6 +469,12 @@ inline std::string dumpOptions(GemmOptions const& options) { ss << "mMmaN=" << options.mMmaN << "," << std::endl; ss << "mMockAllReduce=" << options.mMockAllReduce << "," << std::endl; ss << "mN=" << options.mN << "," << std::endl; + ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl; + ss << "mNumRegsCopySfLdsSttm=" << options.mNumRegsCopySfLdsSttm << "," << std::endl; + ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," + << std::endl; + ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," + << std::endl; ss << "mNumSlicesForSplitK=" << options.mNumSlicesForSplitK << "," << std::endl; ss << "mNumSlicesForSliceK=" << options.mNumSlicesForSliceK << "," << std::endl; ss << "mNumStages=" << options.mNumStages << "," << std::endl; @@ -673,18 +699,27 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in if ((options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4 || options.mDtypeC == tg::Dtype::MxE4m3) && options.mMmaM != 128) { - // MMA M must be 128 when the input uses block scaling, or when the output is an Mx format. - int newTileM = 128 * divUp(options.mTileM, 128); - TLLM_LOG_WARNING("Unsupported MmaM (", options.mMmaM, - ") for MmaKind=", gemm::toString(options.mMmaKind), - ". Setting MmaM to 128 and TileM to ", newTileM); - if (updateOptions) { - options.mMmaM = 128; - options.mTileM = newTileM; + if (options.mClusterDimX == 1) { + // MMA M must be 128 when the input uses block scaling, or when the output is an Mx format. + int newTileM = 128 * divUp(options.mTileM, 128); + TLLM_LOG_WARNING("Unsupported MmaM (", options.mMmaM, + ") for MmaKind=", gemm::toString(options.mMmaKind), + ". Setting MmaM to 128 and TileM to ", newTileM); + if (updateOptions) { + options.mMmaM = 128; + options.mTileM = newTileM; + } else { + return false; + } } else { - return false; + TLLM_CHECK_ERROR(options.mMmaM == 256 && options.mTileM == 128, + "2CTA UTCxMMA only supports mmaM = 256 and tileM = 128."); } } + if (options.mClusterDimX > 1) { + TLLM_CHECK_ERROR(options.mLayoutB != MatrixLayout::BlockMajorK, + "layoutB == MatrixLayout::BlockMajorK is not supported for now"); + } if (options.mMmaKind == tg::MmaKind::MxFp4NvFp4 || options.mMmaKind == tg::MmaKind::MxFp8Fp6Fp4) { TLLM_CHECK_ERROR(isBlackwell, "Block scaling is only supported on Blackwell"); @@ -869,14 +904,26 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } if (!options.mSliceK) { - TLLM_CHECK_ERROR(options.mMmaM <= options.mEpilogueTileM, + TLLM_CHECK_ERROR(options.mMmaM / options.mClusterDimX <= options.mEpilogueTileM, "EpilogueTileM must be larger or equal than mmaM."); + } else { + // FIXME: this is not necessary limitation. Simply fixing num repeats in TmemSliceKA should be + // enough. + TLLM_CHECK_ERROR((options.mTileN & (options.mTileN - 1)) == 0, + "For Slice-K TileN is required to be a power of 2"); + } + + if (options.mClusterDimX == 2) { + TLLM_CHECK_ERROR(options.mMmaM == 256, "Only mmaM = 256 is supported for 2CTA UTCMMA."); + TLLM_CHECK_ERROR(options.mMmaN % 16 == 0, "mmaN needs to be multiple of 16 for 2CTA UTCMMA."); } + TLLM_CHECK_ERROR( options.mTileM % options.mEpilogueTileM == 0 && options.mTileN % options.mEpilogueTileN == 0, "TileM and TileN must be divisible by EpilogueTileM and EpilogueTileN respectively."); - TLLM_CHECK_ERROR(options.mClusterDimX == 1 && options.mClusterDimY == 1, - "GEMM does not support cluster in X and Y dimensions."); + TLLM_CHECK_ERROR( + (options.mClusterDimX == 1 || options.mClusterDimX == 2) && options.mClusterDimY == 1, + "GEMM does not support cluster in X and Y dimensions."); TLLM_CHECK_ERROR(options.mClusterDimZ == 1 || options.mNumSlicesForSplitK > 1, "Cluster DimZ is only allowed for split-k."); TLLM_CHECK_ERROR(options.mTileM <= 128, "GEMM does not support TileM > 128."); @@ -1003,6 +1050,9 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in "Non-DeepSeekFp8 requires persistent scheduler when using numStagesMma >1"); } } + if (options.mUseDeepSeekFp8) { + TLLM_CHECK_ERROR(options.mClusterDimX == 1, "2CTA Gemm is not supported for DeepSeekFp8"); + } if (options.mUseDeepSeekFp8) { TLLM_CHECK_ERROR(options.mDtypeA == tg::Dtype::E4m3 && options.mDtypeB == tg::Dtype::E4m3, "A and B dtype must be E4m3 for DeepSeek Fp8. Found dtypeA=", @@ -1085,22 +1135,34 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in ")"); } + // Number of iterations in K dimension after padding. + // Note the perCtaK in each CTA in the splitK group are padded to the same number of iterations. + // E.g., K = 512, TileK = 128, numSlicesForSplitK = 3. Then the padded K is + // + // ceil(512 / (128*3)) * (128*3) = 768 + // + int const paddedK = divUpMul(options.mK, options.mTileK * options.mNumSlicesForSplitK); + int const perCtaK = paddedK / options.mNumSlicesForSplitK; + // However, number of iterations is clamped to multiples of tileK within individual CTAs + // E.g., K = 448, TileK = 64, numSlicesForSplitK = 4. + // + // paddedK = 512 + // perCtaK = 128 + // clampedPerCtaK for CTA 0, 1, 2 = 128 + // clampedPerCtaK for CTA 3 = 64 + int const paddingForK = paddedK - options.mK; + int const clampedAndPaddedPerCtaK = divUpMul(perCtaK - paddingForK, options.mTileK); if (options.mUseUnrollLoop2xForMma) { - // Number of iterations in K dimension after padding. - // Note the perCtaK in each CTA in the splitK group are padded to the same number of iterations. - // E.g., K = 512, TileK = 128, numSlicesForSplitK = 3. Then the padded K is - // - // ceil(512 / (128*3)) * (128*3) = 768 - // - int paddedK = divUpMul(options.mK, options.mTileK * options.mNumSlicesForSplitK); - // Check that the padded K (K rounded to next multiple of tileK) is a multiple of 2*TileK when - // UnrollLoop2x is enabled. This is to avoid deadlock when mma runs even-numbered loop while the - // other warps run odd-numbered loop. + // Check that the padded K and clamped padded K (K rounded to next multiple of tileK) is a + // multiple of 2*TileK when UnrollLoop2x is enabled. This is to avoid deadlock when mma runs + // even-numbered loop while the other warps run odd-numbered loop. // - bool notSupported = (paddedK / options.mNumSlicesForSplitK) % (options.mTileK * 2) != 0; + bool notSupported = (perCtaK % (options.mTileK * 2) != 0) || + (clampedAndPaddedPerCtaK % (options.mTileK * 2) != 0); if (notSupported) { TLLM_LOG_WARNING("Size K / splitK must be a multiple of TileK * 2. Found TileK=", options.mTileK, " and K=", options.mK, " (paddedK=", paddedK, + " clampedAndPaddedPerCtaK=", clampedAndPaddedPerCtaK, ") and numSlicesForSplitK=", options.mNumSlicesForSplitK, ". Disabling unrollLoop2xForMma."); if (updateOptions) { @@ -1110,6 +1172,11 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in } } } + if (options.mNumSlicesForSplitK > 1) { + TLLM_CHECK_ERROR( + perCtaK * (options.mNumSlicesForSplitK - 1) < options.mK, + "K must be greater than perCtaK * (numSlicesForSplitK - 1) to ensure each CTA has work"); + } if (!isBlackwell && options.mTileScheduler == TileScheduler::Persistent) { // TODO(anchengc): will be supported in upcoming MRs. @@ -1242,7 +1309,8 @@ inline bool checkAndUpdateGemmOptions(GemmOptions& options, bool isBlackwell, in options.mNumStagesMma, options.mNumSlicesForSplitK, options.mNumSlicesForSliceK, options.mSplitK, options.mUseTmaStore, options.mTransposeMmaOutput, options.mAllReduceAlgo, options.mTileScheduler == TileScheduler::Persistent, options.mUseDeepSeekFp8, - options.mUsePerTokenSfA, options.mUsePerTokenSfB, options.mBiasType); + options.mUsePerTokenSfA, options.mUsePerTokenSfB, + /* useTwoCtas*/ options.mClusterDimX == 2, options.mBiasType); } return true; diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h index 7e0fdfaa3c..64d065cd21 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelParams.h @@ -58,6 +58,10 @@ static auto makeTmaShapeStrideAb(GemmOptions const& options, MatrixType matrixTy // Assemble the box shape std::vector tileShape = {options.mTileK, tileMn}; + // When using 2CTA MMA, we only need to load half of the tile in each CTA for B. + if (matrixType == MatrixType::MatrixB && options.mClusterDimX == 2) { + tileShape[1] /= 2; + } MatrixLayout layout = (matrixType == MatrixType::MatrixA) ? options.mLayoutA : options.mLayoutB; if (layout == MatrixLayout::MajorMn) { @@ -66,6 +70,7 @@ static auto makeTmaShapeStrideAb(GemmOptions const& options, MatrixType matrixTy stride[1] = numTokens; std::swap(tileShape[0], tileShape[1]); } else if (layout == MatrixLayout::BlockMajorK) { + // FIXME: fix for the 2CTA MMA case // Set shapes based on blocking layout shape = {static_cast(options.mBlockK), static_cast(numTokens), static_cast(options.mK / options.mBlockK)}; diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h index f413b4371f..4ca6af8a4c 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/KernelTraits.h @@ -163,7 +163,7 @@ class KernelTraits { int32_t numSlicesForSplitK, int32_t numSlicesForSliceK, SplitK splitK, bool useTmaStore, bool transposeMmaOutput, AllReduceAlgo allReduceAlgo, bool usePersistentScheduler, bool useDeepSeekFp8, bool usePerTokenSfA, - bool usePerTokenSfB, BiasType biasType) + bool usePerTokenSfB, bool useTwoCtas, BiasType biasType) : mMmaKind{mmaKind} { // // SMEM @@ -213,8 +213,8 @@ class KernelTraits { // LoadB { // Number of bytes in load B shared memory. - auto const numSmemBytesLoadB = - numStages * tileN * tileK * getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */; + auto const numSmemBytesLoadB = numStages * (useTwoCtas ? tileN / 2 : tileN) * tileK * + getNumSmemBitsPerElt(dtypeB, mMmaKind) / 8 /* bits */; // Number of bytes for load B alignment for TMA load. auto const numBytesAlignmentLoadB = 1024; // No need to reuse the first chunk. diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h index cd42f1c6a8..0722a42a4e 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/TmaDescriptor.h @@ -156,7 +156,7 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, char const* errorString; cuGetErrorString(result, &errorString); std::stringstream ss; - ss << "Error: Failed to initialize the TMA descriptor " << result << std::endl; + ss << "Error: Failed to initialize the TMA descriptor. " << errorString << std::endl; ss << "tmaFormat: " << static_cast(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl; @@ -251,7 +251,7 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector c char const* errorString; cuGetErrorString(result, &errorString); std::stringstream ss; - ss << "Error: Failed to initialize the TMA descriptor for SF " << errorString << std::endl; + ss << "Error: Failed to initialize the TMA descriptor for SF. " << errorString << std::endl; ss << "tmaFormat: " << static_cast(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl; diff --git a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CommonUtils.h b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CommonUtils.h index 27c7ab7193..331f6cd285 100644 --- a/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CommonUtils.h +++ b/include/flashinfer/trtllm/gemm/trtllmGen_gemm_export/trtllm/gen/CommonUtils.h @@ -23,6 +23,21 @@ namespace gen { //////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMA OOB optimization constants. +// +// CUDA Programming Guide states that "globalDim must be non-zero and less than or equal to 2^32". +// In practice, the kernel acts funny with TMA shape of 2^32 so we use 2^31. +constexpr unsigned long TmaDimMax = 1UL << 31; +// Chosen so that LargeN * XLargeN * sizeof(dtype) >= 2^64 which causes overflow and effectively +// becomes 0. As sizeof(dtype) can be as small as 0.5B, we choose LargeN = 2^30 and XLargeN = 2^35 +// so overflow can happen. +constexpr unsigned long LargeN = 1UL << 30; +// Used in TMA stride. Should be less than 2^40. +constexpr unsigned long XLargeN = 1UL << 35; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template inline T ceilDiv(T m, T n) { return (m + n - T(1)) / n; diff --git a/tests/gemm/test_bmm_fp8.py b/tests/gemm/test_bmm_fp8.py index 630e4004c0..f750edddc2 100644 --- a/tests/gemm/test_bmm_fp8.py +++ b/tests/gemm/test_bmm_fp8.py @@ -4,15 +4,7 @@ from flashinfer import autotune, bmm_fp8 from flashinfer.utils import get_compute_capability - - -def to_float8(x, dtype=torch.float8_e4m3fn): - finfo = torch.finfo(dtype) - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) - scale = finfo.max / amax - x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) - return x_scl_sat.to(dtype), scale.float().reciprocal() +from tests.utils_fp8 import to_float8 @pytest.mark.parametrize("b", [1, 16]) diff --git a/tests/gemm/test_mm_fp8.py b/tests/gemm/test_mm_fp8.py new file mode 100644 index 0000000000..53b6a0f676 --- /dev/null +++ b/tests/gemm/test_mm_fp8.py @@ -0,0 +1,59 @@ +from typing import Dict +from flashinfer.utils import get_compute_capability +import pytest +import torch +import torch.nn.functional as F + +from flashinfer import autotune, mm_fp8 +from tests.utils_fp8 import to_float8 +from flashinfer import prepare_low_latency_gemm_weights + +_cache_permute_indices: Dict[torch.Size, torch.Tensor] = {} + + +@pytest.mark.parametrize("m", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("n", [2560, 5120]) +@pytest.mark.parametrize("k", [8192, 16384, 32768]) +@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16]) +def test_mm_fp8( + m: int, + n: int, + k: int, + input_dtype: torch.dtype, + mat2_dtype: torch.dtype, + res_dtype: torch.dtype, +): + compute_capability = get_compute_capability(torch.device(device="cuda")) + if compute_capability[0] not in [10]: + pytest.skip("mm_fp8 is only supported on Blackwell GPUs.") + + torch.manual_seed(123) + input = torch.randn([m, k], device="cuda", dtype=torch.bfloat16) + input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) + + mat2 = torch.randn([n, k], device="cuda", dtype=torch.bfloat16) + mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype) + + res = torch.zeros([m, n], device="cuda", dtype=res_dtype) + global_scale = input_inv_s * mat2_inv_s + + prepared_weights = prepare_low_latency_gemm_weights( + mat2_fp8, _cache_permute_indices + ) + with autotune(): + mm_fp8( + input_fp8, + prepared_weights, + global_scale, + out=res, + ) + + reference = torch.mm(input, mat2.transpose(-2, -1)) + cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) + assert cos_sim > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 880c739259..c0a4ecfcec 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -43,7 +43,7 @@ trtllm_fp8_per_tensor_scale_moe, ) from flashinfer.fused_moe.core import ( - _maybe_get_cached_w2_permute_indices, + get_w2_permute_indices_with_cache, _maybe_get_cached_w3_w1_permute_indices, ) from flashinfer.utils import calculate_tile_tokens_dim, get_compute_capability @@ -468,7 +468,7 @@ def prepare_static_weights_for_kernel( ) ) - permute_indices = _maybe_get_cached_w2_permute_indices( + permute_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, gemm2_weights_fp4[i].view(torch.uint8), epilogue_tile_m, @@ -479,7 +479,7 @@ def prepare_static_weights_for_kernel( .contiguous() ) - permute_sf_indices = _maybe_get_cached_w2_permute_indices( + permute_sf_indices = get_w2_permute_indices_with_cache( self._cache_permute_indices, gemm2_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m, diff --git a/tests/utils_fp8.py b/tests/utils_fp8.py new file mode 100644 index 0000000000..3d6afa8453 --- /dev/null +++ b/tests/utils_fp8.py @@ -0,0 +1,12 @@ +import torch + + +def to_float8( + x: torch.Tensor, dtype=torch.float8_e4m3fn +) -> tuple[torch.Tensor, torch.Tensor]: + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal()