diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh new file mode 100644 index 0000000000..3b7ed113e7 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/compiler.cuh @@ -0,0 +1,543 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "jit_utils.cuh" +#include "nvrtc.h" +#include "runtime.cuh" +#include "scheduler.cuh" + +#ifdef _WIN32 +#include +#endif + +namespace deep_gemm::jit { + +// Generate a unique ID for temporary directories to avoid collisions +std::string generateUniqueId() { + // Use current time and random number to generate a unique ID + static std::mt19937 gen(std::random_device{}()); + static std::uniform_int_distribution<> distrib(0, 999999); + + auto now = std::chrono::system_clock::now(); + auto now_ms = std::chrono::time_point_cast(now); + auto value = now_ms.time_since_epoch().count(); + + // Use the static random generator + int random_value = distrib(gen); + + return std::to_string(value) + "_" + std::to_string(random_value); +} + +std::filesystem::path getDefaultUserDir() { + static std::filesystem::path userDir; + if (userDir.empty()) { + char const* cacheDir = getenv("TRTLLM_DG_CACHE_DIR"); + if (cacheDir) { + userDir = cacheDir; + std::filesystem::create_directories(userDir); + } else { +#ifdef _WIN32 + char const* appData = getenv("APPDATA"); + if (appData) { + userDir = std::filesystem::path(appData) / "tensorrt_llm"; + } else { + userDir = std::filesystem::temp_directory_path() / "tensorrt_llm"; + } +#else + char const* homeDir = getenv("HOME"); + if (homeDir) { + userDir = std::filesystem::path(homeDir) / ".tensorrt_llm"; + } else { + userDir = std::filesystem::temp_directory_path() / "tensorrt_llm"; + } +#endif + } + } + return userDir; +} + +inline std::filesystem::path getTmpDir() { return getDefaultUserDir() / "tmp"; } + +inline std::filesystem::path getCacheDir() { return getDefaultUserDir() / "cache"; } + +std::string getNvccCompiler() { + static std::string compiler; + if (compiler.empty()) { + // Check environment variable + char const* envCompiler = getenv("TRTLLM_DG_NVCC_COMPILER"); + if (envCompiler) { + compiler = envCompiler; + } else { + // Check CUDA_HOME + char const* cudaHome = getenv("CUDA_HOME"); + if (cudaHome) { + std::filesystem::path cudaPath(cudaHome); +#ifdef _WIN32 + compiler = (cudaPath / "bin" / "nvcc.exe").string(); +#else + compiler = (cudaPath / "bin" / "nvcc").string(); +#endif + } else { +// Default to system nvcc +#ifdef _WIN32 + compiler = "nvcc.exe"; +#else + compiler = "nvcc"; +#endif + } + } + } + return compiler; +} + +std::vector getJitIncludeDirs() { + static std::vector includeDirs; + if (includeDirs.empty()) { + // Command to execute + char const* cmd = "pip show tensorrt_llm 2>/dev/null"; + + // Buffer to store the output + std::array buffer; + std::string result; + +// Open pipe to command +#ifdef _MSC_VER + FILE* pipe = _popen(cmd, "r"); +#else + FILE* pipe = popen(cmd, "r"); +#endif + + if (pipe) { + // Read the output + while (fgets(buffer.data(), buffer.size(), pipe) != nullptr) { + result += buffer.data(); + } + +// Close the pipe +#ifdef _MSC_VER + _pclose(pipe); +#else + pclose(pipe); +#endif + + // Parse the location using regex + // `pip show tensorrt_llm` will output something like: + // Location: /usr/local/lib/python3.12/dist-packages + // Editable project location: /code + std::regex locationRegex("(Location|Editable project location): (.+)"); + + // Find all matches + auto match_begin = std::sregex_iterator(result.begin(), result.end(), locationRegex); + auto match_end = std::sregex_iterator(); + + // Get the number of matches + auto match_count = std::distance(match_begin, match_end); + + if (match_count > 0) { + // Get the last match + auto last_match_iter = match_begin; + std::advance(last_match_iter, match_count - 1); + + // Get the path from the second capture group + std::string location = last_match_iter->str(2); + location.erase(location.find_last_not_of(" \n\r\t") + 1); + + // Set the include directory based on the package location + includeDirs.push_back(std::filesystem::path(location) / "tensorrt_llm" / "include"); + + if (!kJitUseNvcc) { + includeDirs.push_back(std::filesystem::path(location) / "tensorrt_llm" / "include" / + "cuda" / "include"); + } + } + } else { + TLLM_LOG_WARNING("Failed to find TensorRT LLM installation, DeepGEMM will be disabled."); + } + } + return includeDirs; +} + +std::string generateKernel(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m, + uint32_t const block_n, uint32_t const block_k, + uint32_t const num_groups, uint32_t const num_stages, + uint32_t const num_tma_multicast, deep_gemm::GemmType const gemm_type, + bool swapAB = false) { + constexpr uint32_t kNumTMAThreads = 128; + constexpr uint32_t kNumMathThreadsPerGroup = 128; + + std::string input_type; + if (!swapAB) { + switch (gemm_type) { + case deep_gemm::GemmType::Normal: + input_type = "NormalSchedulerInput"; + break; + case deep_gemm::GemmType::GroupedContiguous: + input_type = "GroupedContiguousSchedulerInput"; + break; + case deep_gemm::GemmType::GroupedMasked: + input_type = "GroupedMaskedSchedulerInput"; + break; + case deep_gemm::GemmType::GroupedWithOffset: + input_type = "GroupedWithOffsetSchedulerInput"; + break; + case deep_gemm::GemmType::StridedBatched: + input_type = "StridedBatchedSchedulerInput"; + break; + default: + throw std::runtime_error("Unsupported gemm type"); + } + } else { + switch (gemm_type) { + case deep_gemm::GemmType::Normal: + input_type = "NormalSchedulerInputSwapAB"; + break; + case deep_gemm::GemmType::GroupedWithOffset: + input_type = "GroupedWithOffsetSchedulerInputSwapAB"; + break; + default: + throw std::runtime_error("Unsupported gemm type"); + } + } + + // Modify kernel name based on swapAB to determine which kernel function to use + std::string kernel_name = swapAB ? "fp8_gemm_kernel_swapAB" : "fp8_gemm_kernel"; + std::string scheduler_name = swapAB ? "SchedulerSelectorSwapAB" : "SchedulerSelector"; + + // Create the kernel source code using raw string literal + std::string code = R"( +#ifdef __CUDACC_RTC__ +#ifndef NVRTC_JIT_COMPILATION +#define NVRTC_JIT_COMPILATION +#endif + +#include + +#else + +#include +#include + +#endif + +#include +#include +#include +#include + +using namespace deep_gemm; + +using SchedulerType = +typename )" + scheduler_name + + R"(::type; + +__global__ void dummy_kernel() { + void *ptr = (void *)&)" + + kernel_name + R"(<)" + std::to_string(shape_n) + R"(, )" + + std::to_string(shape_k) + R"(, )" + std::to_string(block_m) + R"(, )" + + std::to_string(block_n) + R"(, )" + std::to_string(block_k) + R"(, )" + + std::to_string(num_groups) + R"(, )" + std::to_string(num_stages) + R"(, )" + + std::to_string(kNumTMAThreads) + R"(, )" + + std::to_string(kNumMathThreadsPerGroup) + R"(, )" + + std::to_string(num_tma_multicast) + R"(, SchedulerType, )" + input_type + R"(>; +} +)"; + + return code; +} + +/** + * C++ implementation of the Compiler class + * Compiles CUDA code into CUBINs + */ +class Compiler { + public: + // Get singleton instance + static Compiler& getInstance() { + static Compiler instance; + return instance; + } + + [[nodiscard]] bool isValid() const { return !includeDirs_.empty(); } + + // Build function + Runtime* build(uint32_t const shape_n, uint32_t const shape_k, uint32_t const block_m, + uint32_t const block_n, uint32_t const block_k, uint32_t const num_groups, + uint32_t const num_stages, uint32_t const num_tma_multicast, + deep_gemm::GemmType const gemm_type, bool swapAB = false) { + int sm_version = tensorrt_llm::common::getSMVersion(); + if (sm_version != 90) { + TLLM_THROW( + "DeepGEMM only supports Hopper (SM90) architectures, but current device compute " + "capability is %d.", + sm_version); + } + + // Build signature - simplified, no MD5 calculation + std::string name = std::string(swapAB ? "gemm_swapAB_" : "gemm_") + std::to_string(shape_n) + + "_" + std::to_string(shape_k) + "_" + std::to_string(block_m) + "_" + + std::to_string(block_n) + "_" + std::to_string(block_k) + "_" + + std::to_string(num_groups) + "_" + std::to_string(num_stages) + + std::to_string(num_groups) + "_" + std::to_string(num_stages) + "_" + + std::to_string(num_tma_multicast) + "_" + gemm_type_to_string(gemm_type); + std::filesystem::path path = getCacheDir() / name; + + // Check runtime cache or file system hit + auto& runtimeCache = getGlobalRuntimeCache(); + Runtime* cachedRuntime = runtimeCache[path.string()]; + if (cachedRuntime != nullptr) { + if (kJitDebugging) { + TLLM_LOG_INFO("Using cached JIT runtime %s during build", name.c_str()); + } + return cachedRuntime; + } + + // Compiler flags + std::vector flags = {"-std=c++17", + "--gpu-architecture=sm_90a", + "--ptxas-options=-allow-expensive-optimizations=true", + "--ptxas-options=--register-usage-level=10", + "--diag-suppress=161,174,177,940", + "-D__FORCE_INCLUDE_CUDA_FP16_HPP_FROM_FP16_H__=1", + "-D__FORCE_INCLUDE_CUDA_BF16_HPP_FROM_BF16_H__=1"}; + + if (kJitUseNvcc) { + flags.push_back("-O3"); + flags.push_back("-cubin"); + flags.push_back("--expt-relaxed-constexpr"); + flags.push_back("--expt-extended-lambda"); + + std::vector cxxFlags = {"-fPIC", "-O3", "-Wno-deprecated-declarations", + "-Wno-abi"}; + std::string cxxFlagsStr = "--compiler-options="; + for (size_t i = 0; i < cxxFlags.size(); ++i) { + cxxFlagsStr += cxxFlags[i]; + if (i < cxxFlags.size() - 1) { + cxxFlagsStr += ","; + } + } + flags.push_back(cxxFlagsStr); + } else { + flags.push_back("-default-device"); + } + + std::filesystem::path tmpPath = getTmpDir() / (name + "_" + generateUniqueId()); + std::filesystem::path cubinPath = path / kKernelName; + std::filesystem::path tmpCubinPath = tmpPath / kKernelName; + + // Create the target directory if it doesn't exist + if (kJitUseNvcc || kJitDumpCubin) { + std::filesystem::create_directories(tmpPath); + std::filesystem::create_directories(path); + } + + for (auto const& dir : includeDirs_) { + flags.push_back("-I" + dir.string()); + } + + // Print options if debug enabled + if (kJitDebugging) { + TLLM_LOG_INFO("Compiling JIT runtime %s with options: ", name.c_str()); + for (auto const& flag : flags) { + TLLM_LOG_INFO("%s ", flag.c_str()); + } + TLLM_LOG_INFO("\n"); + } + + std::string code = generateKernel(shape_n, shape_k, block_m, block_n, block_k, num_groups, + num_stages, num_tma_multicast, gemm_type, swapAB); + + if (kJitDebugging) { + TLLM_LOG_INFO("Generated kernel code:\n%s", code.c_str()); + } + + if (kJitUseNvcc) { + std::filesystem::path tmpSrcPath = tmpPath / "kernel.cu"; + + // Write files + std::ofstream srcFile(tmpSrcPath); + srcFile << code; + srcFile.close(); + + // Build command + std::vector command = {getNvccCompiler(), tmpSrcPath.string(), "-o", + tmpCubinPath.string()}; + command.insert(command.end(), flags.begin(), flags.end()); + + // Execute command + std::string cmd; + for (auto const& arg : command) { + cmd += arg + " "; + } + + // Buffer to store the output + std::array buffer; + std::string result; + + // Time the compilation + auto start = std::chrono::high_resolution_clock::now(); + + // Open pipe to command +#ifdef _MSC_VER + FILE* pipe = _popen(cmd.c_str(), "r"); +#else + FILE* pipe = popen(cmd.c_str(), "r"); +#endif + + if (pipe) { + // Read the output + while (fgets(buffer.data(), buffer.size(), pipe) != nullptr) { + result += buffer.data(); + } + +// Close the pipe +#ifdef _MSC_VER + _pclose(pipe); +#else + pclose(pipe); +#endif + + // Output result if debug enabled + if (kJitDebugging) { + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start); + TLLM_LOG_INFO("NVCC compilation took %d ms", duration.count()); + TLLM_LOG_INFO("Compilation log:\n%s", result.c_str()); + } + } + } else { + nvrtcProgram prog; + CHECK_NVRTC(nvrtcCreateProgram(&prog, code.c_str(), "kernel.cu", 0, nullptr, nullptr)); + + std::vector options; + for (auto const& flag : flags) { + options.push_back(flag.c_str()); + } + + // Time the compilation + auto start = std::chrono::high_resolution_clock::now(); + nvrtcResult compileResult = nvrtcCompileProgram(prog, options.size(), options.data()); + + if (kJitDebugging) { + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start); + TLLM_LOG_INFO("NVRTC compilation took %d ms", duration.count()); + + size_t logSize; + CHECK_NVRTC(nvrtcGetProgramLogSize(prog, &logSize)); + std::vector log(logSize); + CHECK_NVRTC(nvrtcGetProgramLog(prog, log.data())); + TLLM_LOG_INFO("Compilation log:\n%s", log.data()); + } + + // Check if compilation succeeded + if (compileResult != NVRTC_SUCCESS) { + TLLM_LOG_ERROR("NVRTC compilation failed"); + CHECK_NVRTC(nvrtcDestroyProgram(&prog)); + throw std::runtime_error("NVRTC compilation failed"); + } + + // Save CUBIN to a file + size_t cubinSize; + CHECK_NVRTC(nvrtcGetCUBINSize(prog, &cubinSize)); + std::vector cubin(cubinSize); + CHECK_NVRTC(nvrtcGetCUBIN(prog, cubin.data())); + + // Cache the runtime in memory by default + if (!kJitDumpCubin) { + auto runtime = std::make_unique(path.string(), cubin, gemm_type); + Runtime* result = runtime.get(); + runtimeCache.set(path.string(), std::move(runtime)); + if (kJitDebugging) { + TLLM_LOG_INFO("Successfully cached JIT runtime %s in memory", name.c_str()); + } + return result; + } + + std::ofstream cubinFile(tmpCubinPath.string(), std::ios::binary); + cubinFile.write(cubin.data(), static_cast(cubinSize)); + cubinFile.close(); + CHECK_NVRTC(nvrtcDestroyProgram(&prog)); + } + + // Copy the source and compiled files to the cache directory + try { + // Rename (atomic operation) to final locations + std::filesystem::rename(tmpCubinPath, cubinPath); + if (kJitDebugging) { + TLLM_LOG_INFO("Successfully copied kernel files to cache directory: %s", + path.string().c_str()); + } + } catch (std::exception const& e) { + TLLM_LOG_ERROR("Warning: Failed to copy kernel files to cache: %s", e.what()); + } + + // Clean up temporary directory after successful compilation + try { + std::filesystem::remove_all(tmpPath); + } catch (std::exception const& e) { + TLLM_LOG_ERROR("Warning: Failed to clean up temporary directory: %s", e.what()); + } + + // Create runtime and cache it + auto runtime = std::make_unique(path.string(), std::vector(), gemm_type); + Runtime* result = runtime.get(); + runtimeCache.set(path.string(), std::move(runtime)); + return result; + } + + private: + std::vector includeDirs_; + + // Private constructor for singleton pattern + Compiler() : includeDirs_(getJitIncludeDirs()) { + // Create necessary directories + if (kJitUseNvcc || kJitDumpCubin) { + std::filesystem::create_directories(getTmpDir()); + std::filesystem::create_directories(getCacheDir()); + } + } + + // Delete copy constructor and assignment operator + Compiler(Compiler const&) = delete; + Compiler& operator=(Compiler const&) = delete; +}; + +// Global function to access the Compiler singleton +inline Compiler& getGlobalCompiler() { return Compiler::getInstance(); } + +} // namespace deep_gemm::jit diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/fp8_gemm.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/fp8_gemm.cuh new file mode 100644 index 0000000000..d68386242f --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/fp8_gemm.cuh @@ -0,0 +1,414 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 DeepSeek + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/MIT + * + * + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" +#pragma once + +#include +#include + +#include +#include +#include + +#include "compiler.cuh" +#include "fp8_gemm_impl.cuh" +#include "mma_utils.cuh" +#include "scheduler.cuh" +#include "tma_utils.cuh" +#include "utils.cuh" + +namespace deep_gemm { +template +static CUtensorMap make_2d_tma_a_desc(T* global_address, uint32_t shape_m, uint32_t shape_k, + uint32_t block_m, uint32_t block_k, uint32_t num_groups, + GemmType gemm_type, uint64_t global_stride_in_bytes = 0) { + return make_2d_tma_desc(global_address, Layout::RowMajor, + shape_m * (gemm_type == GemmType::GroupedMasked ? num_groups : 1), + shape_k, block_m, block_k, global_stride_in_bytes); +} + +template +CUtensorMap make_2d_tma_b_desc(T* global_address, uint32_t shape_n, uint32_t shape_k, + uint32_t block_n, uint32_t block_k, uint32_t num_groups, + GemmType gemm_type, uint64_t global_stride_in_bytes = 0) { + return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k, + shape_n * (gemm_type != GemmType::Normal ? num_groups : 1), block_k, + block_n, global_stride_in_bytes); +} + +template +CUtensorMap make_2d_tma_d_desc(T* global_address, uint32_t shape_m, uint32_t shape_n, + uint32_t block_m, uint32_t block_n, uint32_t num_groups, + GemmType gemm_type, uint64_t global_stride_in_bytes = 0) { + return make_2d_tma_desc(global_address, Layout::RowMajor, + shape_m * (gemm_type == GemmType::GroupedMasked ? num_groups : 1), + shape_n, min(block_m, shape_m), min(block_n, shape_n), + global_stride_in_bytes, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); +} + +template +CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint32_t shape_m, uint32_t shape_k, + uint32_t block_m, uint32_t block_k, uint32_t num_groups, + GemmType gemm_type, uint64_t global_stride_in_bytes = 0) { + // Make TMA aligned to 16 bytes + constexpr uint32_t kAlignment = 16 / sizeof(T); + shape_m = ceil_div(shape_m, kAlignment) * kAlignment; + + return make_2d_tma_desc( + global_address, Layout::ColMajor, shape_m, + ceil_div(shape_k, block_k) * + ((gemm_type == GemmType::GroupedMasked || gemm_type == GemmType::StridedBatched) + ? num_groups + : 1), + block_m, 1, global_stride_in_bytes, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); +} + +template +CUtensorMap make_tma_scales_a_offset_desc(T* global_address, int64_t max_m_padded_total, + uint32_t shape_k, uint32_t block_m, uint32_t block_k, + uint64_t global_stride_in_bytes = 0) { + return make_2d_tma_desc(global_address, Layout::ColMajor, max_m_padded_total, + ceil_div(shape_k, block_k), block_m, 1, global_stride_in_bytes, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); +} + +template +CUtensorMap make_2d_tma_a_desc_swapAB(T* global_address, uint32_t shape_m, uint32_t shape_k, + uint32_t block_m, uint32_t block_k, uint32_t num_groups, + GemmType gemm_type, uint64_t global_stride_in_bytes = 0) { + return make_2d_tma_desc(global_address, Layout::RowMajor, + shape_m * (gemm_type != GemmType::Normal ? num_groups : 1), shape_k, + block_m, block_k, global_stride_in_bytes); +} + +template +CUtensorMap make_2d_tma_b_desc_swapAB(T* global_address, uint32_t shape_n, uint32_t shape_k, + uint32_t block_n, uint32_t block_k, uint32_t num_groups, + GemmType gemm_type, uint64_t global_stride_in_bytes = 0) { + return make_2d_tma_desc(global_address, Layout::ColMajor, shape_k, + shape_n * (gemm_type == GemmType::GroupedMasked ? num_groups : 1), + block_k, block_n, global_stride_in_bytes); +} + +template +CUtensorMap make_2d_tma_d_desc_swapAB(T* global_address, uint32_t shape_m, uint32_t shape_n, + uint32_t block_m, uint32_t block_n, uint32_t num_groups, + GemmType gemm_type, uint64_t global_stride_in_bytes = 0) { + return make_2d_tma_desc(global_address, Layout::RowMajor, + shape_n * (gemm_type == GemmType::GroupedMasked ? num_groups : 1), + shape_m, min(block_n, shape_n), min(block_m, shape_m), + global_stride_in_bytes, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); +} + +template +CUtensorMap make_2d_tma_scales_b_desc_swapAB(T* global_address, uint32_t shape_n, uint32_t shape_k, + uint32_t block_n, uint32_t block_k, + uint32_t num_groups, GemmType gemm_type, + uint64_t global_stride_in_bytes = 0) { + // Make TMA aligned to 16 bytes + constexpr uint32_t kAlignment = 16 / sizeof(T); + shape_n = ceil_div(shape_n, kAlignment) * kAlignment; + + return make_2d_tma_desc( + global_address, Layout::RowMajor, + ceil_div(shape_k, block_k) * + ((gemm_type == GemmType::GroupedMasked || gemm_type == GemmType::StridedBatched) + ? num_groups + : 1), + shape_n, 1, block_n, global_stride_in_bytes, CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); +} + +template +CUtensorMap make_tma_scales_b_offset_desc_swapAB(T* global_address, int64_t max_n_padded_total, + uint32_t shape_k, uint32_t block_n, + uint32_t block_k, + uint64_t global_stride_in_bytes = 0) { + return make_2d_tma_desc(global_address, Layout::RowMajor, ceil_div(shape_k, block_k), + max_n_padded_total, 1, block_n, global_stride_in_bytes, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); +} + +template +CUtensorMap make_2d_tma_desc( + T* global_address, Layout layout, uint32_t gmem_rows, uint32_t gmem_cols, uint32_t smem_rows, + uint32_t smem_cols, uint64_t global_stride_in_bytes, + CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B) { + if (layout == Layout::RowMajor) { + uint64_t gmem_dim[2] = {gmem_cols, gmem_rows}; + uint32_t smem_dim[2] = {smem_cols, smem_rows}; + global_stride_in_bytes = + global_stride_in_bytes ? global_stride_in_bytes : gmem_cols * sizeof(T); + return make_2d_tma_copy_desc(global_address, gmem_dim, global_stride_in_bytes, smem_dim, + swizzle_type); + } else { + uint64_t gmem_dim[2] = {gmem_rows, gmem_cols}; + uint32_t smem_dim[2] = {smem_rows, smem_cols}; + global_stride_in_bytes = + global_stride_in_bytes ? global_stride_in_bytes : gmem_rows * sizeof(T); + return make_2d_tma_copy_desc(global_address, gmem_dim, gmem_rows * sizeof(T), smem_dim, + swizzle_type); + } +} + +template +void runGemm(cudaKernel_t kernel, void* mat_a, int ld_a, void* mat_b, int ld_b, void* mat_d, + int ld_d, float* scales_a, float* scales_b, uint32_t shape_m, uint32_t shape_n, + uint32_t shape_k, uint32_t block_m, uint32_t block_n, uint32_t block_k, + uint32_t num_groups, uint32_t num_tma_multicast, GemmType gemm_type, + LayoutIndexType* grouped_layout, cudaStream_t stream, int num_sms, + uint32_t smem_size) { + auto tma_a_desc = make_2d_tma_a_desc(reinterpret_cast<__nv_fp8_e4m3*>(mat_a), shape_m, shape_k, + block_m, block_k, num_groups, gemm_type, ld_a); + auto tma_b_desc = make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(mat_b), shape_n, shape_k, + block_n, block_k, num_groups, gemm_type, ld_b); + auto tma_scales_a_desc = make_2d_tma_scales_a_desc(scales_a, shape_m, shape_k, block_m, block_k, + num_groups, gemm_type); + auto tma_d_desc = make_2d_tma_d_desc(reinterpret_cast<__nv_bfloat16*>(mat_d), shape_m, shape_n, + block_m, block_n, num_groups, gemm_type, ld_d * 2); + + constexpr uint32_t kNumTMAThreads = 128; + constexpr uint32_t kNumMathThreadsPerGroup = 128; + DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size) == cudaSuccess); + + // Cluster launch + cudaLaunchConfig_t config; + config.gridDim = num_sms; + config.blockDim = get_num_threads_per_sm( + static_cast(block_m)); + config.dynamicSmemBytes = smem_size; + config.stream = stream; + + // Clusters for TMA multicast + // NOTES: `>= 4` cluster size will cause performance degradation + cudaLaunchAttribute attr; + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {num_tma_multicast, 1, 1}; + config.attrs = &attr; + config.numAttrs = 1; + + NormalSchedulerInput input; + input.shape_m = shape_m; + input.grouped_layout = grouped_layout; + + // Launch + auto status = + cudaLaunchKernelEx(&config, kernel, reinterpret_cast<__nv_bfloat16*>(mat_d), scales_b, input, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc); + DG_HOST_ASSERT(status == cudaSuccess); +} + +template +void runGemmSwapAB(cudaKernel_t kernel, void* mat_a, int ld_a, void* mat_b, int ld_b, void* mat_d, + int ld_d, float* scales_a, float* scales_b, uint32_t shape_m, uint32_t shape_n, + uint32_t shape_k, uint32_t block_m, uint32_t block_n, uint32_t block_k, + uint32_t num_groups, uint32_t num_tma_multicast, GemmType gemm_type, + LayoutIndexType* grouped_layout, cudaStream_t stream, int num_sms, + uint32_t smem_size) { + auto tma_a_desc = + make_2d_tma_a_desc_swapAB(reinterpret_cast<__nv_fp8_e4m3*>(mat_a), shape_m, shape_k, block_m, + block_k, num_groups, gemm_type, ld_a); + auto tma_b_desc = + make_2d_tma_b_desc_swapAB(reinterpret_cast<__nv_fp8_e4m3*>(mat_b), shape_n, shape_k, block_n, + block_k, num_groups, gemm_type, ld_b); + auto tma_scales_b_desc = make_2d_tma_scales_b_desc_swapAB(scales_b, shape_n, shape_k, block_n, + block_k, num_groups, gemm_type); + auto tma_d_desc = + make_2d_tma_d_desc_swapAB(reinterpret_cast<__nv_bfloat16*>(mat_d), shape_m, shape_n, block_m, + block_n, num_groups, gemm_type, ld_d * 2); + + constexpr uint32_t kNumTMAThreads = 128; + constexpr uint32_t kNumMathThreadsPerGroup = 128; + DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size) == cudaSuccess); + + // Cluster launch + cudaLaunchConfig_t config; + config.gridDim = num_sms; + config.blockDim = get_num_threads_per_sm( + static_cast(block_m)); + config.dynamicSmemBytes = smem_size; + config.stream = stream; + + // Clusters for TMA multicast + cudaLaunchAttribute attr; + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {num_tma_multicast, 1, 1}; + config.attrs = &attr; + config.numAttrs = 1; + + NormalSchedulerInputSwapAB input; + input.shape_n = shape_n; + input.grouped_layout = grouped_layout; + + auto status = + cudaLaunchKernelEx(&config, kernel, reinterpret_cast<__nv_bfloat16*>(mat_d), scales_a, input, + tma_a_desc, tma_b_desc, tma_scales_b_desc, tma_d_desc); + DG_HOST_ASSERT(status == cudaSuccess); +} + +template +void runGemm(cudaKernel_t kernel, void* mat_a, int ld_a, void* mat_b, int ld_b, void* mat_d, + int ld_d, float* scales_a, float* scales_b, uint32_t shape_m, uint32_t shape_n, + uint32_t shape_k, uint32_t block_m, uint32_t block_n, uint32_t block_k, + uint32_t num_groups, uint32_t num_tma_multicast, GemmType gemm_type, + LayoutIndexType* problem_m_offsets, cudaStream_t stream, int num_sms, + uint32_t smem_size, uint32_t max_shape_m_padded) { + auto tma_a_desc = make_2d_tma_a_desc(reinterpret_cast<__nv_fp8_e4m3*>(mat_a), shape_m, shape_k, + block_m, block_k, num_groups, gemm_type); + auto tma_b_desc = make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(mat_b), shape_n, shape_k, + block_n, block_k, num_groups, gemm_type); + auto tma_scales_a_desc = + make_tma_scales_a_offset_desc(scales_a, max_shape_m_padded, shape_k, block_m, block_k); + auto tma_d_desc = make_2d_tma_d_desc(reinterpret_cast<__nv_bfloat16*>(mat_d), shape_m, shape_n, + block_m, block_n, num_groups, gemm_type); + constexpr uint32_t kNumTMAThreads = 128; + constexpr uint32_t kNumMathThreadsPerGroup = 128; + DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size) == cudaSuccess); + + // Cluster launch + cudaLaunchConfig_t config; + config.gridDim = num_sms; + config.blockDim = get_num_threads_per_sm( + static_cast(block_m)); + config.dynamicSmemBytes = smem_size; + config.stream = stream; + + // Clusters for TMA multicast + // NOTES: `>= 4` cluster size will cause performance degradation + cudaLaunchAttribute attr; + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {num_tma_multicast, 1, 1}; + config.attrs = &attr; + config.numAttrs = 1; + + GroupedWithOffsetSchedulerInput input; + input.problem_m_offsets = problem_m_offsets; + + // Launch + auto status = + cudaLaunchKernelEx(&config, kernel, reinterpret_cast<__nv_bfloat16*>(mat_d), scales_b, input, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc); + DG_HOST_ASSERT(status == cudaSuccess); +} + +template +void runGemmSwapAB(cudaKernel_t kernel, void* mat_a /* weight*/, int ld_a, void* mat_b /* act*/, + int ld_b, void* mat_d, int ld_d, float* scales_a /* weight scales*/, + float* scales_b /* act scales*/, uint32_t shape_m, uint32_t shape_n, + uint32_t shape_k, uint32_t block_m, uint32_t block_n, uint32_t block_k, + uint32_t num_groups, uint32_t num_tma_multicast, GemmType gemm_type, + LayoutIndexType* problem_n_offsets, cudaStream_t stream, int num_sms, + uint32_t smem_size, uint32_t max_shape_n_padded) { + // Create tensor mappings using swapAB version functions, note the parameter order + auto tma_a_desc = make_2d_tma_a_desc_swapAB(reinterpret_cast<__nv_fp8_e4m3*>(mat_a), shape_m, + shape_k, block_m, block_k, num_groups, gemm_type); + auto tma_b_desc = make_2d_tma_b_desc_swapAB(reinterpret_cast<__nv_fp8_e4m3*>(mat_b), shape_n, + shape_k, block_n, block_k, num_groups, gemm_type); + auto tma_scales_b_desc = + make_tma_scales_b_offset_desc_swapAB(scales_b, max_shape_n_padded, shape_k, block_n, block_k); + auto tma_d_desc = make_2d_tma_d_desc_swapAB(reinterpret_cast<__nv_bfloat16*>(mat_d), shape_m, + shape_n, block_m, block_n, num_groups, gemm_type); + constexpr uint32_t kNumTMAThreads = 128; + constexpr uint32_t kNumMathThreadsPerGroup = 128; + DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size) == cudaSuccess); + + // Cluster launch + cudaLaunchConfig_t config; + config.gridDim = num_sms; + config.blockDim = get_num_threads_per_sm( + static_cast(block_m)); + config.dynamicSmemBytes = smem_size; + config.stream = stream; + + // Clusters for TMA multicast + cudaLaunchAttribute attr; + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {num_tma_multicast, 1, 1}; + config.attrs = &attr; + config.numAttrs = 1; + + // Update input structure to use N dimension offsets + GroupedWithOffsetSchedulerInputSwapAB input; + input.problem_n_offsets = problem_n_offsets; // Now offsets are for N dimension + + auto status = + cudaLaunchKernelEx(&config, kernel, reinterpret_cast<__nv_bfloat16*>(mat_d), scales_a, input, + tma_a_desc, tma_b_desc, tma_scales_b_desc, tma_d_desc); + DG_HOST_ASSERT(status == cudaSuccess); +} + +void runGemm(cudaKernel_t kernel, void* mat_a, uint64_t ld_a, uint64_t stride_a, void* mat_b, + uint64_t ld_b, uint64_t stride_b, void* mat_d, uint64_t ld_d, uint64_t stride_d, + float* scales_a, float* scales_b, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + uint32_t block_m, uint32_t block_n, uint32_t block_k, uint32_t num_groups, + uint32_t num_tma_multicast, GemmType gemm_type, cudaStream_t stream, int num_sms, + uint32_t smem_size) { + auto tma_a_desc = + make_2d_tma_a_desc(reinterpret_cast<__nv_fp8_e4m3*>(mat_a), shape_m * num_groups, shape_k, + block_m, block_k, num_groups, gemm_type, ld_a); + auto tma_b_desc = make_2d_tma_b_desc(reinterpret_cast<__nv_fp8_e4m3*>(mat_b), shape_n, shape_k, + block_n, block_k, num_groups, gemm_type, ld_b); + auto tma_scales_a_desc = make_2d_tma_scales_a_desc(scales_a, shape_m, shape_k, block_m, block_k, + num_groups, gemm_type); + auto tma_d_desc = make_2d_tma_d_desc(reinterpret_cast<__nv_bfloat16*>(mat_d), shape_m, shape_n, + block_m, block_n, num_groups, gemm_type, ld_d * 2); + constexpr uint32_t kNumTMAThreads = 128; + constexpr uint32_t kNumMathThreadsPerGroup = 128; + DG_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size) == cudaSuccess); + + // Cluster launch + cudaLaunchConfig_t config; + config.gridDim = num_sms; + config.blockDim = get_num_threads_per_sm( + static_cast(block_m)); + config.dynamicSmemBytes = smem_size; + config.stream = stream; + + // Clusters for TMA multicast + // NOTES: `>= 4` cluster size will cause performance degradation + cudaLaunchAttribute attr; + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {num_tma_multicast, 1, 1}; + config.attrs = &attr; + config.numAttrs = 1; + + StridedBatchedSchedulerInput input{shape_m, ld_a, stride_a, ld_b, stride_b, ld_d, stride_d}; + // Launch + auto status = + cudaLaunchKernelEx(&config, kernel, reinterpret_cast<__nv_bfloat16*>(mat_d), scales_b, input, + tma_a_desc, tma_b_desc, tma_scales_a_desc, tma_d_desc); + DG_HOST_ASSERT(status == cudaSuccess); +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/fp8_gemm_impl.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/fp8_gemm_impl.cuh new file mode 100644 index 0000000000..ae03a82216 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/fp8_gemm_impl.cuh @@ -0,0 +1,823 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "mma_utils.cuh" +#include "scheduler.cuh" +#include "tma_utils.cuh" +#include "utils.cuh" + +namespace deep_gemm { + +enum class Layout { RowMajor, ColMajor }; + +template +__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) { + DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group"); + return (block_m == 64 ? 1 : 2) * kNumMathThreadsPerGroup + kNumTMAThreads; +} + +template +static __device__ __forceinline__ void write_result_to_gmem( + __nv_bfloat16* gmem_d_this_block, __nv_bfloat16 const* smem_d, uint32_t const m_offset, + uint32_t const m_boundary, uint32_t const n_offset, uint32_t const shape_n, + uint32_t const ld_output) { + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + constexpr int int4_per_tile_line = BLOCK_N * sizeof(__nv_bfloat16) / sizeof(int4); + int int4_per_global_line = shape_n * sizeof(__nv_bfloat16) / sizeof(int4); + constexpr auto num_lines = BLOCK_M; + constexpr auto num_warps = NUM_WARPS_PER_BLOCK; + int4 const* smem_d_int4 = reinterpret_cast(smem_d); + bool is_last_n_block = n_offset + BLOCK_N > shape_n; + int int4_per_line = + is_last_n_block ? int4_per_global_line % int4_per_tile_line : int4_per_tile_line; + + for (int line_idx = warp_idx; line_idx < num_lines; line_idx += num_warps) { + if (m_offset + line_idx >= m_boundary) { + break; + } + for (int elem_idx = lane_idx; elem_idx < int4_per_line; elem_idx += 32) { + uint64_t idx = (uint64_t)line_idx * ld_output + n_offset; + int4* g_data_addr = reinterpret_cast(&gmem_d_this_block[idx]) + elem_idx; + int4 const* s_data_addr = &smem_d_int4[line_idx * (int4_per_tile_line) + elem_idx]; + *g_data_addr = *s_data_addr; + } + __syncwarp(); + } +} + +template +__global__ void __launch_bounds__( + get_num_threads_per_sm(BLOCK_M), 1) + fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, InputType problem_input, + __grid_constant__ const CUtensorMap tensor_map_a, + __grid_constant__ const CUtensorMap tensor_map_b, + __grid_constant__ const CUtensorMap tensor_map_scales_a, + __grid_constant__ const CUtensorMap tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ == 900)) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1, "Too much B scales in a single block"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Shared memory + static constexpr int kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); + static constexpr uint32_t SMEM_SCALES_B_SIZE = + ceil_div(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), + sizeof(Barrier)) * + sizeof(Barrier); + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + constexpr uint32_t kNumThreads = + get_num_threads_per_sm(BLOCK_M); + constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; + constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); + uint32_t const warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + uint32_t const lane_idx = get_lane_id(); + + // Prefetch TMA descriptors at very beginning + if (threadIdx.x == kNumMathThreads) { + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); + cute::prefetch_tma_descriptor( + reinterpret_cast(&tensor_map_scales_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; + __nv_fp8_e4m3* smem_b[kNumStages]; + float* smem_scales_a[kNumStages]; + float* smem_scales_b; + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + +// Fill shared memory pointers +#pragma unroll + for (int i = 0; i < kNumStages; ++i) { + smem_a[i] = + reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>( + smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + smem_scales_a[i] = reinterpret_cast( + smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + i * SMEM_SCALES_A_SIZE_PER_STAGE); + } + smem_scales_b = reinterpret_cast( + smem_buffer + SMEM_D_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)); + + // Fill barriers + auto barrier_start_ptr = + reinterpret_cast(reinterpret_cast(smem_scales_b) + SMEM_SCALES_B_SIZE); +#pragma unroll + for (int i = 0; i < kNumStages; ++i) { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (threadIdx.x == kNumMathThreads) { +#pragma unroll + for (int i = 0; i < kNumStages; ++i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK {}; + + struct NotDivisibleK {}; + + auto launch_k_iterations = [](auto const& func) { + if constexpr (SHAPE_K % kFullKOfAllStages == 0) { + for (int k_iter = 0; k_iter < kNumIterations; ++k_iter) func(k_iter, DivisibleK{}); + } else { + for (int k_iter = 0; k_iter < kNumIterations - 1; ++k_iter) func(k_iter, DivisibleK{}); + func(kNumIterations - 1, NotDivisibleK{}); + } + }; + + // Register reconfigurations + constexpr int kNumTMARegisters = 40; + constexpr int kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = SchedulerType(problem_input); + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = + kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + +#pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + + // Issue TMA A with broadcasting + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), + smem_a[s], k_idx, scheduler.get_global_m_idx(m_block_idx)); + + if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset) { + tma_copy( + &tensor_map_scales_a, reinterpret_cast(&full_barrier), + smem_scales_a[s], scheduler.get_global_scales_a_idx(m_block_idx), + k_idx / BLOCK_K); + } else { + tma_copy(&tensor_map_scales_a, + reinterpret_cast(&full_barrier), + smem_scales_a[s], m_block_idx * BLOCK_M, + scheduler.get_global_scales_a_idx(k_idx / BLOCK_K)); + } + + // Issue TMA B without broadcasting + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), smem_b[s], k_idx, + scheduler.get_global_n_idx(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx)); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + + SMEM_SCALES_A_SIZE_PER_STAGE); + } + +// Wait unaligned cases +#pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++s) { + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { +#pragma unroll + for (uint32_t s = 0; s < kNumStages; ++s) + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + auto const math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + auto const r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N"); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2); + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between + // tasks + if (threadIdx.x >= 32) { + auto num_previous_lines = + scheduler.get_global_scales_b_idx(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx); + ; + auto local_scales_b = + scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES; +#pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32) + st_shared(smem_scales_b + i, __ldg(local_scales_b + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + } + }; + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = + kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + +#pragma unroll + for (int s = 0; s < kNumInnerStages; ++s) { + // Read B scales + float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1 = 1.0f; + // NOTES: even some blocks do not need to read the second row, but we still load one to + // align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next + // scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0), + scale_a_1 = ld_shared(smem_scales_a[s] + r_1); + +// Commit WGMMA instructions +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++i) warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); +#pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++k) { + auto desc_a = + make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++i) warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + + // Promote with scales + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) { + bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + final_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } + +// Wait unaligned cases +#pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); +#pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++i) { + SM90_U32x4_STSM_N::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + i * 16 + 8 * (lane_idx / 16)); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], + final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], + final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + (warp_idx * 16 + lane_idx % 16) * BLOCK_N + WGMMA::kNumAccum / 8 * 16); + } + + if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset) { + auto m_global_idx = scheduler.get_global_m_idx(m_block_idx); + bool cross_boundary = (m_global_idx + BLOCK_M) > scheduler.m_boundary; + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + if (!cross_boundary) { + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + m_global_idx); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + } else { + __nv_bfloat16* gmem_d_this_block = gmem_d + m_global_idx * SHAPE_N; + constexpr int NUM_WARPS = + (get_num_threads_per_sm(BLOCK_M) - 128) / 32; + write_result_to_gmem( + gmem_d_this_block, smem_d, m_global_idx, scheduler.m_boundary, n_block_idx * BLOCK_N, + SHAPE_N, SHAPE_N); + } + } else if constexpr (SchedulerType::gemm_type == GemmType::StridedBatched) { + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + __nv_bfloat16* gmem_d_this_block; + auto m_global_idx = scheduler.get_global_m_idx(m_block_idx); + gmem_d_this_block = gmem_d + scheduler.curr_group_idx * problem_input.stride_d + + (m_block_idx * BLOCK_M) * problem_input.ld_d; + constexpr int NUM_WARPS = + (get_num_threads_per_sm(BLOCK_M) - 128) / 32; + write_result_to_gmem( + gmem_d_this_block, smem_d, m_global_idx, scheduler.m_boundary, n_block_idx * BLOCK_N, + SHAPE_N, problem_input.ld_d); + } else { + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, + scheduler.get_global_m_idx(m_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + } + + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +template +__global__ void __launch_bounds__( + get_num_threads_per_sm(BLOCK_M), 1) + fp8_gemm_kernel_swapAB( + __nv_bfloat16* gmem_d, float* scales_a, InputType problem_input, + const __grid_constant__ CUtensorMap tensor_map_a, // weight (previously act) + const __grid_constant__ CUtensorMap tensor_map_b, // act (previously weight) + const __grid_constant__ CUtensorMap + tensor_map_scales_b, // act scales (previously tensor_map_scales_a) + const __grid_constant__ CUtensorMap tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(ceil_div(BLOCK_M, BLOCK_K) == 1, "Too much A scales in a single block"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Shared memory + DG_STATIC_ASSERT(BLOCK_K % BLOCK_M == 0, "BLOCK_M should be 64 or 128 and BLOCK_K should be 128"); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_N * BLOCK_M * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = + BLOCK_N * sizeof(float); // B matrix (act) scales + static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE_PADDED = + ceil_div(BLOCK_N * sizeof(float), 128) * + 128; // B matrix (act) scales, 128B aligned + static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K); + static constexpr uint32_t SMEM_SCALES_A_SIZE = + ceil_div(SHAPE_K_SCALES * sizeof(float), sizeof(Barrier)) * + sizeof(Barrier); // renamed to A (weight) + + // Configs + constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K; + constexpr uint32_t kNumThreads = + get_num_threads_per_sm(BLOCK_M); + constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads; + constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_id(); + + // Prefetch TMA descriptors at very beginning + if (threadIdx.x == kNumMathThreads) { + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); + cute::prefetch_tma_descriptor( + reinterpret_cast(&tensor_map_scales_b)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + __nv_fp8_e4m3* smem_a[kNumStages]; // weight + __nv_fp8_e4m3* smem_b[kNumStages]; // act + float* smem_scales_b[kNumStages]; // act scales + float* smem_scales_a; // weight scales + + // TMA Barrier for both divisible and non-divisible cases + Barrier* full_barriers[kNumStages]; + Barrier* empty_barriers[kNumStages]; + +// Fill shared memory pointers +#pragma unroll + for (int i = 0; i < kNumStages; ++i) { + smem_a[i] = + reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>( + smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + smem_scales_b[i] = reinterpret_cast( + smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + i * SMEM_SCALES_B_SIZE_PER_STAGE_PADDED); + } + smem_scales_a = + reinterpret_cast(smem_buffer + SMEM_D_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + + SMEM_SCALES_B_SIZE_PER_STAGE_PADDED)); + + // Fill barriers + auto barrier_start_ptr = + reinterpret_cast(reinterpret_cast(smem_scales_a) + SMEM_SCALES_A_SIZE); +#pragma unroll + for (int i = 0; i < kNumStages; ++i) { + full_barriers[i] = barrier_start_ptr + i; + empty_barriers[i] = barrier_start_ptr + kNumStages + i; + } + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (threadIdx.x == kNumMathThreads) { +#pragma unroll + for (int i = 0; i < kNumStages; ++i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_view_async_shared(); + (kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // For pipeline unrolling + struct DivisibleK {}; + + struct NotDivisibleK {}; + + auto launch_k_iterations = [](auto const& func) { + if constexpr (SHAPE_K % kFullKOfAllStages == 0) { + for (int k_iter = 0; k_iter < kNumIterations; ++k_iter) func(k_iter, DivisibleK{}); + } else { + for (int k_iter = 0; k_iter < kNumIterations - 1; ++k_iter) func(k_iter, DivisibleK{}); + func(kNumIterations - 1, NotDivisibleK{}); + } + }; + + // Register reconfigurations + constexpr int kNumTMARegisters = 40; + constexpr int kNumMathRegisters = 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = SchedulerType(problem_input); + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (threadIdx.x == kNumMathThreads) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = + kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + +#pragma unroll + for (uint32_t s = 0; s < kNumInnerStages; ++s) { + // Wait consumer release + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + + // Issue TMA A (weight) now without broadcasting + auto& full_barrier = *full_barriers[s]; + int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K; + tma_copy(&tensor_map_a, reinterpret_cast(&full_barrier), smem_a[s], k_idx, + scheduler.get_global_m_idx(SHAPE_M, BLOCK_M, m_block_idx, n_block_idx)); + + // Issue TMA B (act) with broadcasting + tma_copy(&tensor_map_b, reinterpret_cast(&full_barrier), + smem_b[s], k_idx, scheduler.get_global_n_idx(n_block_idx)); + + // Issue TMA scales_b (act scales) for B matrix + if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset) { + tma_copy( + &tensor_map_scales_b, reinterpret_cast(&full_barrier), + smem_scales_b[s], scheduler.get_global_scales_b_idx(n_block_idx), + k_idx / BLOCK_K); + } else { + tma_copy(&tensor_map_scales_b, + reinterpret_cast(&full_barrier), + smem_scales_b[s], n_block_idx * BLOCK_N, + scheduler.get_global_scales_b_idx(k_idx / BLOCK_K)); + } + + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + + SMEM_SCALES_B_SIZE_PER_STAGE); + } + +// Wait unaligned cases +#pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++s) { + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1); + full_barriers[s]->arrive(); + } + }); + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { +#pragma unroll + for (uint32_t s = 0; s < kNumStages; ++s) + empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + auto const math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0); + + // Each thread loads consecutive 2 scales + const uint32_t scale_offset = (lane_idx % 4) * 2; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Load weight scales (scales_a) - these are associated with tensor_map_a (weight) + // Decide the number of scales A to load + DG_STATIC_ASSERT(SHAPE_M % 8 == 0, "Invalid shape M"); + uint32_t num_scales_a = SHAPE_K_SCALES; + + // Load A scales with math warp-groups (weight scales) + if (threadIdx.x >= 32) { + auto num_previous_lines = + scheduler.get_global_scales_a_idx(ceil_div(SHAPE_M, BLOCK_K), 0, 0, n_block_idx); + auto local_scales_a = + scales_a + (num_previous_lines + ((m_block_idx * BLOCK_M) / BLOCK_K)) * SHAPE_K_SCALES; +#pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_scales_a; i += kNumMathThreads - 32) + st_shared(smem_scales_a + i, __ldg(local_scales_a + i)); + } + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + + // Accumulation for WGMMA or CUDA promotion + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](int s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(lane_idx) : void(); + } + }; + + // Launch MMAs + launch_k_iterations([&](int k_iter, auto type) { + constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr int kNumInnerStages = + kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K; + DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); + +#pragma unroll + for (int s = 0; s < kNumInnerStages; ++s) { + // Read weight scales (A scales) + float scale_a_0 = ld_shared(smem_scales_a + k_iter * kNumStages + s); + + // Wait TMA arrivals + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next + // scheduled block polluting the results Each thread reads consecutive two b scales, each + // thread needs to read WGMMA::N / 4 * 2 b scales + float scale_0_0[WGMMA::kNumAccum / 4], scale_0_1[WGMMA::kNumAccum / 4]; +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) { + float2 scale_b = + ld_shared(reinterpret_cast(smem_scales_b[s] + i * 8 + scale_offset)); + scale_0_0[i] = scale_a_0 * scale_b.x; + scale_0_1[i] = scale_a_0 * scale_b.y; + } + +// Commit WGMMA instructions +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++i) warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); +#pragma unroll + for (int k = 0; k < BLOCK_K / WGMMA::K; ++k) { + auto desc_a = + make_smem_desc(smem_a[s] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum; ++i) warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(s); + +// Promote with scales +#pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) { + final_accum[i * 4 + 0] += scale_0_0[i] * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_0_1[i] * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_0_0[i] * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_0_1[i] * accum[i * 4 + 3]; + } + } + +// Wait unaligned cases +#pragma unroll + for (uint32_t s = kNumInnerStages; s < kNumStages; ++s) { + full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1); + empty_barrier_arrive(s); + } + }); + + // Write back to shared memory using STSM + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + int tid = 0; + if (lane_idx < 8) { + tid = lane_idx * BLOCK_M; + } else if (lane_idx < 16) { + tid = (lane_idx - 8) * BLOCK_M + 8; + } else if (lane_idx < 24) { + tid = (lane_idx - 8) * BLOCK_M; + } else { + tid = (lane_idx - 16) * BLOCK_M + 8; + } +#pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 8; ++i) { + SM90_U32x4_STSM_T::copy( + __float22bfloat162_rn({final_accum[i * 8 + 0], final_accum[i * 8 + 1]}), + __float22bfloat162_rn({final_accum[i * 8 + 2], final_accum[i * 8 + 3]}), + __float22bfloat162_rn({final_accum[i * 8 + 4], final_accum[i * 8 + 5]}), + __float22bfloat162_rn({final_accum[i * 8 + 6], final_accum[i * 8 + 7]}), + smem_d + warp_idx * 16 + i * 16 * BLOCK_M + tid); + } + if constexpr (WGMMA::kNumAccum % 8 != 0) { + SM90_U32x2_STSM_T::copy( + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 0], + final_accum[WGMMA::kNumAccum / 8 * 8 + 1]}), + __float22bfloat162_rn({final_accum[WGMMA::kNumAccum / 8 * 8 + 2], + final_accum[WGMMA::kNumAccum / 8 * 8 + 3]}), + smem_d + warp_idx * 16 + WGMMA::kNumAccum / 8 * 16 * BLOCK_M + tid); + } + + if constexpr (SchedulerType::gemm_type == GemmType::GroupedWithOffset) { + auto n_global_idx = scheduler.get_global_n_idx(n_block_idx); + bool cross_boundary = (n_global_idx + BLOCK_N) > scheduler.n_boundary; + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + if (!cross_boundary) { + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, m_block_idx * BLOCK_M, + n_global_idx); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + } else { + __nv_bfloat16* gmem_d_this_block = gmem_d + n_global_idx * SHAPE_M; + constexpr int NUM_WARPS = + (get_num_threads_per_sm(BLOCK_M) - 128) / 32; + write_result_to_gmem( + gmem_d_this_block, smem_d, n_global_idx, scheduler.n_boundary, m_block_idx * BLOCK_M, + SHAPE_M, SHAPE_M); + } + } else if constexpr (SchedulerType::gemm_type == GemmType::StridedBatched) { + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + __nv_bfloat16* gmem_d_this_block; + auto n_global_idx = scheduler.get_global_n_idx(n_block_idx); + gmem_d_this_block = gmem_d + scheduler.curr_group_idx * problem_input.stride_d + + (n_block_idx * BLOCK_N) * problem_input.ld_d; + constexpr int NUM_WARPS = + (get_num_threads_per_sm(BLOCK_M) - 128) / 32; + write_result_to_gmem( + gmem_d_this_block, smem_d, n_global_idx, scheduler.n_boundary, m_block_idx * BLOCK_M, + SHAPE_M, problem_input.ld_d); + } else { + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumMathThreads).sync(); + // Use TMA store to write back to global memory + if (threadIdx.x == 0) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, m_block_idx * BLOCK_M, + scheduler.get_global_n_idx(n_block_idx)); + cute::tma_store_arrive(); + cute::tma_store_wait<0>(); + } + } + + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} +} // namespace deep_gemm diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh new file mode 100644 index 0000000000..25c47eb8f6 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/jit_utils.cuh @@ -0,0 +1,231 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "scheduler.cuh" + +// Helper function to check NVRTC errors +#define CHECK_NVRTC(call) \ + do { \ + nvrtcResult result = call; \ + if (result != NVRTC_SUCCESS) { \ + std::cerr << "NVRTC error: " << nvrtcGetErrorString(result) << std::endl; \ + exit(1); \ + } \ + } while (0) + +// Helper function to check CUDA driver errors +#define CHECK_CUDA(call) \ + do { \ + CUresult result = call; \ + if (result != CUDA_SUCCESS) { \ + const char* error_string; \ + cuGetErrorString(result, &error_string); \ + std::cerr << "CUDA error: " << error_string << std::endl; \ + exit(1); \ + } \ + } while (0) + +namespace deep_gemm::jit { + +using GemmConfig = std::tuple; // block_m, block_n, num_stages, + // num_tma_multicast, best_smem_size + +std::string gemm_type_to_string(deep_gemm::GemmType gemm_type); + +int div_up(int a, int b); +int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k, bool swap_ab); +bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms); +GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + int num_groups, int num_device_sms, bool is_grouped_contiguous, + bool swap_ab); +} // namespace deep_gemm::jit + +namespace deep_gemm::jit { + +std::string gemm_type_to_string(deep_gemm::GemmType gemm_type) { + switch (gemm_type) { + case deep_gemm::GemmType::Normal: + return std::string("Normal"); + case deep_gemm::GemmType::GroupedContiguous: + return std::string("GroupedContiguous"); + case deep_gemm::GemmType::GroupedMasked: + return std::string("GroupedMasked"); + case deep_gemm::GemmType::GroupedWithOffset: + return std::string("GroupedWithOffset"); + case deep_gemm::GemmType::StridedBatched: + return std::string("StridedBatched"); + // Add other GEMM types as needed + default: + return std::string("Unknown"); + } +} + +int div_up(int a, int b) { return (a + b - 1) / b; } + +int get_smem_size(int num_stages, int k, int block_m, int block_n, int block_k = 128, + bool swap_ab = false) { + if (!swap_ab) { + int smem_d = block_m * block_n * 2; + int smem_a_per_stage = block_m * block_k; + int smem_scales_a_per_stage = block_m * 4; + int smem_b_per_stage = block_n * block_k; + int smem_scales_b = div_up(k, block_k) * 4; + int smem_barrier = num_stages * 8 * 2; + + int smem_size = 0; + smem_size += smem_d; + smem_size += num_stages * smem_a_per_stage; + smem_size += num_stages * smem_scales_a_per_stage; + smem_size += num_stages * smem_b_per_stage; + smem_size += div_up(smem_scales_b * (block_k % block_n == 0 ? 1 : 2), 8) * 8; + smem_size += smem_barrier; + + return smem_size; + } else { + int smem_d = block_n * block_m * 2; + int smem_a_per_stage = block_m * block_k; // weight + int smem_scales_a_per_stage = div_up(k, block_k) * 4; // weight scales + int smem_b_per_stage = block_n * block_k; // act + int smem_scales_b = div_up(block_n * 4, 128) * 128; // act scales,tma 128B alignment + int smem_barrier = num_stages * 8 * 2; + + int smem_size = 0; + smem_size += smem_d; + smem_size += num_stages * smem_a_per_stage; + smem_size += num_stages * smem_scales_b; + smem_size += num_stages * smem_b_per_stage; + smem_size += div_up(smem_scales_a_per_stage, 8) * 8; + smem_size += smem_barrier; + + return smem_size; + } +} + +bool is_tma_multicast_legal(int n, int block_n, int num_tma_multicast, int num_sms) { + if (num_tma_multicast == 1) { + return true; + } + return (n % (block_n * num_tma_multicast) == 0) && num_sms % num_tma_multicast == 0; +} + +GemmConfig get_best_gemm_config(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + int num_groups, int num_device_sms, + bool is_grouped_contiguous = false, bool swap_ab = false) { + // Choose candidate block sizes + std::vector block_ms; + block_ms.push_back((!is_grouped_contiguous && shape_m <= 64) ? 64 : 128); + + // Candidate block sizes for N dimension + std::vector block_ns; + for (int i = 16; i <= 128; i += 8) { + block_ns.push_back(i); + } + + // Lambda functions for calculating waves and utilization + auto fix_wave_saturate = [num_device_sms](int x) -> int { return x == 0 ? num_device_sms : x; }; + + auto get_num_waves = [shape_m, shape_n, num_groups, num_device_sms](int block_m, + int block_n) -> int { + return div_up(div_up(shape_m, block_m) * div_up(shape_n, block_n) * num_groups, num_device_sms); + }; + + auto get_last_wave_util = [shape_m, shape_n, num_groups, num_device_sms, &fix_wave_saturate]( + int block_m, int block_n) -> int { + return fix_wave_saturate((div_up(shape_m, block_m) * div_up(shape_n, block_n) * num_groups) % + num_device_sms); + }; + + // Find best block sizes + int best_block_m = 0; + int best_block_n = 0; + for (int block_m : block_ms) { + for (int block_n : block_ns) { + bool success = false; + int num_waves = get_num_waves(block_m, block_n); + int best_num_waves = best_block_m == 0 ? INT_MAX : get_num_waves(best_block_m, best_block_n); + + if (best_block_m == 0 || best_block_n == 0) { + success = true; + } else if (num_waves < best_num_waves) { + success = true; + } else if (num_waves == best_num_waves) { + // Check last wave utilization + int util = get_last_wave_util(block_m, block_n); + int best_util = get_last_wave_util(best_block_m, best_block_n); + success = util > best_util || + (util == best_util && + (block_m > best_block_m || (block_m == best_block_m && block_n < best_block_n))); + } + + if (success) { + best_block_m = block_m; + best_block_n = block_n; + } + } + } + + // Find best number of stages + int best_num_stages = 0; + int best_smem_size = 0; + constexpr int sm90_capacity = 232448; + + std::vector stage_candidates; + if (128 % best_block_n != 0) { + stage_candidates = {6, 5, 4}; + } else { + stage_candidates = {8, 7, 6, 5, 4}; + } + + for (int num_stages : stage_candidates) { + int smem_size = get_smem_size(num_stages, shape_k, best_block_m, best_block_n, 128, swap_ab); + if (smem_size <= sm90_capacity) { + best_num_stages = num_stages; + best_smem_size = smem_size; + break; + } + } + + // Determine TMA multicast settings + int best_num_tma_multicast = 1; + + if (!swap_ab) { + if (shape_m >= 1024 && is_tma_multicast_legal(shape_n, best_block_n, 2, num_device_sms) && + num_groups == 1) { + best_num_tma_multicast = 2; + } + } else { + if (shape_n >= 1024 && is_tma_multicast_legal(shape_m, best_block_m, 2, num_device_sms) && + num_groups == 1) { + best_num_tma_multicast = 2; + } + } + + return std::make_tuple(best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, + best_smem_size); +} +} // namespace deep_gemm::jit diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/mma_utils.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/mma_utils.cuh new file mode 100644 index 0000000000..e225855d38 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/mma_utils.cuh @@ -0,0 +1,943 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 DeepSeek + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/MIT + * + * + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifndef NVRTC_JIT_COMPILATION +#include +#endif + +#include "utils.cuh" + +namespace deep_gemm { + +struct SM90_64x16x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 16; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x24x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 24; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x32x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 32; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x40x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 40; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x48x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, + float& d21, float& d22, float& d23, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 48; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x56x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, + float& d21, float& d22, float& d23, float& d24, float& d25, + float& d26, float& d27, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}, " + " %28," + " %29," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 56; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x64x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, + float& d21, float& d22, float& d23, float& d24, float& d25, + float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}, " + " %32," + " %33," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 64; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x72x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, + float& d21, float& d22, float& d23, float& d24, float& d25, + float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, float& d32, float& d33, float& d34, float& d35, + bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}, " + " %36," + " %37," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], d[32], d[33], d[34], d[35], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 72; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x80x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, + float& d21, float& d22, float& d23, float& d24, float& d25, + float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, float& d32, float& d33, float& d34, float& d35, + float& d36, float& d37, float& d38, float& d39, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}, " + " %40," + " %41," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], d[32], d[33], d[34], d[35], d[36], d[37], + d[38], d[39], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 80; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x88x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, + float& d21, float& d22, float& d23, float& d24, float& d25, + float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, float& d32, float& d33, float& d34, float& d35, + float& d36, float& d37, float& d38, float& d39, float& d40, + float& d41, float& d42, float& d43, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}, " + " %44," + " %45," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], d[32], d[33], d[34], d[35], d[36], d[37], + d[38], d[39], d[40], d[41], d[42], d[43], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 88; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x96x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, + float& d21, float& d22, float& d23, float& d24, float& d25, + float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, float& d32, float& d33, float& d34, float& d35, + float& d36, float& d37, float& d38, float& d39, float& d40, + float& d41, float& d42, float& d43, float& d44, float& d45, + float& d46, float& d47, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}, " + " %48," + " %49," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], d[32], d[33], d[34], d[35], d[36], d[37], + d[38], d[39], d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 96; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x104x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, + float& d21, float& d22, float& d23, float& d24, float& d25, + float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, float& d32, float& d33, float& d34, float& d35, + float& d36, float& d37, float& d38, float& d39, float& d40, + float& d41, float& d42, float& d43, float& d44, float& d45, + float& d46, float& d47, float& d48, float& d49, float& d50, + float& d51, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}, " + " %52," + " %53," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], d[32], d[33], d[34], d[35], d[36], d[37], + d[38], d[39], d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], d[48], d[49], d[50], + d[51], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 104; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x112x32_F32E4M3E4M3_SS { + __device__ static void wgmma( + uint64_t const& desc_a, uint64_t const& desc_b, float& d00, float& d01, float& d02, + float& d03, float& d04, float& d05, float& d06, float& d07, float& d08, float& d09, + float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, float& d16, + float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, + float& d38, float& d39, float& d40, float& d41, float& d42, float& d43, float& d44, + float& d45, float& d46, float& d47, float& d48, float& d49, float& d50, float& d51, + float& d52, float& d53, float& d54, float& d55, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}, " + " %56," + " %57," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], d[32], d[33], d[34], d[35], d[36], d[37], + d[38], d[39], d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], d[48], d[49], d[50], + d[51], d[52], d[53], d[54], d[55], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 112; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x120x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, + float& d21, float& d22, float& d23, float& d24, float& d25, + float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, float& d32, float& d33, float& d34, float& d35, + float& d36, float& d37, float& d38, float& d39, float& d40, + float& d41, float& d42, float& d43, float& d44, float& d45, + float& d46, float& d47, float& d48, float& d49, float& d50, + float& d51, float& d52, float& d53, float& d54, float& d55, + float& d56, float& d57, float& d58, float& d59, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}, " + " %60," + " %61," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], d[32], d[33], d[34], d[35], d[36], d[37], + d[38], d[39], d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], d[48], d[49], d[50], + d[51], d[52], d[53], d[54], d[55], d[56], d[57], d[58], d[59], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 120; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x128x32_F32E4M3E4M3_SS { + __device__ static void wgmma( + uint64_t const& desc_a, uint64_t const& desc_b, float& d00, float& d01, float& d02, + float& d03, float& d04, float& d05, float& d06, float& d07, float& d08, float& d09, + float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, float& d16, + float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, + float& d38, float& d39, float& d40, float& d41, float& d42, float& d43, float& d44, + float& d45, float& d46, float& d47, float& d48, float& d49, float& d50, float& d51, + float& d52, float& d53, float& d54, float& d55, float& d56, float& d57, float& d58, + float& d59, float& d60, float& d61, float& d62, float& d63, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}, " + " %64," + " %65," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], d[32], d[33], d[34], d[35], d[36], d[37], + d[38], d[39], d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], d[48], d[49], d[50], + d[51], d[52], d[53], d[54], d[55], d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 128; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +struct SM90_64x192x32_F32E4M3E4M3_SS { + __device__ static void wgmma( + uint64_t const& desc_a, uint64_t const& desc_b, float& d00, float& d01, float& d02, + float& d03, float& d04, float& d05, float& d06, float& d07, float& d08, float& d09, + float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, float& d16, + float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, + float& d38, float& d39, float& d40, float& d41, float& d42, float& d43, float& d44, + float& d45, float& d46, float& d47, float& d48, float& d49, float& d50, float& d51, + float& d52, float& d53, float& d54, float& d55, float& d56, float& d57, float& d58, + float& d59, float& d60, float& d61, float& d62, float& d63, float& d64, float& d65, + float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, float& d72, + float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79, + float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, + float& d87, float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, + float& d94, float& d95, bool scale_d) { + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}, " + " %96," + " %97," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], d[32], d[33], d[34], d[35], d[36], d[37], + d[38], d[39], d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], d[48], d[49], d[50], + d[51], d[52], d[53], d[54], d[55], d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], d[72], d[73], d[74], d[75], d[76], + d[77], d[78], d[79], d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87], d[88], d[89], + d[90], d[91], d[92], d[93], d[94], d[95], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 192; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct SM90_U32x2_STSM_N { + __device__ __forceinline__ static void copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + uint32_t const src[2] = {*reinterpret_cast(&src_0), + *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" ::"l"(smem_dst), + "r"(src[0]), "r"(src[1])); + } +}; + +template +struct SM90_U32x4_STSM_N { + __device__ __forceinline__ static void copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, + dtype_t src_3, void* smem_dst) { + uint32_t const src[4] = { + *reinterpret_cast(&src_0), *reinterpret_cast(&src_1), + *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; + asm volatile( + "stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"l"(smem_dst), + "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); + } +}; + +template +struct SM90_U32x2_STSM_T { + __device__ __forceinline__ static void copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + const uint32_t src[2] = {*reinterpret_cast(&src_0), + *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16.trans [%0], {%1, %2};\n" ::"l"(smem_dst), + "r"(src[0]), "r"(src[1])); + } +}; + +template +struct SM90_U32x4_STSM_T { + __device__ __forceinline__ static void copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, + dtype_t src_3, void* smem_dst) { + const uint32_t src[4] = { + *reinterpret_cast(&src_0), *reinterpret_cast(&src_1), + *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; + asm volatile( + "stmatrix.sync.aligned.x4.m8n8.shared.b16.trans [%0], {%1, %2, %3, %4};\n" ::"l"(smem_dst), + "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); + } +}; + +__device__ void warpgroup_arrive() { asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); } + +__device__ void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +__device__ void warpgroup_fence_operand(float& reg) { asm volatile("" : "+f"(reg)::"memory"); } + +__forceinline__ __device__ uint32_t get_lane_id() { + uint32_t lane_id; + asm("mov.u32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; +} + +__device__ __forceinline__ uint32_t ld_shared(uint32_t const* __restrict__ ptr) { + uint32_t ret; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ int4 ld_shared(int4 const* __restrict__ ptr) { + int4 ret; + asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) + : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ float ld_shared(float const* __restrict__ ptr) { + float ret; + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ float2 ld_shared(float2 const* __restrict__ ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr)); + return ret; +} + +__device__ __forceinline__ void st_shared(float const* ptr, float val) { + asm volatile("st.shared.f32 [%0], %1;" ::"l"(ptr), "f"(val)); +} + +__device__ __forceinline__ void st_shared(uint32_t const* ptr, uint32_t val) { + asm volatile("st.shared.u32 [%0], %1;" ::"l"(ptr), "r"(val)); +} + +template +__device__ void warpgroup_wait() { + DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory"); +} + +union GmmaDescriptor { + __host__ __device__ constexpr GmmaDescriptor() noexcept : desc_(0) {} + + __host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept : desc_(desc) {} + + __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const& t) noexcept : desc_(t.desc_) {} + + __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor&& t) noexcept : desc_(t.desc_) {} + + __host__ __device__ constexpr GmmaDescriptor& operator=(GmmaDescriptor const& t) noexcept { + desc_ = t.desc_; + return *this; + } + + __host__ __device__ constexpr GmmaDescriptor& operator=(GmmaDescriptor&& t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint64_t desc_; + uint32_t reg32_[2]; + uint16_t reg16_[4]; + + struct { + uint16_t start_address_ : 14, : 2; + uint16_t leading_byte_offset_ : 14, : 2; + uint16_t stride_byte_offset_ : 14, : 2; + uint8_t : 1, base_offset_ : 3, : 4; + uint8_t : 6, layout_type_ : 2; + } bitfield; + + // Decay to an `uint64_t` + __host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; } +}; + +template +__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type, + int leading_byte_offset = 0, + int stride_byte_offset = 1024) { + GmmaDescriptor desc; + auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +struct FP8MMASelector { + static constexpr auto select_type() { + if constexpr (N == 16) return SM90_64x16x32_F32E4M3E4M3_SS(); + if constexpr (N == 24) return SM90_64x24x32_F32E4M3E4M3_SS(); + if constexpr (N == 32) return SM90_64x32x32_F32E4M3E4M3_SS(); + if constexpr (N == 40) return SM90_64x40x32_F32E4M3E4M3_SS(); + if constexpr (N == 48) return SM90_64x48x32_F32E4M3E4M3_SS(); + if constexpr (N == 56) return SM90_64x56x32_F32E4M3E4M3_SS(); + if constexpr (N == 64) return SM90_64x64x32_F32E4M3E4M3_SS(); + if constexpr (N == 72) return SM90_64x72x32_F32E4M3E4M3_SS(); + if constexpr (N == 80) return SM90_64x80x32_F32E4M3E4M3_SS(); + if constexpr (N == 88) return SM90_64x88x32_F32E4M3E4M3_SS(); + if constexpr (N == 96) return SM90_64x96x32_F32E4M3E4M3_SS(); + if constexpr (N == 104) return SM90_64x104x32_F32E4M3E4M3_SS(); + if constexpr (N == 112) return SM90_64x112x32_F32E4M3E4M3_SS(); + if constexpr (N == 120) return SM90_64x120x32_F32E4M3E4M3_SS(); + if constexpr (N == 128) return SM90_64x128x32_F32E4M3E4M3_SS(); + if constexpr (N == 192) return SM90_64x192x32_F32E4M3E4M3_SS(); + } + + using type = decltype(select_type()); +}; + +} // namespace deep_gemm diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/nvrtc_cutlass.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/nvrtc_cutlass.cuh new file mode 100644 index 0000000000..85a5551e5b --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/nvrtc_cutlass.cuh @@ -0,0 +1,2451 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// SM90 +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 0)) +#define CUTLASS_ARCH_MMA_SM90_SUPPORTED 1 +#if (!defined(CUTLASS_ARCH_MMA_SM90_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900) +#define CUTLASS_ARCH_MMA_SM90_ENABLED 1 + +#if (!defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +#define CUTLASS_ARCH_MMA_SM90A_ENABLED 1 +#endif +#endif +#endif + +#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 2) +#define CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 Modifiable +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 3)) +#define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED 1 +#if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ == 900) +#define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED 1 + +#if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +#define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90A_ENABLED 1 +#endif +#endif +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 F64 +#if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) +#define CUTLASS_ARCH_MMA_SM90_F64_MMA_SUPPORTED 1 +#if (!defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) && defined(__CUDA_ARCH__) && \ + __CUDA_ARCH__ >= 900) +#define CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED 1 +#endif +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// TMA instructions +#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) +#define CUTE_ARCH_TMA_SM90_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) +#define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +#endif + +// STSM +#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) +#define CUTE_ARCH_STSM_SM90_ENABLED +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ + ((__CUDACC_VER_MAJOR__ >= 12) || \ + ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)))) +#define CUTE_ARCH_CLUSTER_SM90_ENABLED +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +#define CUTE_ARCH_ELECT_ONE_SM90_ENABLED +#endif + +#ifndef CUDA_CTA_RECONFIG_ACTIVATED +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && \ + defined(__CUDA_ARCH_FEAT_SM90_ALL)) +#define CUDA_CTA_RECONFIG_ACTIVATED 1 +#endif +#endif + +#ifndef CU_TENSOR_MAP_NUM_QWORDS +#define CU_TENSOR_MAP_NUM_QWORDS 16 + +struct CUtensorMap_st { +#if defined(__cplusplus) && (__cplusplus >= 201103L) + alignas(64) +#elif __STDC_VERSION__ >= 201112L + _Alignas(64) +#endif + cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; +}; + +using CUtensorMap = CUtensorMap_st; +#endif + +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#elif defined(__CUDACC_RTC__) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#else +#define CUTLASS_HOST_DEVICE inline +#define CUTLASS_DEVICE inline +#endif + +#define CUTLASS_UNUSED(expr) \ + do { \ + ; \ + } while (&expr != &expr) +#define CUTLASS_ASSERT(x) assert(x) + +#if defined(__CUDACC__) || defined(_NVHPC_CUDA) +#define CUTE_HOST_DEVICE __forceinline__ __host__ __device__ +#define CUTE_DEVICE __forceinline__ __device__ +#define CUTE_HOST __forceinline__ __host__ +#else +#define CUTE_HOST_DEVICE inline +#define CUTE_DEVICE inline +#define CUTE_HOST inline +#endif // CUTE_HOST_DEVICE, CUTE_DEVICE + +#if defined(__CUDA_ARCH__) +#define CUTE_INVALID_CONTROL_PATH(x) \ + assert(0 && x); \ + printf(x); \ + __brkpt() +#else +#define CUTE_INVALID_CONTROL_PATH(x) \ + assert(0 && x); \ + printf(x) +#endif + +#define CUTLASS_HOST __host__ +#define CUTLASS_GLOBAL __global__ static + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12) +#define CUDA_BARRIER_ENABLED 1 +#else +#define CUDA_BARRIER_ENABLED 0 +#endif + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ENABLE_SYNCLOG) + +constexpr uint32_t synclog_cap = 1 << 26; + +inline std::mutex synclog_mutex; +inline std::vector synclog_buf_list; +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +inline __device__ uint32_t* synclog_buf; +#endif + +CUTLASS_DEVICE +uint32_t* synclog_alloc(uint32_t n) { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint32_t* buf = synclog_buf; + if (buf == nullptr) return nullptr; + uint32_t last = atomicAdd(&buf[0], n); + if (last + n < synclog_cap) return buf + last + 1; + if (last >= synclog_cap) atomicAdd(&buf[0], -n); +#endif + return nullptr; +} + +CUTLASS_DEVICE +void synclog_emit_prefix(uint32_t* to, uint32_t header, uint32_t line) { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint64_t time64; + asm volatile("mov.u64 %0, %%globaltimer;\n" : "=l"(time64) :); + to[0] = header; + to[1] = line; + to[2] = time64; + to[3] = time64 >> 32; + to[4] = threadIdx.x; + to[5] = threadIdx.y; + to[6] = threadIdx.z; + to[7] = blockIdx.x; + to[8] = blockIdx.y; + to[9] = blockIdx.z; +#endif +} + +constexpr uint32_t synclog_header_none = 0; +constexpr uint32_t synclog_length_prefix = 1 + 1 + 2 + 3 + 3; + +constexpr bool synclog_enable_syncthreads = true; +constexpr uint32_t synclog_header_syncthreads = 1; +constexpr uint32_t synclog_length_syncthreads = synclog_length_prefix + 0; + +constexpr bool synclog_enable_syncwarp = true; +constexpr uint32_t synclog_header_syncwarp = 2; +constexpr uint32_t synclog_length_syncwarp = synclog_length_prefix + 0; + +constexpr bool synclog_enable_named_barrier_arrive_and_wait = true; +constexpr uint32_t synclog_header_named_barrier_arrive_and_wait = 3; +constexpr uint32_t synclog_length_named_barrier_arrive_and_wait = synclog_length_prefix + 2; + +constexpr bool synclog_enable_named_barrier_arrive = true; +constexpr uint32_t synclog_header_named_barrier_arrive = 4; +constexpr uint32_t synclog_length_named_barrier_arrive = synclog_length_prefix + 2; + +constexpr bool synclog_enable_cluster_barrier_init = true; +constexpr uint32_t synclog_header_cluster_barrier_init = 5; +constexpr uint32_t synclog_length_cluster_barrier_init = synclog_length_prefix + 2; + +constexpr bool synclog_enable_cluster_barrier_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_wait = 6; +constexpr uint32_t synclog_length_cluster_barrier_wait = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_barrier_test_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_test_wait = 7; +constexpr uint32_t synclog_length_cluster_barrier_test_wait = synclog_length_prefix + 5; + +constexpr bool synclog_enable_cluster_barrier_try_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_try_wait = 8; +constexpr uint32_t synclog_length_cluster_barrier_try_wait = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_barrier_arrive_cluster = true; +constexpr uint32_t synclog_header_cluster_barrier_arrive_cluster = 9; +constexpr uint32_t synclog_length_cluster_barrier_arrive_cluster = synclog_length_prefix + 5; + +constexpr bool synclog_enable_cluster_barrier_arrive = true; +constexpr uint32_t synclog_header_cluster_barrier_arrive = 10; +constexpr uint32_t synclog_length_cluster_barrier_arrive = synclog_length_prefix + 3; + +constexpr bool synclog_enable_cluster_barrier_invalidate = true; +constexpr uint32_t synclog_header_cluster_barrier_invalidate = 11; +constexpr uint32_t synclog_length_cluster_barrier_invalidate = synclog_length_prefix + 3; + +constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx = 12; +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx = + synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster = 13; +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster = + synclog_length_prefix + 6; + +constexpr bool synclog_enable_cluster_transaction_barrier_expect_transaction = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_expect_transaction = 14; +constexpr uint32_t synclog_length_cluster_transaction_barrier_expect_transaction = + synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_transaction_barrier_complete_transaction = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_complete_transaction = 15; +constexpr uint32_t synclog_length_cluster_transaction_barrier_complete_transaction = + synclog_length_prefix + 6; + +constexpr bool synclog_enable_fence_barrier_init = true; +constexpr uint32_t synclog_header_fence_barrier_init = 16; +constexpr uint32_t synclog_length_fence_barrier_init = synclog_length_prefix + 0; + +constexpr bool synclog_enable_fence_view_async_shared = true; +constexpr uint32_t synclog_header_fence_view_async_shared = 17; +constexpr uint32_t synclog_length_fence_view_async_shared = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_wait = true; +constexpr uint32_t synclog_header_cp_async_wait = 18; +constexpr uint32_t synclog_length_cp_async_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_cp_async_wait_all = true; +constexpr uint32_t synclog_header_cp_async_wait_all = 19; +constexpr uint32_t synclog_length_cp_async_wait_all = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_fence = true; +constexpr uint32_t synclog_header_cp_async_fence = 20; +constexpr uint32_t synclog_length_cp_async_fence = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_nan = true; +constexpr uint32_t synclog_header_cp_async_nan = 21; +constexpr uint32_t synclog_length_cp_async_nan = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cp_async_zfill = true; +constexpr uint32_t synclog_header_cp_async_zfill = 22; +constexpr uint32_t synclog_length_cp_async_zfill = synclog_length_prefix + 5; + +constexpr bool synclog_enable_cp_async = true; +constexpr uint32_t synclog_header_cp_async = 23; +constexpr uint32_t synclog_length_cp_async = synclog_length_prefix + 5; + +constexpr bool synclog_enable_tma_load = true; +constexpr uint32_t synclog_header_tma_load = 24; +constexpr uint32_t synclog_length_tma_load = synclog_length_prefix + 4; + +constexpr bool synclog_enable_tma_store = true; +constexpr uint32_t synclog_header_tma_store = 25; +constexpr uint32_t synclog_length_tma_store = synclog_length_prefix + 3; + +constexpr bool synclog_enable_tma_store_arrive = true; +constexpr uint32_t synclog_header_tma_store_arrive = 26; +constexpr uint32_t synclog_length_tma_store_arrive = synclog_length_prefix + 0; + +constexpr bool synclog_enable_tma_store_wait = true; +constexpr uint32_t synclog_header_tma_store_wait = 27; +constexpr uint32_t synclog_length_tma_store_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_warpgroup_arrive = true; +constexpr uint32_t synclog_header_warpgroup_arrive = 28; +constexpr uint32_t synclog_length_warpgroup_arrive = synclog_length_prefix + 0; + +constexpr bool synclog_enable_warpgroup_wait = true; +constexpr uint32_t synclog_header_warpgroup_wait = 29; +constexpr uint32_t synclog_length_warpgroup_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_warpgroup_commit_batch = true; +constexpr uint32_t synclog_header_warpgroup_commit_batch = 30; +constexpr uint32_t synclog_length_warpgroup_commit_batch = synclog_length_prefix + 0; + +constexpr bool synclog_enable_wgmma_reg_smem = true; +constexpr uint32_t synclog_header_wgmma_reg_smem = 31; +constexpr uint32_t synclog_length_wgmma_reg_smem = synclog_length_prefix + 2; + +constexpr bool synclog_enable_wgmma_smem_smem = true; +constexpr uint32_t synclog_header_wgmma_smem_smem = 32; +constexpr uint32_t synclog_length_wgmma_smem_smem = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cpasync_barrier_arrive = true; +constexpr uint32_t synclog_header_cpasync_barrier_arrive = 33; +constexpr uint32_t synclog_length_cpasync_barrier_arrive = synclog_length_prefix + 3; + +CUTLASS_DEVICE +bool synclog_condition_emit() { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + return threadIdx.x % NumThreadsPerWarp == 0 && threadIdx.y == 0 && threadIdx.z == 0 && + blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0; +#else + return 0; +#endif +} + +CUTLASS_DEVICE +bool synclog_condition_print() { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + return threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 && + blockIdx.y == 0 && blockIdx.z == 0; +#else + return false; +#endif +} + +CUTLASS_DEVICE +void synclog_print_prefix(char const* header, uint32_t at) { +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint32_t line = synclog_buf[at + 1]; + uint32_t timeLo = synclog_buf[at + 2]; + uint32_t timeHi = synclog_buf[at + 3]; + uint32_t threadIdxX = synclog_buf[at + 4]; + uint32_t threadIdxY = synclog_buf[at + 5]; + uint32_t threadIdxZ = synclog_buf[at + 6]; + uint32_t blockIdxX = synclog_buf[at + 7]; + uint32_t blockIdxY = synclog_buf[at + 8]; + uint32_t blockIdxZ = synclog_buf[at + 9]; + printf("%s line=%u time=%lu thread=%u,%u,%u block=%u,%u,%u ", header, line, + (uint64_t)timeHi << 32 | timeLo, threadIdxX, threadIdxY, threadIdxZ, blockIdxX, blockIdxY, + blockIdxZ); +#endif +} + +CUTLASS_DEVICE +uint64_t synclog_mbarrier_bits(uint32_t smem_addr) { + uint64_t bits = 0; + asm volatile( + "mbarrier.inval.shared::cta.b64 [%1];\n" + "ld.shared::cta.b64 %0, [%1];\n" + : "=l"(bits) + : "r"(smem_addr)); + return bits; +} + +CUTLASS_DEVICE +void synclog_print_wgmma_desc(char const* str, uint32_t lo, uint32_t hi, char const* sep) { + CUTLASS_UNUSED(hi); + uint32_t smem_int_ptr = (lo & ((1 << 14) - 1)) << 4; + printf("%s_smem_int_ptr=%u%s", str, smem_int_ptr, sep); +} + +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline void synclog_setup() { +#if defined(CUTLASS_ENABLE_SYNCLOG) +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + std::scoped_lock lock(synclog_mutex); + auto fail = []() { + fprintf(stderr, "synclog_setup() failed\n"); + std::terminate(); + }; + int orig_device = 0; + if (cudaGetDevice(&orig_device) != cudaSuccess) { + fail(); + } + int device_count = 0; + if (cudaGetDeviceCount(&device_count) != cudaSuccess) { + fail(); + } + if (synclog_buf_list.size() == 0) { + for (int device = 0; device < device_count; device++) { + uint32_t* buf = 0; + if (cudaSetDevice(device) != cudaSuccess || + cudaMalloc(&buf, synclog_cap * sizeof(uint32_t)) != cudaSuccess) { + fail(); + } + synclog_buf_list.push_back(buf); + } + } + for (int device = 0; device < device_count; device++) { + uint32_t* buf = synclog_buf_list.at(device); + if (cudaSetDevice(device) != cudaSuccess || + cudaMemset(buf, 0, synclog_cap * sizeof(uint32_t)) != cudaSuccess || + cudaMemcpyToSymbol(synclog_buf, &buf, sizeof(buf)) != cudaSuccess) { + fail(); + } + } + if (cudaSetDevice(orig_device) != cudaSuccess) { + fail(); + } +#endif +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_syncthreads(uint32_t line) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_syncthreads) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_syncthreads); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_syncthreads, line); +#else + CUTLASS_UNUSED(line); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_syncwarp(uint32_t line) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_syncwarp) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_syncwarp); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_syncwarp, line); +#else + CUTLASS_UNUSED(line); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_named_barrier_arrive_and_wait(uint32_t line, uint32_t num_threads, + uint32_t barrier_id) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_named_barrier_arrive_and_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive_and_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_named_barrier_arrive_and_wait, line); + to[synclog_length_prefix + 0] = num_threads; + to[synclog_length_prefix + 1] = barrier_id; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(num_threads); + CUTLASS_UNUSED(barrier_id); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_named_barrier_arrive(uint32_t line, uint32_t num_threads, uint32_t barrier_id) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_named_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_named_barrier_arrive, line); + to[synclog_length_prefix + 0] = num_threads; + to[synclog_length_prefix + 1] = barrier_id; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(num_threads); + CUTLASS_UNUSED(barrier_id); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_init(uint32_t line, uint32_t smem_addr, uint32_t arrive_count) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_init) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_init); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_init, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = arrive_count; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(arrive_count); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_wait(uint32_t line, uint32_t smem_addr, uint32_t phase) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_wait) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 3] = bits >> 32; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_test_wait(uint32_t line, uint32_t smem_addr, uint32_t phase, + uint32_t pred) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_test_wait) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_test_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_test_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + to[synclog_length_prefix + 2] = pred; + to[synclog_length_prefix + 3] = bits; + to[synclog_length_prefix + 4] = bits >> 32; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); + CUTLASS_UNUSED(pred); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_try_wait(uint32_t line, uint32_t smem_addr, uint32_t phase) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_try_wait) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_try_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_try_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 3] = bits >> 32; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_arrive_cluster(uint32_t line, uint32_t smem_addr, uint32_t cta_id, + uint32_t pred) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_arrive_cluster) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive_cluster); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive_cluster, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = cta_id; + to[synclog_length_prefix + 2] = pred; + to[synclog_length_prefix + 3] = bits; + to[synclog_length_prefix + 4] = bits >> 32; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(cta_id); + CUTLASS_UNUSED(pred); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_arrive(uint32_t line, uint32_t smem_addr) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = bits; + to[synclog_length_prefix + 2] = bits >> 32; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_invalidate(uint32_t line, uint32_t smem_addr) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_invalidate) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_invalidate); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_invalidate, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = bits; + to[synclog_length_prefix + 2] = bits >> 32; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx(uint32_t line, + uint32_t smem_addr, + uint32_t transaction_bytes) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 3] = bits >> 32; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx_cluster( + uint32_t line, uint32_t smem_addr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = + synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster, + line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + to[synclog_length_prefix + 2] = cta_id; + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = bits; + to[synclog_length_prefix + 5] = bits >> 32; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); + CUTLASS_UNUSED(cta_id); + CUTLASS_UNUSED(pred); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_expect_transaction(uint32_t line, uint32_t smem_addr, + uint32_t transaction_bytes) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_expect_transaction) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_expect_transaction); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_expect_transaction, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 2] = bits >> 32; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_complete_transaction(uint32_t line, + uint32_t smem_addr, + uint32_t dst_cta_id, + uint32_t transaction_bytes, + uint32_t pred) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_complete_transaction) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_complete_transaction); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_complete_transaction, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = dst_cta_id; + to[synclog_length_prefix + 2] = transaction_bytes; + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = bits; + to[synclog_length_prefix + 5] = bits >> 32; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(dst_cta_id); + CUTLASS_UNUSED(transaction_bytes); + CUTLASS_UNUSED(pred); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_fence_barrier_init(uint32_t line) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_fence_barrier_init) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_fence_barrier_init); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_fence_barrier_init, line); +#else + CUTLASS_UNUSED(line); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_fence_view_async_shared(uint32_t line) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_fence_view_async_shared) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_fence_view_async_shared); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_fence_view_async_shared, line); +#else + CUTLASS_UNUSED(line); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_wait(uint32_t line, uint32_t n) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_wait, line); + to[synclog_length_prefix + 0] = n; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(n); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_wait_all(uint32_t line) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_wait_all) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_wait_all); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_wait_all, line); +#else + CUTLASS_UNUSED(line); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_fence(uint32_t line) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_fence) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_fence); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_fence, line); +#else + CUTLASS_UNUSED(line); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_nan(uint32_t line, uint32_t smem_addr, void const* gmem_ptr, + uint32_t pred) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_nan) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_nan); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_nan, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_zfill(uint32_t line, uint32_t smem_addr, void const* gmem_ptr, + uint32_t pred, uint32_t size) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_zfill) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_zfill); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_zfill, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = size; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + CUTLASS_UNUSED(size); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async(uint32_t line, uint32_t smem_addr, void const* gmem_ptr, uint32_t pred, + uint32_t size) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = size; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + CUTLASS_UNUSED(size); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_load(uint32_t line, uint64_t gmem_int_desc, uint32_t smem_int_mbar, + uint32_t smem_int_ptr) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_load) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_load); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_load, line); + to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc); + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32); + to[synclog_length_prefix + 2] = smem_int_mbar; + to[synclog_length_prefix + 3] = smem_int_ptr; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(gmem_int_desc); + CUTLASS_UNUSED(smem_int_mbar); + CUTLASS_UNUSED(smem_int_ptr); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store(uint32_t line, uint64_t gmem_int_desc, uint32_t smem_int_ptr) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store, line); + to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc); + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32); + to[synclog_length_prefix + 2] = smem_int_ptr; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(gmem_int_desc); + CUTLASS_UNUSED(smem_int_ptr); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store_arrive(uint32_t line) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store_arrive, line); +#else + CUTLASS_UNUSED(line); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store_wait(uint32_t line, uint32_t count) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store_wait, line); + to[synclog_length_prefix + 0] = count; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(count); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_arrive(uint32_t line) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_arrive, line); +#else + CUTLASS_UNUSED(line); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_wait(uint32_t line, uint32_t n) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_wait, line); + to[synclog_length_prefix + 0] = n; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(n); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_commit_batch(uint32_t line) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_commit_batch) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_commit_batch); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_commit_batch, line); +#else + CUTLASS_UNUSED(line); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_wgmma_reg_smem(uint32_t line, uint64_t desc_b) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_wgmma_reg_smem) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_wgmma_reg_smem); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_wgmma_reg_smem, line); + to[synclog_length_prefix + 0] = desc_b; + to[synclog_length_prefix + 1] = desc_b >> 32; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(desc_b); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_wgmma_smem_smem(uint32_t line, uint64_t desc_a, uint64_t desc_b) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_wgmma_smem_smem) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_wgmma_smem_smem); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_wgmma_smem_smem, line); + to[synclog_length_prefix + 0] = desc_a; + to[synclog_length_prefix + 1] = desc_a >> 32; + to[synclog_length_prefix + 2] = desc_b; + to[synclog_length_prefix + 3] = desc_b >> 32; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(desc_a); + CUTLASS_UNUSED(desc_b); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cpasync_barrier_arrive(uint32_t line, uint32_t smem_addr) { +#if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cpasync_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cpasync_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cpasync_barrier_arrive, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = bits; + to[synclog_length_prefix + 2] = bits >> 32; +#else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +#if !defined(CUTLASS_ENABLE_SYNCLOG) +CUTLASS_DEVICE +#elif defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +static __attribute__((__noinline__)) __device__ +#else +static __attribute__((__noinline__)) +#endif +void synclog_print() { +#if defined(CUTLASS_ENABLE_SYNCLOG) +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + if (synclog_buf == nullptr || !synclog_condition_print()) { + return; + } + printf("synclog start\n"); + for (uint32_t at = 1; at < synclog_cap;) { + uint32_t header = synclog_buf[at]; + if (header == synclog_header_none) { + break; + } + printf("synclog at %u: ", at); + if constexpr (synclog_enable_syncthreads) { + if (header == synclog_header_syncthreads) { + synclog_print_prefix("syncthreads", at); + at += synclog_length_syncthreads; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_syncwarp) { + if (header == synclog_header_syncwarp) { + synclog_print_prefix("syncwarp", at); + at += synclog_length_syncwarp; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_named_barrier_arrive_and_wait) { + if (header == synclog_header_named_barrier_arrive_and_wait) { + synclog_print_prefix("named_barrier_arrive_and_wait", at); + at += synclog_length_named_barrier_arrive_and_wait; + printf("num_threads=%u barrier_id=%u\n", synclog_buf[at - 2], synclog_buf[at - 1]); + continue; + } + } + if constexpr (synclog_enable_named_barrier_arrive) { + if (header == synclog_header_named_barrier_arrive) { + synclog_print_prefix("named_barrier_arrive", at); + at += synclog_length_named_barrier_arrive; + printf("num_threads=%u barrier_id=%u\n", synclog_buf[at - 2], synclog_buf[at - 1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_init) { + if (header == synclog_header_cluster_barrier_init) { + synclog_print_prefix("cluster_barrier_init", at); + at += synclog_length_cluster_barrier_init; + printf("smem_addr=%u arrive_count=%u\n", synclog_buf[at - 2], synclog_buf[at - 1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_wait) { + if (header == synclog_header_cluster_barrier_wait) { + synclog_print_prefix("cluster_barrier_wait", at); + at += synclog_length_cluster_barrier_wait; + printf("smem_addr=%u phase=%u", synclog_buf[at - 4], synclog_buf[at - 3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_test_wait) { + if (header == synclog_header_cluster_barrier_test_wait) { + synclog_print_prefix("cluster_barrier_test_wait", at); + at += synclog_length_cluster_barrier_test_wait; + printf("smem_addr=%u phase=%u pred=%u", synclog_buf[at - 5], synclog_buf[at - 4], + synclog_buf[at - 3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_try_wait) { + if (header == synclog_header_cluster_barrier_try_wait) { + synclog_print_prefix("cluster_barrier_try_wait", at); + at += synclog_length_cluster_barrier_try_wait; + printf("smem_addr=%u phase=%u", synclog_buf[at - 4], synclog_buf[at - 3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_arrive_cluster) { + if (header == synclog_header_cluster_barrier_arrive_cluster) { + synclog_print_prefix("cluster_barrier_arrive_cluster", at); + at += synclog_length_cluster_barrier_arrive_cluster; + printf("smem_addr=%u cta_id=%u pred=%u", synclog_buf[at - 5], synclog_buf[at - 4], + synclog_buf[at - 3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_arrive) { + if (header == synclog_header_cluster_barrier_arrive) { + synclog_print_prefix("cluster_barrier_arrive", at); + at += synclog_length_cluster_barrier_arrive; + printf("smem_addr=%u", synclog_buf[at - 3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_invalidate) { + if (header == synclog_header_cluster_barrier_invalidate) { + synclog_print_prefix("cluster_barrier_invalidate", at); + at += synclog_length_cluster_barrier_invalidate; + printf("smem_addr=%u", synclog_buf[at - 3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) { + if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx) { + synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx", at); + at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx; + printf("smem_addr=%u transaction_bytes=%u", synclog_buf[at - 4], synclog_buf[at - 3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { + if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { + synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx_cluster", at); + at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster; + printf("smem_addr=%u transaction_bytes=%u cta_id=%u pred=%u", synclog_buf[at - 6], + synclog_buf[at - 5], synclog_buf[at - 4], synclog_buf[at - 3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_expect_transaction) { + if (header == synclog_header_cluster_transaction_barrier_expect_transaction) { + synclog_print_prefix("cluster_transaction_barrier_expect_transaction", at); + at += synclog_length_cluster_transaction_barrier_expect_transaction; + printf("smem_addr=%u transaction_bytes=%u", synclog_buf[at - 4], synclog_buf[at - 3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_complete_transaction) { + if (header == synclog_header_cluster_transaction_barrier_complete_transaction) { + synclog_print_prefix("cluster_transaction_barrier_complete_transaction", at); + at += synclog_length_cluster_transaction_barrier_complete_transaction; + printf("smem_addr=%u dst_cta_id=%u transaction_bytes=%u pred=%u", synclog_buf[at - 6], + synclog_buf[at - 5], synclog_buf[at - 4], synclog_buf[at - 3]); + continue; + } + } + if constexpr (synclog_enable_fence_barrier_init) { + if (header == synclog_header_fence_barrier_init) { + synclog_print_prefix("fence_barrier_init", at); + at += synclog_length_fence_barrier_init; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_fence_view_async_shared) { + if (header == synclog_header_fence_view_async_shared) { + synclog_print_prefix("fence_view_async_shared", at); + at += synclog_length_fence_view_async_shared; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_wait) { + if (header == synclog_header_cp_async_wait) { + synclog_print_prefix("cp_async_wait", at); + at += synclog_length_cp_async_wait; + printf("n=%u\n", synclog_buf[at - 1]); + continue; + } + } + if constexpr (synclog_enable_cp_async_wait_all) { + if (header == synclog_header_cp_async_wait_all) { + synclog_print_prefix("cp_async_wait_all", at); + at += synclog_length_cp_async_wait_all; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_fence) { + if (header == synclog_header_cp_async_fence) { + synclog_print_prefix("cp_async_fence", at); + at += synclog_length_cp_async_fence; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_nan) { + if (header == synclog_header_cp_async_nan) { + synclog_print_prefix("cp_async_nan", at); + at += synclog_length_cp_async_nan; + uint64_t gmem_addr = synclog_buf[at - 3]; + gmem_addr += (uint64_t)synclog_buf[at - 2] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u\n", synclog_buf[at - 4], gmem_addr, + synclog_buf[at - 1]); + continue; + } + } + if constexpr (synclog_enable_cp_async_zfill) { + if (header == synclog_header_cp_async_zfill) { + synclog_print_prefix("cp_async_zfill", at); + at += synclog_length_cp_async_zfill; + uint64_t gmem_addr = synclog_buf[at - 4]; + gmem_addr += (uint64_t)synclog_buf[at - 3] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at - 5], gmem_addr, + synclog_buf[at - 2], synclog_buf[at - 1]); + continue; + } + } + if constexpr (synclog_enable_cp_async) { + if (header == synclog_header_cp_async) { + synclog_print_prefix("cp_async", at); + at += synclog_length_cp_async; + uint64_t gmem_addr = synclog_buf[at - 4]; + gmem_addr += (uint64_t)synclog_buf[at - 3] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at - 5], gmem_addr, + synclog_buf[at - 2], synclog_buf[at - 1]); + continue; + } + } + if constexpr (synclog_enable_tma_load) { + if (header == synclog_header_tma_load) { + synclog_print_prefix("tma_load", at); + at += synclog_length_tma_load; + uint64_t gmem_int_desc = synclog_buf[at - 4]; + gmem_int_desc += (uint64_t)synclog_buf[at - 3] << 32; + printf("gmem_int_desc=%llu smem_int_mbar=%u smem_int_ptr=%u\n", gmem_int_desc, + synclog_buf[at - 2], synclog_buf[at - 1]); + continue; + } + } + if constexpr (synclog_enable_tma_store) { + if (header == synclog_header_tma_store) { + synclog_print_prefix("tma_store", at); + at += synclog_length_tma_store; + uint64_t gmem_int_desc = synclog_buf[at - 3]; + gmem_int_desc += (uint64_t)synclog_buf[at - 2] << 32; + printf("gmem_int_desc=%llu smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at - 1]); + continue; + } + } + if constexpr (synclog_enable_tma_store_arrive) { + if (header == synclog_header_tma_store_arrive) { + synclog_print_prefix("tma_store_arrive", at); + at += synclog_length_tma_store_arrive; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_tma_store_wait) { + if (header == synclog_header_tma_store_wait) { + synclog_print_prefix("tma_store_wait", at); + at += synclog_length_tma_store_wait; + printf("count=%u\n", synclog_buf[at - 1]); + continue; + } + } + if constexpr (synclog_enable_warpgroup_arrive) { + if (header == synclog_header_warpgroup_arrive) { + synclog_print_prefix("warpgroup_arrive", at); + at += synclog_length_warpgroup_arrive; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_warpgroup_wait) { + if (header == synclog_header_warpgroup_wait) { + synclog_print_prefix("warpgroup_wait", at); + at += synclog_length_warpgroup_wait; + printf("n=%u\n", synclog_buf[at - 1]); + continue; + } + } + if constexpr (synclog_enable_warpgroup_commit_batch) { + if (header == synclog_header_warpgroup_commit_batch) { + synclog_print_prefix("warpgroup_commit_batch", at); + at += synclog_length_warpgroup_commit_batch; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_wgmma_reg_smem) { + if (header == synclog_header_wgmma_reg_smem) { + synclog_print_prefix("wgmma_reg_smem", at); + at += synclog_length_wgmma_reg_smem; + synclog_print_wgmma_desc("desc_b", synclog_buf[at - 2], synclog_buf[at - 1], ""); + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_wgmma_smem_smem) { + if (header == synclog_header_wgmma_smem_smem) { + synclog_print_prefix("wgmma_smem_smem", at); + at += synclog_length_wgmma_smem_smem; + synclog_print_wgmma_desc("desc_a", synclog_buf[at - 4], synclog_buf[at - 3], " "); + synclog_print_wgmma_desc("desc_b", synclog_buf[at - 2], synclog_buf[at - 1], ""); + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cpasync_barrier_arrive) { + if (header == synclog_header_cpasync_barrier_arrive) { + synclog_print_prefix("cpasync_barrier_arrive", at); + at += synclog_length_cpasync_barrier_arrive; + printf("smem_addr=%u", synclog_buf[at - 3]); + continue; + } + } + asm volatile("brkpt;\n" ::); + } + if (synclog_buf[0] >= synclog_cap) { + printf("synclog was truncated (exceeded capacity of %lu bytes)\n", + (synclog_cap - 1) * sizeof(uint32_t)); + } + printf("synclog end\n"); +#endif +#endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ENABLE_SYNCLOG) +#undef __syncthreads +#define __syncthreads() \ + do { \ + cutlass::arch::synclog_emit_syncthreads(__LINE__); \ + __syncthreads(); \ + } while (0) +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +#if defined(CUTLASS_ENABLE_SYNCLOG) +#undef __syncwarp +#define __syncwarp(...) \ + do { \ + cutlass::arch::synclog_emit_syncwarp(__LINE__); \ + __syncwarp(__VA_ARGS__); \ + } while (0) +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +#if defined(__clang__) && defined(__CUDA__) + // __cvta_generic_to_shared was added in Clang 14: + // https://reviews.llvm.org/D111665 +#if __clang_major__ >= 14 +#define CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED 1 +#endif + +// __nvvm_get_smem_pointer added in Clang 14: https://reviews.llvm.org/D111665 +// ... but will not work on Windows until Clang 15: +// https://reviews.llvm.org/D122897 +#if (!defined(_WIN32) && __clang_major__ >= 14) || __clang_major__ >= 15 +#define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER 1 +#endif +#endif + +#if defined(__NVCC__) || defined(__CUDACC_RTC__) + // __cvta_generic_to_shared added in CUDA 11+ +#if __CUDACC_VER_MAJOR__ >= 11 +#define CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED 1 +#endif + +// __nvvm_get_smem_pointer added in CUDA 10.2 +#if __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2 +#define CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER 1 +#endif +#endif + +#if CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED +#define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED 1 +#endif + +#if !defined(CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED) && CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED && \ + defined(__CUDA_ARCH__) +#define CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED 1 +#endif + +#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER || CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER +#define CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED 1 +#endif + +#if !defined(CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED) && CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED && \ + defined(__CUDA_ARCH__) +#define CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED 1 +#endif + +// Clang 14+ provides a declaration of __nvvm_get_smem_pointer, so we only need +// to provide one for NVCC +#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER +extern "C" { +// This NVVM intrinsic is subject to change in future versions of CUDA. +// Clients should not call it directly. +CUTE_DEVICE uint32_t __nvvm_get_smem_pointer(void*); +} +#endif + +namespace cute { + +/// CUTE helper to cast SMEM pointer to unsigned +CUTE_DEVICE +uint32_t cast_smem_ptr_to_uint(void const* const ptr) { +// We prefer to use the new CVTA intrinsics if they are available, otherwise we +// will fall back to the previous internal intrinsics if they are available. +#if CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED + // + // This NVVM intrinsic converts an address in shared memory to a plain + // unsigned integer. This is necessary to pass to shared memory instructions + // in inline PTX. + // + // In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only + // available in 10.2]. + // + //__device__ size_t __cvta_generic_to_shared(void* ptr); + + /// CUTE helper to get SMEM pointer + return static_cast(__cvta_generic_to_shared(ptr)); + +#elif CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED + + return __nvvm_get_smem_pointer(ptr); + +#elif defined(__CUDA_ARCH__) + + uint32_t smem_ptr; + + asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, " + "smem_ptr; }\n" + : "=r"(smem_ptr) + : "l"(ptr)); + + return smem_ptr; + +#else + + (void)ptr; + printf("ERROR: cast_smem_ptr_to_uint not supported but used.\n"); + return 0; + +#endif +} + +} // namespace cute + +namespace cute { + +CUTE_DEVICE void cluster_arrive_relaxed() { +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : :); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_arrive() { +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.arrive.aligned;\n" : :); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_wait() { +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.wait.aligned;\n" : :); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_sync() { +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + cluster_arrive(); + cluster_wait(); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +// Returns the dim3 grid size in terms of number of clusters. +CUTE_DEVICE dim3 cluster_grid_dims() { +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%nclusterid.x;\n" : "=r"(x) :); + asm volatile("mov.u32 %0, %%nclusterid.y;\n" : "=r"(y) :); + asm volatile("mov.u32 %0, %%nclusterid.z;\n" : "=r"(z) :); + return {x, y, z}; +#elif defined(__CUDA_ARCH__) + // MSVC requires protecting use of gridDim with __CUDA_ARCH__. + return gridDim; +#elif defined(_MSC_VER) + CUTE_INVALID_CONTROL_PATH("cluster_grid_dims() can only be called on device"); + return {0, 0, 0}; +#else + return {0, 0, 0}; +#endif +} + +// Returns the dim3 cluster rank in the grid. +CUTE_DEVICE dim3 cluster_id_in_grid() { +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%clusterid.x;\n" : "=r"(x) :); + asm volatile("mov.u32 %0, %%clusterid.y;\n" : "=r"(y) :); + asm volatile("mov.u32 %0, %%clusterid.z;\n" : "=r"(z) :); + return {x, y, z}; +#elif defined(__CUDA_ARCH__) + // MSVC requires protecting use of blockIdx with __CUDA_ARCH__. + return blockIdx; +#elif defined(_MSC_VER) + CUTE_INVALID_CONTROL_PATH("cluster_id_in_grid() can only be called on device"); + return {0, 0, 0}; +#else + return {0, 0, 0}; +#endif +} + +// Returns the relative dim3 block rank local to the cluster. +CUTE_DEVICE dim3 block_id_in_cluster() { +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%cluster_ctaid.x;\n" : "=r"(x) :); + asm volatile("mov.u32 %0, %%cluster_ctaid.y;\n" : "=r"(y) :); + asm volatile("mov.u32 %0, %%cluster_ctaid.z;\n" : "=r"(z) :); + return {x, y, z}; +#else + return {0, 0, 0}; +#endif +} + +// Returns the dim3 cluster shape. +CUTE_DEVICE dim3 cluster_shape() { +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%cluster_nctaid.x;\n" : "=r"(x) :); + asm volatile("mov.u32 %0, %%cluster_nctaid.y;\n" : "=r"(y) :); + asm volatile("mov.u32 %0, %%cluster_nctaid.z;\n" : "=r"(z) :); + return {x, y, z}; +#else + return {1, 1, 1}; +#endif +} + +// Get 1D ctaid in a cluster. +CUTE_DEVICE uint32_t block_rank_in_cluster() { +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t rank; + asm volatile("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank) :); + return rank; +#else + return 0; +#endif +} + +// Set the destination block-ID in cluster for a given SMEM Address +CUTE_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) { +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t result; + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" : "=r"(result) : "r"(smemAddr), "r"(rank)); + return result; +#else + return smemAddr; +#endif +} + +// Elect one thread in the warp. The elected thread gets its predicate set to +// true, all others obtain false. +CUTE_HOST_DEVICE uint32_t elect_one_sync() { +#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +#elif defined(__CUDA_ARCH__) + return (threadIdx.x % 32) == 0; +#else + return true; +#endif +} + +struct ElectOneLaneIdReturnType { + uint32_t is_leader; + uint32_t leader_lane_id; +}; + +CUTE_HOST_DEVICE +ElectOneLaneIdReturnType elect_one_leader_sync() { +#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return {pred, laneid}; +#elif defined(__CUDA_ARCH__) + return {(threadIdx.x % 32) == 0, 0}; +#else + return {true, 0}; +#endif +} + +// Store value to remote shared memory in the cluster +CUTE_DEVICE +void store_shared_remote(uint32_t value, uint32_t smem_addr, uint32_t mbarrier_addr, + uint32_t dst_cta_rank) { +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t dsmem_addr = set_block_rank(smem_addr, dst_cta_rank); + uint32_t remote_barrier_addr = set_block_rank(mbarrier_addr, dst_cta_rank); + asm volatile( + "st.async.shared::cluster.mbarrier::complete_tx::bytes.u32 " + "[%0], %1, [%2];" + : + : "r"(dsmem_addr), "r"(value), "r"(remote_barrier_addr)); +#endif +} + +// Fence for smem stores for subsequent TMA_STORE +CUTE_HOST_DEVICE static void tma_store_fence() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + cutlass::arch::synclog_emit_fence_view_async_shared(__LINE__); + asm volatile("fence.proxy.async.shared::cta;"); +#elif defined(__CUDA_ARCH__) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +// Indicate arrival of warp issuing TMA_STORE +CUTE_HOST_DEVICE static void tma_store_arrive() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + cutlass::arch::synclog_emit_tma_store_arrive(__LINE__); + asm volatile("cp.async.bulk.commit_group;"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +// Wait until at most Count committed TMA_STOREs are pending and all prior +// commits are complete +template +CUTE_HOST_DEVICE static void tma_store_wait() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile("cp.async.bulk.wait_group.read %0;" : : "n"(Count) : "memory"); + cutlass::arch::synclog_emit_tma_store_wait(__LINE__, Count); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +} // end namespace cute + +namespace cutlass { +/// @brief +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts +// This enum class specifies the NamedBarriers reserved by CUTLASS. +enum class ReservedNamedBarriers { + EpilogueBarrier = 1, + TransposeBarrier = 2, + TransformBarrier = 3, + StreamkBarrier0 = 4, + StreamkBarrier1 = 5, + FirstUserBarrier = StreamkBarrier1 + 1 +}; + +class NamedBarrier { + // Data Members: + + // Range = [1 , NUM_THREADS_PER_CTA] + // Range % warp-size (i.e 32) == 0 + uint32_t const num_threads_; + + // Range : [0, 15] + // Note that should be set to the final barrier ID, including + // ReserveNamedBarrierCount should be considered + uint32_t const id_; + + public: + // Constructor for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + NamedBarrier(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) + : num_threads_(num_threads), id_(static_cast(reserved_named_barriers)) {} + + // Constructor for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + NamedBarrier(uint32_t num_threads, uint32_t id = 0) + : num_threads_(num_threads), id_(id + ReservedNamedBarrierCount) { + CUTLASS_ASSERT(id + ReservedNamedBarrierCount <= HardwareMaxNumNamedBarriers && + "Effective barrier_id should not exceed 16."); + } + + CUTLASS_DEVICE + void arrive_and_wait() const { + // Note: The value of id_ is already the final barrier id (set correctly in + // the constructor). + NamedBarrier::arrive_and_wait_internal(num_threads_, id_); + } + + CUTLASS_DEVICE + void arrive_and_wait_unaligned() const { + // Note: The value of id_ is already the final barrier id (set correctly in + // the constructor). + NamedBarrier::arrive_and_wait_internal_unaligned(num_threads_, id_); + } + + CUTLASS_DEVICE + void arrive() const { + // Note: The value of id_ is already the final barrier id (set correctly in + // the constructor). + NamedBarrier::arrive_internal(num_threads_, id_); + } + + CUTLASS_DEVICE + void arrive_unaligned() const { + // Note: The value of id_ is already the final barrier id (set correctly in + // the constructor). + NamedBarrier::arrive_internal_unaligned(num_threads_, id_); + } + + CUTLASS_DEVICE + void sync() const { NamedBarrier::arrive_and_wait(); } + + // Static variants + + // Calling interface for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) { + arrive_and_wait_internal(num_threads, barrier_id + ReservedNamedBarrierCount); + } + + // Calling interface for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + static void arrive_and_wait(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { + arrive_and_wait_internal(num_threads, static_cast(reserved_named_barriers)); + } + + // Calling interface for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + static void arrive(uint32_t num_threads, uint32_t barrier_id) { + arrive_internal(num_threads, barrier_id + ReservedNamedBarrierCount); + } + + // Calling interface for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + static void arrive(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { + arrive_internal(num_threads, static_cast(reserved_named_barriers)); + } + + // Calling interface for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + static void sync(uint32_t num_threads, uint32_t barrier_id) { + sync_internal(num_threads, barrier_id + ReservedNamedBarrierCount); + } + + // Calling interface for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + static void sync(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { + sync_internal(num_threads, static_cast(reserved_named_barriers)); + } + + private: + CUTLASS_DEVICE + static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void arrive_and_wait_internal_unaligned(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + asm volatile("barrier.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); + asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void arrive_internal_unaligned(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); + asm volatile("barrier.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void sync_internal(uint32_t num_threads, uint32_t barrier_id) { + NamedBarrier::arrive_and_wait_internal(num_threads, barrier_id); + } + + public: + // Currently we reserve 8 NamedBarriers for CUTLASS' own use cases, + // while leaving the renaming for general users. + static const uint32_t ReservedNamedBarrierCount = + static_cast(ReservedNamedBarriers::FirstUserBarrier); + static const uint32_t HardwareMaxNumNamedBarriers = 16; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Hopper introduces a new cluster-wide barrier which handle with Cluster-wide +// arrive-wait behaviour. This is an extension to the Ampere arrive-wait +// barriers Note : Ampere arrive-wait Barriers have a larger max-arrive count +// (2^30) than Hopper arrive-wait Barriers (2^20). +struct ClusterBarrier { + using ValueType = uint64_t; + + protected: + // Can never be initialized - can only be aliased to smem + ValueType barrier_; + + public: + CUTLASS_DEVICE + ClusterBarrier() = delete; + + CUTLASS_DEVICE + void init(uint32_t arrive_count) const { ClusterBarrier::init(&this->barrier_, arrive_count); } + + CUTLASS_DEVICE + bool test_wait(uint32_t phase, uint32_t pred = true) const { + return ClusterBarrier::test_wait(&this->barrier_, phase, pred); + } + + CUTLASS_DEVICE + bool try_wait(uint32_t phase) const { return ClusterBarrier::try_wait(&this->barrier_, phase); } + + CUTLASS_DEVICE + void wait(uint32_t phase) const { ClusterBarrier::wait(&this->barrier_, phase); } + + // Barrier arrive on local smem + CUTLASS_DEVICE + void arrive() const { ClusterBarrier::arrive(&this->barrier_); } + + // Remote SMEM arrive with a perdicate (usually done to pick the thread doing + // the arrive) + CUTLASS_DEVICE + void arrive(uint32_t cta_id, uint32_t pred = true) const { + ClusterBarrier::arrive(&this->barrier_, cta_id, pred); + } + + // + // Static Versions + // + CUTLASS_DEVICE + static void init(ValueType const* smem_ptr, uint32_t arrive_count) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.init.shared::cta.b64 [%1], %0; \n" + "}" + : + : "r"(arrive_count), "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_barrier_init(__LINE__, smem_addr, arrive_count); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + } + + // Static version of wait - in case we don't want to burn a register + CUTLASS_DEVICE + static void wait(ValueType const* smem_ptr, uint32_t phase) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_cluster_barrier_wait(__LINE__, smem_addr, phase); + // Arbitrarily large timer value after which try-wait expires and re-tries. + uint32_t ticks = 0x989680; + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + "}" + : + : "r"(smem_addr), "r"(phase), "r"(ticks)); + +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static bool test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_cluster_barrier_test_wait(__LINE__, smem_addr, phase, pred); + uint32_t waitComplete; + + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + ".reg .pred P2; \n\t" + "setp.eq.u32 P2, %3, 1;\n\t" + "@P2 mbarrier.test_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(waitComplete) + : "r"(smem_addr), "r"(phase), "r"(pred)); + + return static_cast(waitComplete); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + return 0; + } + + CUTLASS_DEVICE + static bool try_wait(ValueType const* smem_ptr, uint32_t phase) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_cluster_barrier_try_wait(__LINE__, smem_addr, phase); + uint32_t waitComplete; + + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(waitComplete) + : "r"(smem_addr), "r"(phase)); + + return static_cast(waitComplete); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + return 0; + } + + // Static Predicated version of the above - in case we know the address. + CUTLASS_DEVICE + static void arrive(ValueType const* smem_ptr, uint32_t cta_id, uint32_t pred) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + if (pred) { + asm volatile( + "{\n\t" + ".reg .b32 remAddr32;\n\t" + "mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" + "mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t" + "}" + : + : "r"(smem_addr), "r"(cta_id)); + } + + cutlass::arch::synclog_emit_cluster_barrier_arrive_cluster(__LINE__, smem_addr, cta_id, pred); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + } + + // Barrier arrive on local smem + CUTLASS_DEVICE + static void arrive(ValueType const* smem_ptr) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.arrive.shared::cta.b64 _, [%0];\n\t" + "}" + : + : "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_barrier_arrive(__LINE__, smem_addr); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void invalidate(ValueType const* smem_ptr) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.inval.shared::cta.b64 [%0]; \n\t" + "}" + : + : "r"(smem_addr)); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 also introduces a new type of cluster-barrier which supports sync. +// not just based on Arrive Count, but also transaction count (in bytes) +struct ClusterTransactionBarrier : public ClusterBarrier { + CUTLASS_DEVICE + ClusterTransactionBarrier() = delete; + + // Performs an arrive operation + expected transaction bytes increment + CUTLASS_DEVICE + void arrive_and_expect_tx(uint32_t transaction_bytes) const { + ClusterTransactionBarrier::arrive_and_expect_tx(&this->barrier_, transaction_bytes); + } + + // Performs an arrive operation + expected transaction bytes increment + CUTLASS_DEVICE + void arrive_and_expect_tx(uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred = 1u) const { + ClusterTransactionBarrier::arrive_and_expect_tx(&this->barrier_, transaction_bytes, cta_id, + pred); + } + + // Performs an expected transaction bytes increment without doing an arrive + // operation + CUTLASS_DEVICE + void expect_transaction(uint32_t transaction_bytes) const { + ClusterTransactionBarrier::expect_transaction(&this->barrier_, transaction_bytes); + } + + // Performs an expected transaction bytes decrement without doing an arrive + // operation + CUTLASS_DEVICE + void complete_transaction(uint32_t transaction_bytes, uint32_t pred = 1) const { + uint32_t cta_rank = cute::block_rank_in_cluster(); + ClusterTransactionBarrier::complete_transaction(&this->barrier_, cta_rank, transaction_bytes, + pred); + } + + // Performs an expected transaction bytes decrement without doing an arrive + // operation + CUTLASS_DEVICE + void complete_transaction(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { + ClusterTransactionBarrier::complete_transaction(&this->barrier_, dst_cta_id, transaction_bytes, + pred); + } + + // + // Static Versions + // + + // Performs an arrive operation + expected transaction bytes increment + CUTLASS_DEVICE + static void arrive_and_expect_tx(ValueType const* smem_ptr, uint32_t transaction_bytes) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" + "}" + : + : "r"(transaction_bytes), "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx( + __LINE__, smem_addr, transaction_bytes); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + } + + // Performs an arrive operation + expected transaction bytes increment for a + // remote cta_id in a Cluster + CUTLASS_DEVICE + static void arrive_and_expect_tx(ValueType const* smem_ptr, uint32_t transaction_bytes, + uint32_t cta_id, uint32_t pred) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b32 remAddr32;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" + "@p mbarrier.arrive.expect_tx.shared::cluster.b64 _, " + "[remAddr32], %3;\n\t" + "}" + : + : "r"(smem_addr), "r"(cta_id), "r"(pred), "r"(transaction_bytes)); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + } + + // Performs an expected transaction bytes increment without doing an arrive + // operation + CUTLASS_DEVICE + static void expect_transaction(ValueType const* smem_ptr, uint32_t transaction_bytes) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.expect_tx.shared::cta.b64 [%1], %0; \n\t" + "}" + : + : "r"(transaction_bytes), "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_transaction_barrier_expect_transaction(__LINE__, smem_addr, + transaction_bytes); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + } + + // Performs an expected transaction bytes decrement without doing an arrive + // operation + CUTLASS_DEVICE + static void complete_transaction(ValueType const* smem_ptr, uint32_t dst_cta_id, + uint32_t transaction_bytes, uint32_t pred = 1) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + smem_addr = cute::set_block_rank(smem_addr, dst_cta_id); + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p mbarrier.complete_tx.shared::cluster.relaxed.cluster.b64 " + " [%1], %0;" + "}" + : + : "r"(transaction_bytes), "r"(smem_addr), "r"(pred)); + cutlass::arch::synclog_emit_cluster_transaction_barrier_complete_transaction( + __LINE__, smem_addr, dst_cta_id, transaction_bytes, pred); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif + } + + // + // DEPRECATED APIs + // + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE void arrive_and_reset_bytes( + uint32_t transaction_bytes) const { + arrive_and_expect_tx(transaction_bytes); + } + + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE void arrive_and_reset_bytes( + uint32_t transaction_bytes, uint32_t cta_id) const { + arrive_and_expect_tx(transaction_bytes, cta_id); + } + + [[deprecated("Use expect_transaction instead")]] CUTLASS_DEVICE void reset_bytes( + uint32_t transaction_bytes) const { + expect_transaction(transaction_bytes); + } + + [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE void commit( + uint32_t transaction_bytes, uint32_t pred = 1) const { + complete_transaction(transaction_bytes, pred); + } + + [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE void commit( + uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { + complete_transaction(dst_cta_id, transaction_bytes, pred); + } + + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE static void + arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { + arrive_and_expect_tx(smem_ptr, transaction_bytes); + } + + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE static void + arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, + uint32_t pred) { + arrive_and_expect_tx(smem_ptr, transaction_bytes, cta_id, pred); + } + + [[deprecated("Use expect_transaction instead")]] CUTLASS_DEVICE static void reset_bytes( + ValueType const* smem_ptr, uint32_t transaction_bytes) { + expect_transaction(smem_ptr, transaction_bytes); + } + + [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE static void commit( + ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, + uint32_t pred = 1) { + complete_transaction(smem_ptr, dst_cta_id, transaction_bytes, pred); + } +}; + +// Helps with visibility of barrier init operations across warps / cta / cluster +// Available as a separate function so as to batch inits across barriers and +// fence once Note : It must be composed with an appropriate sync instruction +// with the right scope to ensure visibility eg. __syncthreads() or a +// cluster_arrive() + cluster_wait() +CUTLASS_DEVICE +void fence_barrier_init() { +#if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_fence_barrier_init(__LINE__); + asm volatile( + "{\n\t" + "fence.mbarrier_init.release.cluster; \n" + "}" ::); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif +} + +// Issue a shared memory fence for async operations +CUTLASS_DEVICE +void fence_view_async_shared() { +#if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_fence_view_async_shared(__LINE__); + asm volatile( + "{\n\t" + "fence.proxy.async.shared::cta; \n" + "}" ::); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif +} + +// Arrive on completion of in-flight cp.async operations issued by the calling +// thread +CUTLASS_DEVICE +void cpasync_barrier_arrive(uint64_t const* smem_ptr) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "cp.async.mbarrier.arrive.shared::cta.b64 [%0];\n\t" + "}" + : + : "r"(smem_addr)); + cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_addr); +#elif defined(__CUDA_ARCH__) + asm volatile("brkpt;\n" ::); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // end namespace arch +} // end namespace cutlass + +namespace cutlass { +namespace arch { + +template +CUTLASS_DEVICE void warpgroup_reg_alloc() { +#if CUDA_CTA_RECONFIG_ACTIVATED + asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount)); +#endif +} + +template +CUTLASS_DEVICE void warpgroup_reg_dealloc() { +#if CUDA_CTA_RECONFIG_ACTIVATED + asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount)); +#endif +} + +} // namespace arch +} // namespace cutlass + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ + ((__CUDACC_VER_MAJOR__ >= 12) || \ + ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)))) +#define CUTE_ARCH_CLUSTER_SM90_ENABLED +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +#define CUTE_ARCH_ELECT_ONE_SM90_ENABLED +#endif + +namespace cute { + +#if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) +using TmaDescriptor = CUtensorMap; +using Im2ColTmaDescriptor = CUtensorMap; +#else +using TmaDescriptor = struct alignas(64) { + char bytes[128]; +}; + +using Im2ColTmaDescriptor = struct alignas(64) { + char bytes[128]; +}; +#endif + +CUTE_HOST_DEVICE +void prefetch_tma_descriptor(TmaDescriptor const* desc_ptr) { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Prefetch TMA Descriptor using generic addressing (i.e. no specific state + // space: const or param) + asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory"); +#else + CUTE_INVALID_CONTROL_PATH( + "Trying to use TMA Descriptor Prefetch without " + "CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +namespace TMA { +enum class CacheHintSm90 : uint64_t { + EVICT_NORMAL = 0x1000000000000000, + EVICT_FIRST = 0x12F0000000000000, + EVICT_LAST = 0x14F0000000000000, +}; +} + +struct SM90_TMA_LOAD_2D { + CUTE_HOST_DEVICE static void copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void* smem_ptr, int32_t const& crd0, int32_t const& crd1) { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "r"(crd0), "r"(crd1), + "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH { + CUTE_HOST_DEVICE static void copy(void const* desc_ptr, int32_t const& crd0, + int32_t const& crd1) { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile( + "cp.async.bulk.prefetch.tensor.2d.L2.global" + " [%0, {%1, %2}];" + : + : "l"(gmem_int_desc), "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + }; +}; + +struct SM90_TMA_LOAD_MULTICAST_2D { + CUTE_HOST_DEVICE static void copy(void const* desc_ptr, uint64_t* mbar_ptr, + uint16_t multicast_mask, uint64_t cache_hint, void* smem_ptr, + int32_t const& crd0, int32_t const& crd1) { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5}], [%2], %3, %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), "r"(crd0), + "r"(crd1), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_2D { + CUTE_HOST_DEVICE static void copy(void const* desc_ptr, void const* smem_ptr, int32_t const& crd0, + int32_t const& crd1) { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, " + "{%2, %3}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +} // namespace cute diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/nvrtc_std.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/nvrtc_std.cuh new file mode 100644 index 0000000000..4fad508e67 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/nvrtc_std.cuh @@ -0,0 +1,57 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef NVRTC_JIT_COMPILATION + +using int8_t = signed char; +using uint8_t = unsigned char; +using int16_t = signed short; +using uint16_t = unsigned short; +using int32_t = signed int; +using uint32_t = unsigned int; +using int64_t = signed long long; +using uint64_t = unsigned long long; +using cuuint64_t = unsigned long long; + +namespace std { +template +struct integral_constant { + static constexpr T value = v; + using value_type = T; + using type = integral_constant; // using injected-class-name + + __device__ constexpr operator value_type() const noexcept { return value; } + + __device__ constexpr value_type operator()() const noexcept { return value; } // since c++14 +}; + +using false_type = integral_constant; +using true_type = integral_constant; + +template +struct is_same : false_type {}; + +template +struct is_same : true_type {}; + +template +inline constexpr bool is_same_v = is_same::value; +} // namespace std + +#endif diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh new file mode 100644 index 0000000000..35af1fcd23 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/runtime.cuh @@ -0,0 +1,181 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +#include "jit_utils.cuh" +#include "scheduler.cuh" + +namespace deep_gemm::jit { + +static bool kJitDebugging = []() { + char const* env_var = getenv("TRTLLM_DG_JIT_DEBUG"); + return env_var && (std::string(env_var) == "1" || std::string(env_var) == "true"); +}(); + +static bool kJitUseNvcc = []() { + char const* env_var = getenv("TRTLLM_DG_JIT_USE_NVCC"); + return env_var && (std::string(env_var) == "1" || std::string(env_var) == "true"); +}(); + +static bool kJitDumpCubin = []() { + char const* env_var = getenv("TRTLLM_DG_JIT_DUMP_CUBIN"); + return env_var && (std::string(env_var) == "1" || std::string(env_var) == "true"); +}(); + +static std::string kKernelName = kJitUseNvcc ? "nvcc_kernel.cubin" : "nvrtc_kernel.cubin"; + +/** + * C++ implementation of the Runtime class from runtime.py + * Loads and executes JIT-compiled kernels + */ +class Runtime { + public: + Runtime(std::string const& path, std::vector const& cubin, deep_gemm::GemmType gemm_type) + : path_(path), cubin_(cubin), gemm_type_(gemm_type), lib_(nullptr), kernel_(nullptr) { + DG_HOST_ASSERT(!cubin.empty() || isPathValid(path_)); + } + + ~Runtime() { + if (lib_ != nullptr) { + CHECK_CUDA(cuLibraryUnload(lib_)); + } + } + + static bool isPathValid(std::string const& path) { + // Check if path exists and is a directory + if (!std::filesystem::exists(path) || !std::filesystem::is_directory(path)) { + return false; + } + + // Check if all necessary files exist + return std::filesystem::exists(std::filesystem::path(path) / kKernelName); + } + + CUkernel getKernel() { + // Load shared object if not already loaded + if (kernel_ == nullptr) { + if (cubin_.empty()) { + std::filesystem::path cubinPath = std::filesystem::path(path_); + cubinPath /= kKernelName; + std::ifstream cubinFile(cubinPath.string(), std::ios::binary); + cubin_ = std::vector(std::istreambuf_iterator(cubinFile), {}); + } + + CHECK_CUDA(cuLibraryLoadData(&lib_, cubin_.data(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + + unsigned int numKernels = 0; + CHECK_CUDA(cuLibraryGetKernelCount(&numKernels, lib_)); + + std::vector kernels(numKernels); + CHECK_CUDA(cuLibraryEnumerateKernels(kernels.data(), numKernels, lib_)); + + for (auto kernel : kernels) { + char const* kernelName; + CHECK_CUDA(cuKernelGetName(&kernelName, kernel)); + std::string kernelNameStr(kernelName); + if (kernelNameStr.find("fp8_gemm_kernel") != std::string::npos) { + kernel_ = kernel; + break; + } + } + + if (!kernel_) { + throw std::runtime_error("Failed to find fp8_gemm_kernel"); + } + } + + return kernel_; + } + + private: + std::string path_; + std::vector cubin_; + CUlibrary lib_; + CUkernel kernel_; + deep_gemm::GemmType gemm_type_; +}; + +/** + * C++ implementation of the RuntimeCache class from runtime.py + * Caches Runtime instances by path + */ +class RuntimeCache { + public: + static RuntimeCache& getInstance() { + static RuntimeCache instance; + return instance; + } + + Runtime* operator[](std::string const& path) { + // Check if already in cache + auto it = cache_.find(path); + if (it != cache_.end()) { + return it->second.get(); + } + + // Check if already compiled + if (Runtime::isPathValid(path)) { + // Parse path to get gemm type + std::string gemm_type_str = path.substr(path.find_last_of('_') + 1); + deep_gemm::GemmType gemm_type; + if (gemm_type_str == "Normal") { + gemm_type = deep_gemm::GemmType::Normal; + } else if (gemm_type_str == "GroupedWithOffset") { + gemm_type = deep_gemm::GemmType::GroupedWithOffset; + } else if (gemm_type_str == "StridedBatched") { + gemm_type = deep_gemm::GemmType::StridedBatched; + } else { + throw std::runtime_error("Unsupported gemm type: " + gemm_type_str); + } + + auto runtime = std::make_unique(path, std::vector(), gemm_type); + Runtime* result = runtime.get(); + cache_[path] = std::move(runtime); + return result; + } + + return nullptr; + } + + void set(std::string const& path, std::unique_ptr&& runtime) { + cache_[path] = std::move(runtime); + } + + private: + // Private constructor for singleton pattern + RuntimeCache() = default; + + // Delete copy constructor and assignment operator + RuntimeCache(RuntimeCache const&) = delete; + RuntimeCache& operator=(RuntimeCache const&) = delete; + + std::unordered_map> cache_; +}; + +// Global function to access the singleton +RuntimeCache& getGlobalRuntimeCache() { return RuntimeCache::getInstance(); } + +} // namespace deep_gemm::jit diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/scheduler.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/scheduler.cuh new file mode 100644 index 0000000000..15a624af42 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/scheduler.cuh @@ -0,0 +1,708 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 DeepSeek + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/MIT + * + * + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#ifndef NVRTC_JIT_COMPILATION +#include +#endif + +#include "utils.cuh" + +namespace deep_gemm { + +enum class GemmType { Normal, GroupedContiguous, GroupedMasked, GroupedWithOffset, StridedBatched }; + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" + +template +__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, + uint32_t& m_block_idx, + uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup; + auto group_idx = block_idx / num_blocks_per_group; + auto first_n_block_idx = group_idx * kNumNBlocksPerGroup; + auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx); + auto in_group_idx = block_idx % num_blocks_per_group; + m_block_idx = in_group_idx / num_n_blocks_in_group; + n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group; +} + +struct NormalSchedulerInput { + uint32_t shape_m; + int* grouped_layout; // no use +}; + +struct NormalSchedulerInputSwapAB { + uint32_t shape_n; + int* grouped_layout; // no use +}; + +struct GroupedContiguousSchedulerInput { + uint32_t shape_m; + int* grouped_layout; +}; + +struct GroupedMaskedSchedulerInput { + uint32_t shape_m; + int* grouped_layout; +}; + +struct GroupedWithOffsetSchedulerInput { + uint32_t shape_m; + int64_t* problem_m_offsets; +}; + +struct GroupedWithOffsetSchedulerInputSwapAB { + uint32_t shape_m; + int64_t* problem_n_offsets; +}; + +struct StridedBatchedSchedulerInput { + uint32_t shape_m; + uint64_t ld_a; + uint64_t stride_a; + uint64_t ld_b; + uint64_t stride_b; + uint64_t ld_d; + uint64_t stride_d; +}; + +struct StridedBatchedSchedulerInputSwapAB { + uint32_t shape_n; + uint64_t ld_a; + uint64_t stride_a; + uint64_t ld_b; + uint64_t stride_b; + uint64_t ld_d; + uint64_t stride_d; +}; + +template +struct NormalScheduler { + static constexpr GemmType gemm_type = GemmType::Normal; + + int current_iter = -1; + uint32_t num_aligned_m_blocks; + uint32_t num_blocks; + + using Input = NormalSchedulerInput; + Input input; + + NormalScheduler() {} + + __device__ __forceinline__ NormalScheduler(Input& input) { + num_aligned_m_blocks = ceil_div(input.shape_m, BLOCK_M); + num_blocks = num_aligned_m_blocks * kNumNBlocks; + } + + __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { + return block_idx * BLOCK_M; + } + + __device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const shape_dim, + uint32_t const block_size, + uint32_t const& block_idx, + uint32_t const& m_block_idx = 0) { + return block_idx * block_size; + } + + __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { + return block_idx; + } + + __device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const shape_dim, + uint32_t const block_size, + uint32_t const& block_idx, + uint32_t const& m_block_idx = 0) { + return block_idx * block_size; + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + ++current_iter; + auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; + if (next_block_idx >= num_blocks) { + return false; + } + get_swizzled_block_idx( + num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); + return true; + } +}; + +template +struct NormalSchedulerSwapAB { + static constexpr GemmType gemm_type = GemmType::Normal; + + int current_iter = -1; + uint32_t num_aligned_n_blocks; + uint32_t num_blocks; + + using Input = NormalSchedulerInputSwapAB; + Input input; + + NormalSchedulerSwapAB() {} + + __device__ __forceinline__ NormalSchedulerSwapAB(Input& input) { + num_aligned_n_blocks = ceil_div(input.shape_n, BLOCK_N); + num_blocks = num_aligned_n_blocks * kNumMBlocks; + } + + // weight + __device__ __forceinline__ uint32_t get_global_m_idx(const uint32_t shape_dim, + const uint32_t block_size, + uint32_t const& block_idx, + uint32_t const& n_block_idx = 0) { + return block_idx * block_size; + } + + // act + __device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx) { + return block_idx * BLOCK_N; + } + + // act scales + __device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx) { + return block_idx; + } + + // weight scales + __device__ __forceinline__ uint32_t get_global_scales_a_idx(const uint32_t shape_dim, + const uint32_t block_size, + uint32_t const& block_idx, + uint32_t const& n_block_idx = 0) { + return block_idx * block_size; + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + ++current_iter; + auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; + if (next_block_idx >= num_blocks) { + return false; + } + + get_swizzled_block_idx( + num_aligned_n_blocks, next_block_idx, n_block_idx, m_block_idx); + return true; + } +}; + +template +struct GroupedContiguousScheduler { + static constexpr GemmType gemm_type = GemmType::GroupedContiguous; + + int current_iter = -1; + uint32_t num_aligned_m_blocks; + int* grouped_layout; + uint32_t num_blocks; + uint32_t shape_m; + + using Input = GroupedContiguousSchedulerInput; + Input input; + + GroupedContiguousScheduler() {} + + __device__ __forceinline__ GroupedContiguousScheduler(Input& input) { + num_aligned_m_blocks = ceil_div(input.shape_m, BLOCK_M); + num_blocks = num_aligned_m_blocks * kNumNBlocks; + this->shape_m = input.shape_m; + this->grouped_layout = input.grouped_layout; + } + + __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { + return block_idx * BLOCK_M; + } + + __device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const shape_dim, + uint32_t const block_size, + uint32_t const& block_idx, + uint32_t const& m_block_idx = 0) { + return __ldg(grouped_layout + m_block_idx * BLOCK_M) * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { + return block_idx; + } + + __device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const shape_dim, + uint32_t const block_size, + uint32_t const& block_idx, + uint32_t const& m_block_idx = 0) { + return __ldg(grouped_layout + m_block_idx * BLOCK_M) * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + ++current_iter; + auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; + if (next_block_idx >= num_blocks) { + return false; + } + get_swizzled_block_idx( + num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx); + return true; + } +}; + +template +struct GroupedMaskedScheduler { + static constexpr GemmType gemm_type = GemmType::GroupedMasked; + + int current_iter = -1; + uint32_t num_blocks; + uint32_t num_aligned_m_blocks; + uint32_t curr_group_idx; + uint32_t curr_cumsum; + uint32_t shape_m; + int* grouped_layout; + + using Input = GroupedMaskedSchedulerInput; + Input input; + + GroupedMaskedScheduler() {} + + __device__ __forceinline__ GroupedMaskedScheduler(Input& input) { + num_aligned_m_blocks = ceil_div(input.shape_m, BLOCK_M); + num_blocks = num_aligned_m_blocks * kNumNBlocks; + this->shape_m = input.shape_m; + this->grouped_layout = input.grouped_layout; + curr_group_idx = 0; + curr_cumsum = 0; + } + + __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { + return curr_group_idx * shape_m + block_idx * BLOCK_M; + } + + __device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const shape_dim, + uint32_t const block_size, + uint32_t const& block_idx, + uint32_t const& m_block_idx = 0) { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { + return curr_group_idx * ceil_div(SHAPE_K, BLOCK_K) + block_idx; + } + + __device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const shape_dim, + uint32_t const block_size, + uint32_t const& block_idx, + uint32_t const& m_block_idx = 0) { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + ++current_iter; + auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; + uint32_t num_m_blocks; + while (true) { + // End of the task + if (curr_group_idx == kNumGroups) return false; + + // Within current group + num_m_blocks = + ceil_div(static_cast(__ldg(grouped_layout + curr_group_idx)), BLOCK_M); + auto current_m_block_cumsum = curr_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * kNumNBlocks) break; + + // Move to check the next group + curr_group_idx++; + curr_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx( + num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); + return true; + } +}; + +// Need to keep the same as the one in tests/unittest/_torch/thop/deep_gemm_tests.py +template +__host__ __device__ __forceinline__ T_offset compute_padded_offset(T_offset offset, + T_index problem_idx) { + // This formulation ensures that padded_offset[i + 1] - padded_offset[i] >= offset[i + 1] - + // offset[i]. + constexpr T_offset alignment = 32; + return (offset + problem_idx * (alignment - 1)) / alignment * alignment; +} + +template +struct GroupedWithOffsetScheduler { + static constexpr GemmType gemm_type = GemmType::GroupedWithOffset; + + int current_iter = -1; + uint32_t curr_group_idx; + uint32_t curr_cumsum; + int64_t m_offset; + int64_t m_padded_4_offset; + int64_t m_boundary; + int64_t* problem_m_offsets; + + using Input = GroupedWithOffsetSchedulerInput; + Input input; + + GroupedWithOffsetScheduler() {} + + __device__ __forceinline__ GroupedWithOffsetScheduler(Input& input) { + this->problem_m_offsets = input.problem_m_offsets; + curr_group_idx = 0; + curr_cumsum = 0; + } + + __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { + return m_offset + block_idx * BLOCK_M; + } + + __device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const shape_dim, + uint32_t const block_size, + uint32_t const& block_idx, + uint32_t const& m_block_idx = 0) { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { + return m_padded_4_offset + block_idx * BLOCK_M; + } + + __device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const shape_dim, + uint32_t const block_size, + uint32_t const& block_idx, + uint32_t const& m_block_idx = 0) { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + ++current_iter; + auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; + uint32_t num_m_blocks; + while (true) { + // End of the task + if (curr_group_idx == kNumGroups) return false; + m_offset = __ldg(problem_m_offsets + curr_group_idx); + m_boundary = __ldg(problem_m_offsets + curr_group_idx + 1); + m_padded_4_offset = compute_padded_offset(m_offset, curr_group_idx); + auto m = m_boundary - m_offset; + // Within current group + num_m_blocks = ceil_div(m, static_cast(BLOCK_M)); + auto current_m_block_cumsum = curr_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * kNumNBlocks) break; + + // Move to check the next group + curr_group_idx++; + curr_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx( + num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); + return true; + } +}; + +template +struct GroupedWithOffsetSchedulerSwapAB { + static constexpr GemmType gemm_type = GemmType::GroupedWithOffset; + + int current_iter = -1; + uint32_t curr_group_idx; + uint32_t curr_cumsum; + int64_t n_offset; + int64_t n_padded_4_offset; + int64_t n_boundary; + int64_t* problem_n_offsets; + + using Input = GroupedWithOffsetSchedulerInputSwapAB; + Input input; + + GroupedWithOffsetSchedulerSwapAB() {} + + __device__ __forceinline__ GroupedWithOffsetSchedulerSwapAB(Input& input) { + this->problem_n_offsets = input.problem_n_offsets; + curr_group_idx = 0; + curr_cumsum = 0; + } + + // weight + __device__ __forceinline__ uint32_t get_global_m_idx(const uint32_t shape_dim, + const uint32_t block_size, + uint32_t const& block_idx, + uint32_t const& n_block_idx = 0) { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + // act + __device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx) { + return n_offset + block_idx * BLOCK_N; + } + + // act scales + __device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx) { + return n_padded_4_offset + block_idx * BLOCK_N; + } + + // weight scales + __device__ __forceinline__ uint32_t get_global_scales_a_idx(const uint32_t shape_dim, + const uint32_t block_size, + uint32_t const& block_idx, + uint32_t const& n_block_idx = 0) { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + ++current_iter; + auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; + uint32_t num_n_blocks; + while (true) { + // End of the task + if (curr_group_idx == kNumGroups) return false; + n_offset = __ldg(problem_n_offsets + curr_group_idx); + n_boundary = __ldg(problem_n_offsets + curr_group_idx + 1); + n_padded_4_offset = compute_padded_offset(n_offset, curr_group_idx); + auto n = n_boundary - n_offset; + // Within current group + num_n_blocks = ceil_div(n, static_cast(BLOCK_N)); + auto current_n_block_cumsum = curr_cumsum + num_n_blocks; + if (next_block_idx < current_n_block_cumsum * kNumMBlocks) break; + + // Move to check the next group + curr_group_idx++; + curr_cumsum = current_n_block_cumsum; + } + + get_swizzled_block_idx( + num_n_blocks, next_block_idx - curr_cumsum * kNumMBlocks, n_block_idx, m_block_idx); + return true; + } +}; + +template +struct StridedBatchedScheduler { + static constexpr GemmType gemm_type = GemmType::StridedBatched; + + int current_iter = -1; + uint32_t curr_group_idx; + uint32_t curr_cumsum; + int64_t m_offset; + int64_t m_boundary; + + using Input = StridedBatchedSchedulerInput; + Input input; + + StridedBatchedScheduler() {} + + __device__ __forceinline__ StridedBatchedScheduler(Input& input) { + this->input = input; + curr_group_idx = 0; + curr_cumsum = 0; + } + + __device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx) { + // Assuming stride_a % ld_a == 0 && stride_a >= ld_a + return input.stride_a / input.ld_a * curr_group_idx + block_idx * BLOCK_M; + } + + __device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const shape_dim, + uint32_t const block_size, + uint32_t const& block_idx, + uint32_t const& m_block_idx = 0) { + // Assuming stride_b % ld_b == 0 && stride_b >= ld_b + return input.stride_b / input.ld_b * curr_group_idx + block_idx * block_size; + } + + __device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx) { + return curr_group_idx * ceil_div(SHAPE_K, BLOCK_K) + block_idx; + } + + __device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const shape_dim, + uint32_t const block_size, + uint32_t const& block_idx, + uint32_t const& m_block_idx = 0) { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + ++current_iter; + auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; + uint32_t num_m_blocks; + while (true) { + // End of the task + if (curr_group_idx == kNumGroups) return false; + m_offset = curr_group_idx * input.shape_m; + m_boundary = (curr_group_idx + 1) * input.shape_m; + // Within current group + num_m_blocks = ceil_div(input.shape_m, BLOCK_M); + auto current_m_block_cumsum = curr_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * kNumNBlocks) break; + + // Move to check the next group + curr_group_idx++; + curr_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx( + num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx); + return true; + } +}; + +template +struct StridedBatchedSchedulerSwapAB { + static constexpr GemmType gemm_type = GemmType::StridedBatched; + + int current_iter = -1; + uint32_t curr_group_idx; + uint32_t curr_cumsum; + int64_t n_offset; + int64_t n_boundary; + + using Input = StridedBatchedSchedulerInputSwapAB; + Input input; + + StridedBatchedSchedulerSwapAB() {} + + __device__ __forceinline__ StridedBatchedSchedulerSwapAB(Input& input) { + this->input = input; + curr_group_idx = 0; + curr_cumsum = 0; + } + + // weight + __device__ __forceinline__ uint32_t get_global_m_idx(const uint32_t shape_dim, + const uint32_t block_size, + uint32_t const& block_idx, + uint32_t const& n_block_idx = 0) { + // Assuming stride_a % ld_a == 0 && stride_a >= ld_a + return input.stride_a / input.ld_a * curr_group_idx + block_idx * block_size; + } + + // act + __device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx) { + // Assuming stride_b % ld_b == 0 && stride_b >= ld_b + return input.stride_b / input.ld_b * curr_group_idx + block_idx * BLOCK_N; + } + + // act scales + __device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx) { + return curr_group_idx * ceil_div(SHAPE_K, BLOCK_K) + block_idx; + } + + // weight scales + __device__ __forceinline__ uint32_t get_global_scales_a_idx(const uint32_t shape_dim, + const uint32_t block_size, + uint32_t const& block_idx, + uint32_t const& n_block_idx = 0) { + return curr_group_idx * shape_dim + block_idx * block_size; + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + ++current_iter; + auto const next_block_idx = current_iter * gridDim.x + blockIdx.x; + uint32_t num_n_blocks; + while (true) { + // End of the task + if (curr_group_idx == kNumGroups) return false; + n_offset = curr_group_idx * input.shape_n; + n_boundary = (curr_group_idx + 1) * input.shape_n; + // Within current group + num_n_blocks = ceil_div(input.shape_n, BLOCK_N); + auto current_n_block_cumsum = curr_cumsum + num_n_blocks; + if (next_block_idx < current_n_block_cumsum * kNumMBlocks) break; + + // Move to check the next group + curr_group_idx++; + curr_cumsum = current_n_block_cumsum; + } + + // Note: Here, m and n roles are swapped + get_swizzled_block_idx( + num_n_blocks, next_block_idx - curr_cumsum * kNumMBlocks, n_block_idx, m_block_idx); + return true; + } +}; + +template +struct SchedulerSelector { + static constexpr auto select_type() { + if constexpr (GT == GemmType::Normal) + return NormalScheduler(); + if constexpr (GT == GemmType::GroupedContiguous) + return GroupedContiguousScheduler(); + if constexpr (GT == GemmType::GroupedMasked) + return GroupedMaskedScheduler(); + if constexpr (GT == GemmType::GroupedWithOffset) + return GroupedWithOffsetScheduler(); + if constexpr (GT == GemmType::StridedBatched) + return StridedBatchedScheduler(); + } + + using type = decltype(select_type()); +}; + +template +struct SchedulerSelectorSwapAB { + static constexpr auto select_type() { + static_assert(GT == GemmType::GroupedWithOffset || GT == GemmType::Normal, + "Only GroupedWithOffset and Normal are supported for SwapAB"); + if constexpr (GT == GemmType::Normal) + return NormalSchedulerSwapAB(); + if constexpr (GT == GemmType::GroupedWithOffset) + return GroupedWithOffsetSchedulerSwapAB(); + } + + using type = decltype(select_type()); +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/tma_utils.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/tma_utils.cuh new file mode 100644 index 0000000000..173998b089 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/tma_utils.cuh @@ -0,0 +1,128 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 DeepSeek + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/MIT + * + * + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifndef NVRTC_JIT_COMPILATION +#include +#include +#include + +#include +#include +#include +#endif + +#include + +#include "utils.cuh" + +namespace deep_gemm { + +#ifndef NVRTC_JIT_COMPILATION +template +constexpr CUtensorMapDataType get_CUtensorMapDataType() { + if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT16; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT32; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT64; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_INT32; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_INT64; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; + } +} + +PFN_cuTensorMapEncodeTiled_v12000 get_cuTensorMapEncodeTiled() { + // Get pointer to `cuTensorMapEncodeTiled` + cudaDriverEntryPointQueryResult driver_status; + void* cuTensorMapEncodeTiled_ptr = nullptr; + +#if CUDA_VERSION >= 12050 + cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000, + cudaEnableDefault, &driver_status); +#else + cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, cudaEnableDefault, + &driver_status); +#endif + + if (driver_status != cudaDriverEntryPointSuccess) + throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess"); + return reinterpret_cast(cuTensorMapEncodeTiled_ptr); +} + +template +CUtensorMap make_2d_tma_copy_desc(T* global_address, uint64_t gmem_dim[2], uint64_t stride_in_bytes, + uint32_t smem_dim[2], CUtensorMapSwizzle swizzle_type, + PFN_cuTensorMapEncodeTiled_v12000 encode_func = nullptr) { + CUtensorMap tensor_map{}; + constexpr uint32_t rank = 2; + uint64_t global_stride[rank - 1] = {stride_in_bytes}; + uint32_t elem_strides[rank] = {1, 1}; + + if (encode_func == nullptr) encode_func = get_cuTensorMapEncodeTiled(); + + auto result = + encode_func(&tensor_map, get_CUtensorMapDataType::type>(), rank, + global_address, gmem_dim, global_stride, smem_dim, elem_strides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + DG_HOST_ASSERT(result == CUDA_SUCCESS); + return tensor_map; +} +#endif + +template +__device__ __forceinline__ void tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, + void* smem_ptr, int32_t const& crd_0, + int32_t const& crd_1) { + constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); + if constexpr (kNumTMAMulticast == 1) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); + } else if (cute::block_rank_in_cluster() == 0) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << kNumTMAMulticast) - 1, + cache_hint, smem_ptr, crd_0, crd_1); + } +} + +} // namespace deep_gemm diff --git a/csrc/nv_internal/tensorrt_llm/deep_gemm/utils.cuh b/csrc/nv_internal/tensorrt_llm/deep_gemm/utils.cuh new file mode 100644 index 0000000000..a017db18ff --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/deep_gemm/utils.cuh @@ -0,0 +1,74 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 DeepSeek + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/MIT + * + * + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifndef NVRTC_JIT_COMPILATION +#include + +class AssertionException : public std::exception { + private: + std::string message{}; + + public: + explicit AssertionException(std::string const& message) : message(message) {} + + char const* what() const noexcept override { return message.c_str(); } +}; +#endif + +#ifndef DG_HOST_ASSERT +#ifdef NVRTC_JIT_COMPILATION +#define DG_HOST_ASSERT(cond) ((void)0) +#else +#define DG_HOST_ASSERT(cond) \ + do { \ + if (not(cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + throw AssertionException("Assertion failed: " #cond); \ + } \ + } while (0) +#endif +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ + do { \ + if (not(cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ + } while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason) +#endif + +template +__device__ __host__ constexpr T ceil_div(T a, T b) { + return (a + b - 1) / b; +} diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh new file mode 100644 index 0000000000..08661a385e --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh @@ -0,0 +1,443 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "sm89_utils.cuh" + +namespace ada_blockwise_gemm { + +template +CUTLASS_GLOBAL void sm89_fp8_gemm_1d1d_impl(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + void const* A, void const* B, void* D, + float const* scales_a, float const* scales_b) { + GemmKernel op; + op.invoke(shape_m, shape_n, shape_k, A, B, D, scales_a, scales_b); +} + +template +CUTLASS_GLOBAL void sm89_fp8_bmm_1d1d_impl(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + __nv_fp8_e4m3* A, __nv_fp8_e4m3* B, __nv_bfloat16* D, + float* scales_a, float* scales_b, uint64_t stride_a, + uint64_t stride_b, uint64_t stride_d, + uint64_t stride_scales_a, uint64_t stride_scales_b) { + GemmKernel op; + + auto ptr_a = + reinterpret_cast(A + blockIdx.z * stride_a); + auto ptr_b = + reinterpret_cast(B + blockIdx.z * stride_b); + auto ptr_scale_a = reinterpret_cast( + scales_a + blockIdx.z * stride_scales_a); + auto ptr_scale_b = reinterpret_cast( + scales_b + blockIdx.z * stride_scales_b); + auto ptr_output = + reinterpret_cast(D + blockIdx.z * stride_d); + + op(ptr_a, ptr_b, ptr_scale_a, ptr_scale_b, ptr_output, shape_m, shape_n, shape_k); +} + +template +struct AdaBlockwiseGemmKernel { + using SharedStorage = typename KT::SharedStorage; + using ElementInput = typename KT::ElementInput; + using ElementOutput = typename KT::ElementOutput; + using ElementBlockScale = typename KT::ElementBlockScale; + + // Factory invocation + CUTLASS_DEVICE + void invoke(uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, void const* A, void const* B, + void* D, float const* scales_a, float const* scales_b) { + auto ptr_a = reinterpret_cast(A); + auto ptr_b = reinterpret_cast(B); + auto ptr_scale_a = reinterpret_cast(scales_a); + auto ptr_scale_b = reinterpret_cast(scales_b); + auto ptr_output = reinterpret_cast(D); + + (*this)(ptr_a, ptr_b, ptr_scale_a, ptr_scale_b, ptr_output, shape_m, shape_n, shape_k); + } + + CUTE_DEVICE auto gmem_tensor_init(typename KT::ElementInput const* ptr_a, + typename KT::ElementInput const* ptr_b, + typename KT::ElementBlockScale const* ptr_scale_a, + typename KT::ElementBlockScale const* ptr_scale_b, uint32_t M, + uint32_t N, uint32_t K, int* SharedStorageBase) { + using X = cute::Underscore; + + uint32_t const ScaleM = (((M + 3) >> 2) << 2); // align 4 + uint32_t const ScaleN = (N + KT::ScaleGranularityN - 1) / KT::ScaleGranularityN; + uint32_t const ScaleK = (K + KT::ScaleGranularityK - 1) / KT::ScaleGranularityK; + + auto mA_mk = cute::make_tensor(cute::make_gmem_ptr(ptr_a), cute::make_shape(M, K), + cute::make_stride(K, cute::_1{})); + + auto mB_nk = cute::make_tensor(cute::make_gmem_ptr(ptr_b), cute::make_shape(N, K), + cute::make_stride(K, cute::_1{})); + + auto mSFA_mk = + cute::make_tensor(cute::make_gmem_ptr(ptr_scale_a), cute::make_shape(ScaleM, ScaleK), + cute::make_stride(cute::_1{}, ScaleM)); + + auto mSFB_nk = + cute::make_tensor(cute::make_gmem_ptr(ptr_scale_b), cute::make_shape(ScaleN, ScaleK), + cute::make_stride(ScaleK, cute::_1{})); + + auto cta_coord = cute::make_coord(blockIdx.x, blockIdx.y, cute::_); // (m,n,k) + auto gA = cute::local_tile(mA_mk, typename KT::TileShape{}, cta_coord, + cute::Step<_1, X, _1>{}); // (BLK_M,BLK_K,k) + auto gB = cute::local_tile(mB_nk, typename KT::TileShape{}, cta_coord, + cute::Step{}); // (BLK_N,BLK_K,k) + auto gSFA = cute::local_tile(mSFA_mk, typename KT::ScalePerTileShape{}, cta_coord, + cute::Step<_1, X, _1>{}); // (BLK_M,BLK_K) + auto gSFB = cute::local_tile(mSFB_nk, typename KT::ScalePerTileShape{}, cta_coord, + cute::Step{}); // (BLK_N,BLK_K) + + typename KT::SharedStorageLoad* load_storage = + reinterpret_cast(SharedStorageBase); + auto sA = cute::make_tensor(cute::make_smem_ptr(load_storage->smem_a.data()), + typename KT::SmemLayoutA{}); + auto sB = cute::make_tensor(cute::make_smem_ptr(load_storage->smem_b.data()), + typename KT::SmemLayoutB{}); + auto sSFA = cute::make_tensor(cute::make_smem_ptr(load_storage->smem_sfa.data()), + typename KT::SmemLayoutSFA{}); + auto sSFB = cute::make_tensor(cute::make_smem_ptr(load_storage->smem_sfb.data()), + typename KT::SmemLayoutSFB{}); + + return cute::make_tuple(gA, gB, gSFA, gSFB, sA, sB, sSFA, sSFB); + } + + template + CUTE_DEVICE void epilogue_with_smem(Accumulator& accum, SharedStorage& shared_storage, + ElementOutput* o, int M, int N) { + // convert type + auto epi = cute::make_fragment_like(accum); + cute::for_each(cute::make_int_sequence{}, + [&](auto i) { epi(i) = ElementOutput(accum(i)); }); + + auto sO = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_o.data()), + typename KT::SmemLayoutO{}); + // copy rf -> smem + typename KT::TiledMma mma; + auto tiled_copy_R2S = cute::make_tiled_copy_C(typename KT::SmemCopyAtomR2S{}, mma); + auto thr_copy_R2S = tiled_copy_R2S.get_slice(threadIdx.x); + auto tRS_rO = thr_copy_R2S.retile_S(epi); + auto tRS_sO = thr_copy_R2S.partition_D(sO); + + cute::copy(tiled_copy_R2S, tRS_rO, tRS_sO); + __syncthreads(); + + // copy smem -> rf + typename KT::TiledCopyS2G tiled_copy_S2G; + auto thr_copy_S2G = tiled_copy_S2G.get_slice(threadIdx.x); + auto tSR_sO = thr_copy_S2G.partition_S(sO); + auto tSR_rO = cute::make_tensor(cute::shape(tSR_sO)); + + cute::copy(tiled_copy_S2G, tSR_sO, tSR_rO); + __syncthreads(); + + // copy rf -> gmem + auto mO = cute::make_tensor(cute::make_gmem_ptr(o), cute::make_shape(M, N), + cute::make_stride(N, cute::_1{})); + auto cta_coord = cute::make_coord(blockIdx.x, blockIdx.y, cute::_); + auto gO = cute::local_tile(mO, typename KT::TileShape{}, cta_coord, + cute::Step{}); + auto cO = cute::make_identity_tensor( + cute::make_shape(cute::Int{}, cute::Int{})); + auto tRG_rO = thr_copy_S2G.retile_S(tSR_rO); + auto tRG_gO = thr_copy_S2G.partition_D(gO); + auto tRG_cO = thr_copy_S2G.partition_D(cO); + + int residue_m = M - KT::kTileM * blockIdx.x; + int residue_n = N - KT::kTileN * blockIdx.y; + CUTE_UNROLL + for (int m = 0; m < cute::size<1>(tRG_gO); ++m) { + CUTE_UNROLL + for (int n = 0; n < cute::size<2>(tRG_gO); ++n) { + if (cute::get<0>(tRG_cO(0, m, n)) < residue_m && + cute::get<1>(tRG_cO(0, m, n)) < residue_n) { + cute::copy(typename KT::GmemCopyAtomR2G{}, tRG_rO(cute::_, m, n), tRG_gO(cute::_, m, n)); + } + } + } + } + + template + CUTE_DEVICE void promote(TensorD& accum, TensorC const& temp_accum, TensorScale const& scale, + Index n_block) { + using AccumType = typename TensorD::value_type; + for (int mma_m = 0; mma_m < cute::get<1>(cute::shape<0>(accum)); ++mma_m) { + CUTE_UNROLL + for (int mma_n = 0; mma_n < cute::get<0>(cute::shape<0>(accum)); ++mma_n) { + CUTE_UNROLL + for (int mma_iter_m = 0; mma_iter_m < cute::size<1>(accum); ++mma_iter_m) { + CUTE_UNROLL + for (int mma_iter_n = 0; mma_iter_n < cute::size<2>(accum); ++mma_iter_n) { + auto coord_d = + cute::make_coord(cute::make_coord(mma_n, mma_m), mma_iter_m, mma_iter_n, n_block); + auto coord_c = cute::make_coord(cute::make_coord(mma_n, mma_m), mma_iter_m, mma_iter_n); + accum(coord_d) += temp_accum(coord_c) * scale(mma_m, mma_iter_m, cute::_0{}); + } + } + } + } + } + + /// Executes one GEMM + CUTE_DEVICE + void operator()(typename KT::ElementInput const* ptr_a, typename KT::ElementInput const* ptr_b, + typename KT::ElementBlockScale const* ptr_scale_a, + typename KT::ElementBlockScale const* ptr_scale_b, + typename KT::ElementOutput* ptr_output, uint32_t M, uint32_t N, uint32_t K) { + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + + auto [gA, gB, gSFA, gSFB, sA, sB, sSFA, sSFB] = + gmem_tensor_init(ptr_a, ptr_b, ptr_scale_a, ptr_scale_b, M, N, K, SharedStorageBase); + typename KT::GmemTiledCopyA g2s_copy_A; + typename KT::GmemTiledCopyB g2s_copy_B; + auto g2s_thr_copy_A = g2s_copy_A.get_slice(threadIdx.x); + auto g2s_thr_copy_B = g2s_copy_B.get_slice(threadIdx.x); + + auto tAgA = g2s_thr_copy_A.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + auto tAsA = g2s_thr_copy_A.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,Stage) + auto tBgB = g2s_thr_copy_B.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + auto tBsB = g2s_thr_copy_B.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,Stage) + + typename KT::GmemTiledCopySFA g2s_copy_SFA; + typename KT::GmemTiledCopySFB g2s_copy_SFB; + auto g2s_thr_copy_SFA = g2s_copy_SFA.get_slice(threadIdx.x); + auto g2s_thr_copy_SFB = g2s_copy_SFB.get_slice(threadIdx.x); + + auto tAgSFA = g2s_thr_copy_SFA.partition_S(gSFA); // (ACPY,ACPY_M,ACPY_K,Stage) + auto tAsSFA = g2s_thr_copy_SFA.partition_D(sSFA); // (ACPY,ACPY_M,ACPY_K,Stage) + auto tBgSFB = g2s_thr_copy_SFB.partition_S(gSFB); // (BCPY,BCPY_N,BCPY_K,Stage) + auto tBsSFB = g2s_thr_copy_SFB.partition_D(sSFB); // (BCPY,BCPY_N,BCPY_K,Stage) + + auto cA = make_identity_tensor(cute::make_shape(cute::size<0>(sA), cute::size<1>(sA))); + auto tAcA = g2s_thr_copy_A.partition_S(cA); + + auto cB = make_identity_tensor(cute::make_shape(cute::size<0>(sB), cute::size<1>(sB))); + auto tBcB = g2s_thr_copy_B.partition_S(cB); + + auto cSFA = cute::make_identity_tensor(typename KT::GmemTiledCopySFA::Tiler_MN{}); + auto tAcSFA = g2s_thr_copy_SFA.partition_S(cSFA); + + int residue_m = M - KT::kTileM * blockIdx.x; + int residue_n = N - KT::kTileN * blockIdx.y; + residue_m = residue_m > KT::kTileM ? KT::kTileM : residue_m; + residue_n = residue_n > KT::kTileN ? KT::kTileN : residue_n; + + auto tApA = cute::make_tensor(cute::make_shape(cute::size<1>(tAsA), cute::size<2>(tAsA)), + cute::Stride{}); + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<0>(tApA); ++m) { + tApA(m, 0) = cute::get<0>(tAcA(0, m, 0)) < residue_m; // blk_m coord < residue_m + } + + auto tBpB = cute::make_tensor(cute::make_shape(cute::size<1>(tBsB), cute::size<2>(tBsB)), + cute::Stride{}); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < cute::size<0>(tBpB); ++n) { + tBpB(n, 0) = cute::get<0>(tBcB(0, n, 0)) < residue_n; // blk_n coord < residue_n + } + + auto tApSFA = + cute::make_tensor(cute::make_shape(cute::size<1>(tAsSFA), cute::size<2>(tAsSFA)), + cute::Stride{}); + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<0>(tApSFA); ++m) { + tApSFA(m, 0) = cute::get<0>(tAcSFA(0, m, 0)) < residue_m; // blk_m coord < residue_m + } + + // prefetch gmem A/B + cute::clear(tAsA); + cute::clear(tBsB); + cute::clear(tAsSFA); + cute::clear(tBsSFB); + + int k_tile_count = cute::size<2>(gA); + CUTLASS_PRAGMA_NO_UNROLL + for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe) { + if (k_pipe >= k_tile_count) { + cute::clear(tApA); + cute::clear(tBpB); + cute::clear(tApSFA); + } + auto k_tile_iter = std::min(k_pipe, k_tile_count - 1); + cute::copy_if(g2s_copy_A, tApA, tAgA(cute::_, cute::_, cute::_, k_tile_iter), + tAsA(cute::_, cute::_, cute::_, k_pipe)); + cute::copy_if(g2s_copy_B, tBpB, tBgB(cute::_, cute::_, cute::_, k_tile_iter), + tBsB(cute::_, cute::_, cute::_, k_pipe)); + cute::copy_if(g2s_copy_SFA, tApSFA, tAgSFA(cute::_, cute::_, cute::_, k_tile_iter), + tAsSFA(cute::_, cute::_, cute::_, k_pipe)); + cute::copy(g2s_copy_SFB, tBgSFB(cute::_, cute::_, cute::_, k_tile_iter), + tBsSFB(cute::_, cute::_, cute::_, k_pipe)); + + cute::cp_async_fence(); + } + + typename KT::TiledMma mma; + auto thr_mma = mma.get_slice(threadIdx.x); + auto accum = cute::partition_fragment_C( + mma, cute::make_shape(cute::Int{}, cute::Int{}, + cute::Int{})); // (MMA,MMA_M,MMA_N) + auto temp = cute::partition_fragment_C( + mma, cute::make_shape(cute::Int{}, + cute::Int{})); // (MMA,MMA_M,MMA_N) + + auto mma_shape_A = cute::partition_shape_A( + mma, cute::make_shape(cute::Int{}, cute::Int{})); + auto tCrA = cute::make_tensor(mma_shape_A); + + auto mma_shape_B = cute::partition_shape_B( + mma, cute::make_shape(cute::Int{}, cute::Int{}, + cute::Int{})); + auto tCrB = cute::make_tensor(mma_shape_B); + + auto s2r_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, mma); + auto s2r_thr_copy_A = s2r_copy_A.get_slice(threadIdx.x); + auto tXsA = s2r_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K,Stage) + auto tXrA = s2r_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + static_assert(is_static::value, "tXrA layout must be static"); + + auto s2r_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, mma); + auto s2r_thr_copy_B = s2r_copy_B.get_slice(threadIdx.x); + auto tXsB = s2r_thr_copy_B.partition_S(sB); // (CPY,CPY_N,CPY_K,Stage) + auto tXrB = s2r_thr_copy_B.retile_D(tCrB)(cute::_, cute::Int<0>{}, cute::_, cute::_); + + typename KT::SmemTiledCopySFA s2r_copy_SFA; + typename KT::SmemTiledCopySFB s2r_copy_SFB; + auto s2r_thr_copy_SFA = s2r_copy_SFA.get_slice(threadIdx.x); + auto s2r_thr_copy_SFB = s2r_copy_SFB.get_slice(threadIdx.x); + + auto tXsSFA = s2r_thr_copy_SFA.partition_S(sSFA); + auto tXrSFA = cute::make_fragment_like(tXsSFA(cute::_, cute::_, cute::_, 0)); + auto tXsSFB = s2r_thr_copy_SFB.partition_S(sSFB); + auto tXrSFB = cute::make_fragment_like(tXsSFB(cute::_, cute::_, cute::_, 0)); + auto scale = cute::make_fragment_like(tXrSFA); + + int smem_pipe_write = KT::Stages - 1; + int smem_pipe_read = 0; + + auto tXsA_read = tXsA(cute::_, cute::_, cute::_, smem_pipe_read); + auto tXsB_read = tXsB(cute::_, cute::_, cute::_, smem_pipe_read); + auto tXsSFA_read = tXsSFA(cute::_, cute::_, cute::_, smem_pipe_read); + auto tXsSFB_read = tXsSFB(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + // prefetch smem -> rf + cute::copy(s2r_copy_SFA, tXsSFA_read, tXrSFA); + cute::copy(s2r_copy_SFB, tXsSFB_read, tXrSFB); + cute::copy(s2r_copy_A, tXsA_read, tXrA); + cute::copy(s2r_copy_B, tXsB_read(cute::_, cute::Int<0>{}, cute::_), + tXrB(cute::_, cute::_, cute::Int<0>{})); + + cute::clear(accum); + int k_tile_iter = KT::Stages - 1; + while (k_tile_iter < k_tile_count) { + cute::for_each(cute::make_int_sequence{}, [&](auto n_block) { + if constexpr (n_block == KT::NUM_GROUP_N - 1) { + tXsA_read = tXsA(cute::_, cute::_, cute::_, smem_pipe_read); + tXsB_read = tXsB(cute::_, cute::_, cute::_, smem_pipe_read); + tXsSFA_read = tXsSFA(cute::_, cute::_, cute::_, smem_pipe_read); + tXsSFB_read = tXsSFB(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + cute::copy(s2r_copy_SFA, tXsSFA_read, tXrSFA); + cute::copy(s2r_copy_SFB, tXsSFB_read, tXrSFB); + } + auto n_block_next = (n_block + cute::_1{}) % KT::NUM_GROUP_N; + cute::copy(s2r_copy_B, tXsB_read(cute::_, n_block_next, cute::_), + tXrB(cute::_, cute::_, n_block_next)); + if constexpr (n_block == 0) { + // gmem -> smem + cute::copy_if(g2s_copy_A, tApA, tAgA(cute::_, cute::_, cute::_, k_tile_iter), + tAsA(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy_if(g2s_copy_B, tBpB, tBgB(cute::_, cute::_, cute::_, k_tile_iter), + tBsB(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy_if(g2s_copy_SFA, tApSFA, tAgSFA(cute::_, cute::_, cute::_, k_tile_iter), + tAsSFA(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy(g2s_copy_SFB, tBgSFB(cute::_, cute::_, cute::_, k_tile_iter), + tBsSFB(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::cp_async_fence(); + k_tile_iter++; + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = smem_pipe_read == KT::Stages ? 0 : smem_pipe_read; + cute::for_each(cute::make_int_sequence{}, + [&](auto i) { scale(i) = tXrSFA(i) * tXrSFB(0); }); + } + cute::clear(temp); + cute::gemm(mma, tCrA, tCrB(cute::_, cute::_, cute::_, n_block), temp); + if constexpr (n_block == KT::NUM_GROUP_N - 1) { + cute::copy(s2r_copy_A, tXsA_read, tXrA); + } + promote(accum, temp, scale, n_block); + }); + } + // load tail + cute::for_each(cute::make_int_sequence{}, [&](auto WaitIndex) { + using WaitIndex_t = decltype(WaitIndex); + cute::for_each(cute::make_int_sequence{}, [&](auto n_block) { + if constexpr (n_block == KT::NUM_GROUP_N - 1) { + tXsA_read = tXsA(cute::_, cute::_, cute::_, smem_pipe_read); + tXsB_read = tXsB(cute::_, cute::_, cute::_, smem_pipe_read); + tXsSFA_read = tXsSFA(cute::_, cute::_, cute::_, smem_pipe_read); + tXsSFB_read = tXsSFB(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + cute::copy(s2r_copy_SFA, tXsSFA_read, tXrSFA); + cute::copy(s2r_copy_SFB, tXsSFB_read, tXrSFB); + } + auto n_block_next = (n_block + cute::_1{}) % KT::NUM_GROUP_N; + cute::copy(s2r_copy_B, tXsB_read(cute::_, n_block_next, cute::_), + tXrB(cute::_, cute::_, n_block_next)); + if constexpr (n_block == 0) { + ++smem_pipe_read; + smem_pipe_read = smem_pipe_read == KT::Stages ? 0 : smem_pipe_read; + cute::for_each(cute::make_int_sequence{}, + [&](auto i) { scale(i) = tXrSFA(i) * tXrSFB(0); }); + } + cute::clear(temp); + cute::gemm(mma, tCrA, tCrB(cute::_, cute::_, cute::_, n_block), temp); + if constexpr (n_block == KT::NUM_GROUP_N - 1) { + cute::copy(s2r_copy_A, tXsA_read, tXrA); + } + promote(accum, temp, scale, n_block); + }); + }); + // mma tail + cute::for_each(cute::make_int_sequence{}, [&](auto n_block) { + auto n_block_next = (n_block + cute::_1{}) % KT::NUM_GROUP_N; + cute::copy(s2r_copy_B, tXsB_read(cute::_, n_block_next, cute::_), + tXrB(cute::_, cute::_, n_block_next)); + cute::clear(temp); + if constexpr (n_block == 0) { + cute::for_each(cute::make_int_sequence{}, + [&](auto i) { scale(i) = tXrSFA(i) * tXrSFB(0); }); + } + cute::gemm(mma, tCrA, tCrB(cute::_, cute::_, cute::_, n_block), temp); + promote(accum, temp, scale, n_block); + }); + // epilogue + __syncthreads(); // sync before using store smem + typename KT::SharedStorageStore* store_storage = + reinterpret_cast(SharedStorageBase); + epilogue_with_smem(accum, *store_storage, ptr_output, M, N); + } +}; + +} // namespace ada_blockwise_gemm diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_utils.cuh b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_utils.cuh new file mode 100644 index 0000000000..64496d2cc4 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/ada_blockwise_gemm/sm89_utils.cuh @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +#include +#include +#include +#include + +#include "cute/atom/mma_atom.hpp" + +#define CUTLASS_HOST_TRACE(x) \ + { \ + std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; \ + } + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)) +#define CUTE_ARCH_MMA_F32_SM89_ENABLED +#endif + +namespace cute { + +// MMA 16x8x32 TN +struct SM89_16x8x32_F32F8F8F32_TN { + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void fma(float& d0, float& d1, float& d2, float& d3, uint32_t const& a0, + uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, float const& c0, + float const& c1, float const& c2, float const& c3) { +#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH( + "Attempting to use SM89_16x8x32_F32F8F8F32_TN without " + "CUTE_ARCH_MMA_F32_SM89_ENABLED"); +#endif + } +}; + +template <> +struct MMA_Traits { + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16, _8, _32>; + using ThrID = Layout<_32>; + using ALayout = Layout, Shape<_4, _2, _2>>, + Stride, Stride<_16, _8, _256>>>; + using BLayout = + Layout, Shape<_4, _2>>, Stride, Stride<_8, _128>>>; + using CLayout = SM80_16x8_Row; +}; + +} // namespace cute + +using namespace cute; +using namespace cutlass; +using namespace cutlass::gemm; + +namespace ada_blockwise_gemm { + +template +struct DefaultGemm_TensorOp_MMA; + +template <> +struct DefaultGemm_TensorOp_MMA { + using ArchTag = cutlass::arch::Sm80; + using MMA_Atom_Arch = cute::MMA_Atom; + using ThreadLayoutMNK = cute::Layout>; + using ValLayoutMNK = cute::Tile; + using TiledMma = cute::TiledMMA; +}; + +template <> +struct DefaultGemm_TensorOp_MMA { + using ArchTag = cutlass::arch::Sm89; + using MMA_Atom_Arch = cute::MMA_Atom; + using ThreadLayoutMNK = cute::Layout>; + using ValLayoutMNK = cute::Tile; + using TiledMma = cute::TiledMMA; +}; + +template +struct AdaBlockwiseGemmTraits { + using ElementInput = ElementType; + using ElementOutput = OutElementType; + using ElementAccumulator = float; + using ElementBlockScale = float; + + using index_t = uint32_t; + static_assert(TileM_ % 16 == 0); + static_assert(TileN_ % 32 == 0); + static_assert(TileK_ % 32 == 0); + static constexpr int Stages = Stages_; + static constexpr int kTileM = TileM_; + static constexpr int kTileN = TileN_; + static constexpr int kTileK = TileK_; + + using TileShape = Shape, Int, Int>; + static constexpr int kWarpsCount = 4; + static constexpr int kThreadCount = kWarpsCount * 32; + + static constexpr int ScaleGranularityM = 1; + static constexpr int ScaleGranularityN = 128; + static constexpr int ScaleGranularityK = 128; + + static constexpr int ScaleMsPerTile = (kTileM + ScaleGranularityM - 1) / ScaleGranularityM; + static constexpr int ScaleNsPerTile = (kTileN + ScaleGranularityN - 1) / ScaleGranularityN; + static constexpr int ScaleKsPerTile = (kTileK + ScaleGranularityK - 1) / ScaleGranularityK; + + using ScaleGranularity = + Shape, Int, Int>; + using ScalePerTileShape = Shape, Int, Int>; + + // ====== mma ====== + static constexpr int kMmaPermM = 32; + static constexpr int kMmaPermN = 32; + static constexpr int kMmaPermK = 32; + constexpr static int NUM_GROUP_M = kTileM / kMmaPermM; + constexpr static int NUM_GROUP_N = kTileN / kMmaPermN; + constexpr static int NUM_GROUP_K = kTileK / kMmaPermK; + using MMA_Atom = MMA_Atom; + using AtomLayoutMNK = Layout>; + using PermutationMNK = Tile, Int, Int>; + using TiledMma = TiledMMA; + + // ====== load gmem -> smem ====== + using GmemTiledCopyLoad = + decltype(make_tiled_copy(Copy_Atom, ElementInput>{}, + Layout, Stride<_8, _1>>{}, Layout>{})); + + using GmemTiledCopyA = GmemTiledCopyLoad; + using GmemTiledCopyB = GmemTiledCopyLoad; + + // ====== load smem -> rf ====== + using SmemAtomLayoutLoad = + decltype(composition(Swizzle<3, 4, 3>{}, Layout, Stride<_128, _1>>{})); + using SmemLayoutA = + decltype(tile_to_shape(SmemAtomLayoutLoad{}, Shape, Int, Int>{})); + using SmemLayoutB = + decltype(tile_to_shape(SmemAtomLayoutLoad{}, Shape, Int, Int>{})); + + using SmemCopyAtomLoad = Copy_Atom; + using SmemCopyAtomA = SmemCopyAtomLoad; + using SmemCopyAtomB = SmemCopyAtomLoad; + + // ====== store rf -> smem ====== + using SmemAtomLayoutStore = decltype(composition( + Swizzle<3, 3, 3>{}, + Layout>, Stride<_8, Stride<_1, _64>>>{})); // 8x64 + + using SmemLayoutO = + decltype(tile_to_shape(SmemAtomLayoutStore{}, Shape, Int>{})); + + using SmemCopyAtomR2S = Copy_Atom; + + // ====== store smem -> gmem ====== + using SmemCopyAtomS2R = Copy_Atom, ElementOutput>; + using GmemCopyAtomR2G = SmemCopyAtomS2R; + + using TiledCopyS2G = + decltype(make_tiled_copy(SmemCopyAtomS2R{}, Layout, Stride<_8, _1>>{}, + Layout>{})); // 16x64 + + // ====== load scale gmem -> smem ====== + using GmemCopyAtomScale = + Copy_Atom, ElementBlockScale>; + using GmemLayoutTVSFA = + Layout, Int>, Shape<_1, _1>>, + Stride, Stride<_1, _1>>>; + using GmemTileShapeSFA = Shape, Int>; + using GmemTiledCopySFA = + decltype(make_tiled_copy_impl(GmemCopyAtomScale{}, GmemLayoutTVSFA{}, GmemTileShapeSFA{})); + + using GmemLayoutTVSFB = + Layout, Shape<_1, _1>>, Stride, Stride<_1, _1>>>; + using GmemTileShapeSFB = Shape, Int>; + using GmemTiledCopySFB = + decltype(make_tiled_copy_impl(GmemCopyAtomScale{}, GmemLayoutTVSFB{}, GmemTileShapeSFB{})); + + // ====== load scale smem -> rf ====== + using SmemCopyAtomScale = Copy_Atom, ElementBlockScale>; + using SmemLayoutTVSFA = Layout, Shape<_2>>, + Stride, Stride<_8, _0>>>; + using SmemTileShapeSFA = Shape, _1>; + using SmemTiledCopySFA = + decltype(make_tiled_copy_impl(SmemCopyAtomScale{}, SmemLayoutTVSFA{}, SmemTileShapeSFA{})); + + using SmemLayoutSFA = decltype(tile_to_shape( + make_layout(SmemTileShapeSFA{}), + make_shape(shape<0>(ScalePerTileShape{}), shape<2>(ScalePerTileShape{}), + Int{}))); // BLK_M, BLK_K, Stages + + using SmemLayoutTVSFB = Layout, Shape<_1>>, + Stride, Stride<_0, _0>>>; + using SmemTileShapeSFB = Shape<_1, _1>; + using SmemTiledCopySFB = + decltype(make_tiled_copy_impl(SmemCopyAtomScale{}, SmemLayoutTVSFB{}, SmemTileShapeSFB{})); + + using SmemLayoutSFB = decltype(tile_to_shape( + make_layout(SmemTileShapeSFB{}), + make_shape(shape<1>(ScalePerTileShape{}), shape<2>(ScalePerTileShape{}), + Int{}))); // BLK_N, BLK_K, Stages + + // we need at least 2 stages.. + static_assert(Stages >= 2); + + struct SharedStorageLoad : aligned_struct<128> { + array_aligned> smem_a; + array_aligned> smem_b; + array_aligned> smem_sfa; + array_aligned> smem_sfb; + }; + + struct SharedStorageStore : aligned_struct<128> { + array_aligned> smem_o; + }; + + union SharedStorage { + SharedStorageLoad load; + SharedStorageStore store; + }; + + static constexpr int kSmemSize = static_cast(sizeof(SharedStorage)); +}; + +} // namespace ada_blockwise_gemm diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu new file mode 100644 index 0000000000..12e9d8c5b2 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu @@ -0,0 +1,349 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "fp8_blockscale_gemm.h" +#ifdef ENABLE_FP8_BLOCK_SCALE +#include "fp8_blockscale_gemm_kernel.cuh" +#endif +#include "tensorrt_llm/common/logger.h" + +namespace tensorrt_llm::kernels::fp8_blockscale_gemm { + +template +void CutlassFp8BlockScaleGemmRunner::gemm( + void* mat_d, void const* mat_a, void const* mat_b, int shape_m, int shape_n, int shape_k, + cudaStream_t stream, float const* scales_a, float const* scales_b) { +#ifdef ENABLE_FP8_BLOCK_SCALE + constexpr bool internal_quantize_a = !std::is_same_v; + constexpr bool internal_quantize_b = !std::is_same_v; + __nv_fp8_e4m3* fp8_mat_a; + __nv_fp8_e4m3* fp8_mat_b; + float* per_token_per_128c_scales; + float* per_block_scales; + + auto* ws_ptr = workspace_; + if constexpr (internal_quantize_a || internal_quantize_b) { + TLLM_CHECK(ws_ptr != nullptr); + } + + if constexpr (internal_quantize_a) { + fp8_mat_a = reinterpret_cast<__nv_fp8_e4m3*>(ws_ptr); + ws_ptr += max_shape_m_4_align_ * shape_k * sizeof(__nv_fp8_e4m3); + per_token_per_128c_scales = reinterpret_cast(ws_ptr); + ws_ptr += max_shape_m_4_align_ * div_up(shape_k, 128) * sizeof(float); + } + + if constexpr (internal_quantize_b) { + fp8_mat_b = reinterpret_cast<__nv_fp8_e4m3*>(ws_ptr); + ws_ptr += shape_n * shape_k * sizeof(__nv_fp8_e4m3); + per_block_scales = reinterpret_cast(ws_ptr); + ws_ptr += div_up(shape_n, 128) * div_up(shape_k, 128) * sizeof(float); + } + +#ifdef COMPILE_HOPPER_TMA_GEMMS + if constexpr (internal_quantize_a && internal_quantize_b) { + fp8_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, shape_k, + per_token_per_128c_scales, reinterpret_cast<__nv_bfloat16 const*>(mat_b), + fp8_mat_b, shape_k, per_block_scales, reinterpret_cast<__nv_bfloat16*>(mat_d), + shape_n, shape_m, shape_n, shape_k, stream, internal_quantize_a, + internal_quantize_b); + } + + if constexpr (internal_quantize_a && !internal_quantize_b) { + fp8_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, shape_k, + per_token_per_128c_scales, nullptr, + reinterpret_cast<__nv_fp8_e4m3*>(const_cast(mat_b)), shape_k, + const_cast(scales_b), reinterpret_cast<__nv_bfloat16*>(mat_d), shape_n, + shape_m, shape_n, shape_k, stream, internal_quantize_a, internal_quantize_b); + } +#else // COMPILE_HOPPER_TMA_GEMMS + TLLM_THROW("fp8 blockscale gemm only supported on Hopper."); +#endif // COMPILE_HOPPER_TMA_GEMMS +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template +void CutlassFp8BlockScaleGemmRunner::gemm( + __nv_fp8_e4m3 const* mat_a, int ld_a, __nv_fp8_e4m3 const* mat_b, int ld_b, + __nv_bfloat16* mat_d, int ld_d, int shape_m, int shape_n, int shape_k, float const* scales_a, + float const* scales_b, cudaStream_t stream) { +#ifdef ENABLE_FP8_BLOCK_SCALE + fp8_gemm_run(const_cast<__nv_fp8_e4m3*>(mat_a), ld_a, const_cast<__nv_fp8_e4m3*>(mat_b), ld_b, + mat_d, ld_d, shape_m, shape_n, shape_k, const_cast(scales_a), + const_cast(scales_b), stream); +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template +void CutlassFp8BlockScaleGemmRunner::moeGemm( + void* mat_d, void const* mat_a, void const* mat_b, int64_t const* problem_m_offsets, + size_t num_problems, size_t shape_n, size_t shape_k, cudaStream_t stream, float const* scales_a, + float const* scales_b) { +#ifdef ENABLE_FP8_BLOCK_SCALE + constexpr bool internal_quantize_a = !std::is_same_v; + constexpr bool internal_quantize_b = !std::is_same_v; + + __nv_fp8_e4m3* fp8_mat_a; + float* per_token_per_128c_scales; + __nv_fp8_e4m3* fp8_mat_b; + float* per_block_scales; + + auto* ws_ptr = workspace_; + if constexpr (internal_quantize_a || internal_quantize_b) { + TLLM_CHECK(ws_ptr != nullptr); + } + + if constexpr (internal_quantize_a) { + fp8_mat_a = reinterpret_cast<__nv_fp8_e4m3*>(ws_ptr); + ws_ptr += max_shape_m_4_align_ * shape_k * sizeof(__nv_fp8_e4m3); + per_token_per_128c_scales = reinterpret_cast(ws_ptr); + ws_ptr += max_shape_m_32_align_padded_ * div_up(shape_k, 128) * sizeof(float); + } else { + fp8_mat_a = reinterpret_cast<__nv_fp8_e4m3*>(const_cast(mat_a)); + per_token_per_128c_scales = const_cast(scales_a); + } + + if constexpr (internal_quantize_b) { + fp8_mat_b = reinterpret_cast<__nv_fp8_e4m3*>(ws_ptr); + ws_ptr += num_problems * shape_n * shape_k * sizeof(__nv_fp8_e4m3); + per_block_scales = reinterpret_cast(ws_ptr); + } else { + for (int i = 0; i < num_problems; i++) { + fp8_mat_b = reinterpret_cast<__nv_fp8_e4m3*>(const_cast(mat_b)); + per_block_scales = const_cast(scales_b); + } + } + +#ifdef COMPILE_HOPPER_TMA_GEMMS + if constexpr (std::is_same_v && + std::is_same_v) { + fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, + per_token_per_128c_scales, reinterpret_cast<__nv_bfloat16 const*>(mat_b), + fp8_mat_b, per_block_scales, reinterpret_cast<__nv_bfloat16*>(mat_d), + problem_m_offsets, num_problems, expected_m_, max_shape_m_4_align_, + max_shape_m_32_align_padded_, shape_n, shape_k, stream, + internal_quantize_a, internal_quantize_b); + } else if constexpr (std::is_same_v && + std::is_same_v) { + fp8_grouped_gemm_run(reinterpret_cast<__nv_bfloat16 const*>(mat_a), fp8_mat_a, + per_token_per_128c_scales, nullptr, fp8_mat_b, per_block_scales, + reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, + expected_m_, max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, + shape_k, stream, internal_quantize_a, internal_quantize_b); + } else if constexpr (std::is_same_v && + std::is_same_v) { + fp8_grouped_gemm_run(nullptr, fp8_mat_a, per_token_per_128c_scales, + reinterpret_cast<__nv_bfloat16 const*>(mat_b), fp8_mat_b, per_block_scales, + reinterpret_cast<__nv_bfloat16*>(mat_d), problem_m_offsets, num_problems, + expected_m_, max_shape_m_4_align_, max_shape_m_32_align_padded_, shape_n, + shape_k, stream, internal_quantize_a, internal_quantize_b); + } else { + TLLM_THROW("fp8 blockscale gemm only support __nv_fp8_e4m3 or bfloat16 as dataType."); + } +#else // COMPILE_HOPPER_TMA_GEMMS + TLLM_THROW("fp8 blockscale gemm only support Hopper."); +#endif // COMPILE_HOPPER_TMA_GEMMS +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template +void CutlassFp8BlockScaleGemmRunner::strideBatchGemm( + __nv_bfloat16* mat_d, int ld_d, int stride_d, __nv_fp8_e4m3* mat_a, int ld_a, int stride_a, + __nv_fp8_e4m3* mat_b, int ld_b, int stride_b, int num_problems, int shape_m, int shape_n, + int shape_k, cudaStream_t stream, float* scales_a, int stride_scales_a, float* scales_b) { +#ifdef ENABLE_FP8_BLOCK_SCALE + fp8_stride_batch_gemm_run(nullptr, mat_a, scales_a, ld_a, stride_a, stride_scales_a, nullptr, + mat_b, scales_b, ld_b, stride_b, mat_d, ld_d, stride_d, num_problems, + shape_m, shape_n, shape_k, stream, false, false); +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template +void CutlassFp8BlockScaleGemmRunner::fp8CS1x128( + __nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 const* mat, int shape_x, int shape_y, + cudaStream_t stream) { +#ifdef ENABLE_FP8_BLOCK_SCALE + fp8_1x128_cs(mat_quant, scales, mat, shape_x, shape_y, stream); +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template +void CutlassFp8BlockScaleGemmRunner::fp8CS1x128Reshape( + __nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 const* mat, int shape_x, int shape_h, + int shape_y, int stride_x, cudaStream_t stream) { +#ifdef ENABLE_FP8_BLOCK_SCALE + fp8_1x128_cs_reshape(mat_quant, scales, mat, shape_x, shape_h, shape_y, stride_x, stream); +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template +void CutlassFp8BlockScaleGemmRunner::fp8CS128x128( + __nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 const* mat, int shape_x, int shape_y, + cudaStream_t stream) { +#ifdef ENABLE_FP8_BLOCK_SCALE + fp8_128x128_cs(mat_quant, scales, mat, shape_x, shape_y, stream); +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template +size_t CutlassFp8BlockScaleGemmRunner::getWorkspaceSizeBase( + size_t max_shape_m, size_t shape_n, size_t shape_k, size_t num_problems) { +#ifdef ENABLE_FP8_BLOCK_SCALE + max_shape_m_4_align_ = std::max(max_shape_m_4_align_, int64_t(div_up(max_shape_m, 4) * 4)); + if (expected_m_ == 0) { + expected_m_ = div_up(max_shape_m_4_align_, num_problems); + } + max_shape_m_32_align_padded_ = deep_gemm::compute_padded_offset(max_shape_m, num_problems); + + constexpr bool internal_quantize_a = !std::is_same_v; + constexpr bool internal_quantize_b = !std::is_same_v; + size_t total_workspace_size = 0; + if constexpr (internal_quantize_a) { + // fp8_mat_a + total_workspace_size += max_shape_m_4_align_ * shape_k * sizeof(__nv_fp8_e4m3); + // scales_a + total_workspace_size += max_shape_m_32_align_padded_ * div_up(shape_k, 128) * sizeof(float); + } + + if constexpr (internal_quantize_b) { + // fp8_mat_b + total_workspace_size += num_problems * shape_n * shape_k * sizeof(__nv_fp8_e4m3); + // scales_b + total_workspace_size += + num_problems * div_up(shape_k, 128) * div_up(shape_n, 128) * sizeof(float); + } + + return total_workspace_size; +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); + return 0; +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template +size_t CutlassFp8BlockScaleGemmRunner::getWorkspaceSize( + size_t shape_m, size_t shape_n, size_t shape_k, size_t top_k, size_t num_problems) { +#ifdef ENABLE_FP8_BLOCK_SCALE + expected_m_ = shape_m; + return getWorkspaceSizeBase(shape_m * top_k, shape_n, shape_k, num_problems); +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); + return 0; +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template +size_t CutlassFp8BlockScaleGemmRunner::getFP8DataSize(int shape_m, + int shape_n, + bool is_act) { +#ifdef ENABLE_FP8_BLOCK_SCALE + int shape_m_4_align = div_up(shape_m, 4) * 4; + constexpr bool internal_quantize_a = !std::is_same_v; + constexpr bool internal_quantize_b = !std::is_same_v; + if (is_act && internal_quantize_a) { + return div_up(shape_m_4_align * shape_n * sizeof(__nv_fp8_e4m3), 128) * 128; + } + + if ((!is_act) && internal_quantize_b) { + return div_up(shape_m * shape_n * sizeof(__nv_fp8_e4m3), 128) * 128; + } + return 0; +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); + return 0; +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template +size_t CutlassFp8BlockScaleGemmRunner::getActScaleSize(int shape_m, + int shape_k) { +#ifdef ENABLE_FP8_BLOCK_SCALE + int shape_m_4_align = div_up(shape_m, 4) * 4; + constexpr bool internal_quantize_a = !std::is_same_v; + size_t total_workspace_size = 0; + if constexpr (internal_quantize_a) { + // scales_a + total_workspace_size += + div_up(shape_m_4_align * div_up(shape_k, 128) * sizeof(float), 128) * 128; + } + return total_workspace_size; +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); + return 0; +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template +size_t CutlassFp8BlockScaleGemmRunner::getWeightScaleSize( + int shape_n, int shape_k) { +#ifdef ENABLE_FP8_BLOCK_SCALE + constexpr bool internal_quantize_b = !std::is_same_v; + size_t total_workspace_size = 0; + if constexpr (internal_quantize_b) { + // scales_b + total_workspace_size += + div_up(div_up(shape_k, 128) * div_up(shape_n, 128) * sizeof(float), 128) * 128; + } + + return total_workspace_size; +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); + return 0; +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template +size_t CutlassFp8BlockScaleGemmRunner::getActWorkspaceSize( + int shape_m, int shape_k) { +#ifdef ENABLE_FP8_BLOCK_SCALE + return getFP8DataSize(shape_m, shape_k, true) + getActScaleSize(shape_m, shape_k); +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); + return 0; +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template +size_t CutlassFp8BlockScaleGemmRunner::getWeightWorkspaceSize( + int shape_n, int shape_k) { +#ifdef ENABLE_FP8_BLOCK_SCALE + return getFP8DataSize(shape_n, shape_k, false) + getWeightScaleSize(shape_n, shape_k); +#else // ENABLE_FP8_BLOCK_SCALE + TLLM_THROW("fp8 blockscale gemm only supported on cuda version 12.8 or higher."); + return 0; +#endif // ENABLE_FP8_BLOCK_SCALE +} + +template class CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>; +template class CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>; +template class CutlassFp8BlockScaleGemmRunner<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>; +template class CutlassFp8BlockScaleGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; + +} // namespace tensorrt_llm::kernels::fp8_blockscale_gemm diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h index 6532e6e4b9..040a1463b9 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h @@ -15,10 +15,10 @@ */ #pragma once -// #include #include #include +#include #include // non-persistent-cooperative GEMM diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh new file mode 100644 index 0000000000..6f694b7c8c --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_kernel.cuh @@ -0,0 +1,1676 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh" +#include "fp8_blockscale_mma_utils.cuh" +#include "fp8_blockscale_tma_utils.cuh" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/deep_gemm/fp8_gemm.cuh" + +namespace kernel_utils { + +inline void find_divisor(uint32_t& mul, uint32_t& shr, int x) { + auto find_log_2 = [](int x, bool round_up = false) { + auto clz = [](int x) { + for (int i = 31; i >= 0; --i) { + if ((1 << i) & x) { + return 31 - i; + } + } + return 32; + }; + + int a = 31 - clz(x); + if (round_up) { + a += (x & (x - 1)) ? 1 : 0; + } + return a; + }; + + assert(x != 0); + if (x == 1) { + // If dividing by 1, reduced math doesn't work because mul_coeff would need + // to be 2^32, which doesn't fit into unsigned int. the div() routine + // handles this special case separately. + mul = 0; + shr = 0; + } else { + // To express the division N/D in terms of a multiplication, what we first + // imagine is simply N*(1/D). However, 1/D will always evaluate to 0 (for + // D>1), so we need another way. There's nothing that says we have to use + // exactly the fraction 1/D; instead it could be any X/Y that reduces to 1/D + // (i.e., Y=X*D), or at least to "close enough" to it. If we pick Y that is + // a power of two, then the N*(X/Y) can be N*X followed by a right-shift by + // some amount. The power of two we should pick should be at least 2^32, + // because in the div() routine we'll use umulhi(), which returns only the + // upper 32 bits -- this being equivalent to a right-shift by 32. But we + // might want a higher power of two for better accuracy depending on the + // magnitude of the denominator. Once we've picked Y, then X [our mul_coeff + // value] is simply Y/D, rounding up, and we save shift_coeff as whatever + // further shift we have to do beyond what the umulhi() implies. + uint32_t p = 31 + find_log_2(x, true); + uint32_t m = (uint32_t)(((1ull << p) + (uint32_t)x - 1) / (uint32_t)x); + + mul = m; + shr = p - 32; + } +} + +__device__ __forceinline__ void fast_divmod(uint32_t& div, uint32_t& mod, int x, int y, + uint32_t mul, uint32_t shr) { + if (y == 1) { + div = x; + mod = 0; + } else { + div = __umulhi((uint32_t)x, mul) >> shr; + mod = x - div * y; + } +} + +template +__inline__ __device__ T warpReduceSum(T val) { + constexpr uint32_t FINAL_MASK = 0xffffffff; +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + return val; +} + +template <> +__inline__ __device__ __nv_bfloat16 warpReduceSum(__nv_bfloat16 val) { + constexpr uint32_t FINAL_MASK = 0xffffffff; +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = __hmax(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + return val; +} + +__inline__ __device__ uint32_t elect_one_sync([[maybe_unused]] int lane_id) { + uint32_t pred = 0; +#if __CUDA_ARCH__ >= 900 + uint32_t laneid = 0; + asm volatile( + "\n\ + {\n\ + .reg .b32 %rx;\n\ + .reg .pred %px;\n\ + elect.sync %rx|%px, %2;\n\ + @%px mov.s32 %1, 1;\n\ + mov.s32 %0, %rx;\n\ + }\n\ + " + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); +#else + return lane_id == 0; +#endif + return pred; +} + +} // namespace kernel_utils + +namespace tensorrt_llm::kernels::fp8_blockscale_gemm { + +template +__device__ __host__ constexpr T div_up(T a, int b) { + return (a + b - 1) / b; +} + +using TileShape = std::tuple; +enum class Layout { RowMajor, ColMajor }; +enum class ScaleType { PerTensor, PerBlock, PerChannel, PerSubChannel }; + +template +struct GroupedGemmProblemVisitor { + struct Input { + int64_t const* problem_m_offsets; + }; + + static __host__ __device__ dim3 grid_dim(int shape_m, int shape_n, int num_problems) { + return dim3(div_up(shape_m, TILE_M), div_up(shape_n, TILE_N), num_problems); + } + + static __device__ int tile_m_idx() { return blockIdx.x; } + + static __device__ int tile_n_idx() { return blockIdx.y; } + + static __device__ int problem_idx() { return blockIdx.z; } + + static __device__ int m_offset(Input const& input) { + int problem_idx_ = problem_idx(); + return input.problem_m_offsets[problem_idx_]; + } + + static __device__ int n_offset(Input const& input) { + int problem_idx_ = problem_idx(); + return problem_idx_ * TILE_N * gridDim.y; + } + + static __device__ int m_boundary(Input const& input) { + int problem_idx_ = problem_idx(); + return input.problem_m_offsets[problem_idx_ + 1] - input.problem_m_offsets[problem_idx_]; + } +}; + +template +struct PlainGemmProblemVisitor { + struct Input { + int shape_m; + }; + + static __host__ __device__ dim3 grid_dim(int shape_m, int shape_n) { + return dim3(div_up(shape_m, TILE_M), div_up(shape_n, TILE_N)); + } + + static __device__ int tile_m_idx() { return blockIdx.x; } + + static __device__ int tile_n_idx() { return blockIdx.y; } + + static __device__ int problem_idx() { return 0; } + + static __device__ int m_offset(Input const& input) { return 0; } + + static __device__ int n_offset(Input const& input) { return 0; } + + static __device__ int m_boundary(Input const& input) { return input.shape_m; } +}; + +template +struct StridedBatchedGemmProblemVisitor { + struct Input { + int shape_m; + int ld_a; + int stride_a; + int ld_b; + int stride_b; + int stride_d; + int stride_scales_a; + // stride_a % ld_a must be 0 + // stride_b % ld_b must be 0 + }; + + static __host__ __device__ dim3 grid_dim(int shape_m, int shape_n, int num_problems) { + return dim3(div_up(shape_m, TILE_M), div_up(shape_n, TILE_N), num_problems); + } + + static __device__ int tile_m_idx() { return blockIdx.x; } + + static __device__ int tile_n_idx() { return blockIdx.y; } + + static __device__ int problem_idx() { return blockIdx.z; } + + static __device__ int m_offset(Input const& input) { + int problem_idx_ = problem_idx(); + return input.stride_a / input.ld_a * problem_idx_; + } + + static __device__ int n_offset(Input const& input) { + int problem_idx_ = problem_idx(); + return input.stride_b / input.ld_b * problem_idx_; + } + + static __device__ int m_boundary(Input const& input) { return input.shape_m; } +}; + +namespace cde = cuda::device::experimental; + +template +__global__ void __launch_bounds__(TILE_M == 64 ? 256 : 384, 1) + cooperative_1x128_by_128x128_fp8_gemm_kernel( + ElementD* gmem_d, int ld_d, float const* scales_b, + typename ProblemVisitor::Input problem_input, int shape_n, int shape_k, + __grid_constant__ const CUtensorMap tensor_map_a, + __grid_constant__ const CUtensorMap tensor_map_b, + __grid_constant__ const CUtensorMap tensor_map_scales_a, int guessed_m) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + static_assert(sizeof(ElementA) == 1 && sizeof(ElementB) == 1); + static_assert(TILE_K == 128); + constexpr int ScaleGranMA = 1; + constexpr int ScaleGranKA = 128; + constexpr int ScaleGranNB = 128; + constexpr int ScaleGranKB = 128; + static_assert(TILE_K % ScaleGranKA == 0); + + static constexpr int SMEM_A_SIZE_PER_STAGE = TILE_M * TILE_K * sizeof(ElementA); + static constexpr int SMEM_B_SIZE_PER_STAGE = TILE_N * TILE_K * sizeof(ElementB); + static constexpr int SMEM_SCALES_A_SIZE_PER_STAGE = + div_up(TILE_M, ScaleGranMA) * div_up(TILE_K, ScaleGranKA) * sizeof(float); + static constexpr bool IS_UNIFORM_SCALE_B = ScaleGranNB % TILE_N == 0; + + constexpr int BLOCK_SIZE = TILE_M == 64 ? 256 : 384; + constexpr int TMA_ISSUE_INTERVAL = 1; + using Barrier = cuda::barrier; + + int tile_m_idx = ProblemVisitor::tile_m_idx(); + int m_boundary = ProblemVisitor::m_boundary(problem_input); + if (tile_m_idx * TILE_M >= m_boundary) return; + + int tile_n_idx = ProblemVisitor::tile_n_idx(); + int problem_idx = ProblemVisitor::problem_idx(); + int problem_m_offset = ProblemVisitor::m_offset(problem_input); + int problem_m_padded_offset = 0; + if constexpr (std::is_same_v>) { + problem_m_padded_offset = deep_gemm::compute_padded_offset(problem_m_offset, problem_idx); + } + int problem_n_offset = ProblemVisitor::n_offset(problem_input); + + int scales_b_ld = ScaleGranKB != 0 ? div_up(shape_k, ScaleGranKB) : 1; + scales_b += problem_idx * div_up(shape_n, ScaleGranNB) * scales_b_ld; + + int iters_in_former_scales_b = TILE_N / 8; // assuming divisible + if constexpr (ScaleGranNB != 0) { + scales_b += ((tile_n_idx * TILE_N) / ScaleGranNB) * scales_b_ld; + iters_in_former_scales_b = + min(TILE_N, ScaleGranNB - (tile_n_idx * TILE_N) % ScaleGranNB) / 8; // assuming divisible + } + + // Align to 1024 byte for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + ElementA* smem_a[NUM_STAGES]; + ElementB* smem_b[NUM_STAGES]; + float* smem_scales_a[NUM_STAGES]; + + Barrier* full_bars[NUM_STAGES]; + // NUM_EMPTY_BARS must be a const expression, otherwise it will cost too many registers. + constexpr int NUM_EMPTY_BARS = div_up(NUM_STAGES, TMA_ISSUE_INTERVAL); + Barrier* empty_bars[NUM_EMPTY_BARS]; + + float* smem_scales_b; + + for (int i = 0; i < NUM_STAGES; i++) { + smem_a[i] = reinterpret_cast(smem_buffer + i * SMEM_A_SIZE_PER_STAGE); + smem_b[i] = reinterpret_cast(smem_buffer + NUM_STAGES * SMEM_A_SIZE_PER_STAGE + + i * SMEM_B_SIZE_PER_STAGE); + smem_scales_a[i] = reinterpret_cast( + smem_buffer + NUM_STAGES * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + i * SMEM_SCALES_A_SIZE_PER_STAGE); + full_bars[i] = + reinterpret_cast(smem_buffer + + NUM_STAGES * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + + SMEM_SCALES_A_SIZE_PER_STAGE) + + i * sizeof(Barrier)); + } + for (int i = 0; i < NUM_EMPTY_BARS; i++) { + empty_bars[i] = i ? empty_bars[i - 1] + 1 : full_bars[NUM_STAGES - 1] + 1; + } + smem_scales_b = reinterpret_cast(empty_bars[NUM_EMPTY_BARS - 1] + 1); + + int lane_predicate = cute::elect_one_sync(); + if (threadIdx.x < 32 && lane_predicate == 1) { + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); + cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); + cute::prefetch_tma_descriptor( + reinterpret_cast(&tensor_map_scales_a)); + + for (int i = 0; i < NUM_STAGES; i++) { + init(full_bars[i], 1); + } + for (int i = 0; i < NUM_EMPTY_BARS; i++) { + init(empty_bars[i], BLOCK_SIZE - 128); + } + cutlass::arch::fence_view_async_shared(); + } + int math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128 - 1, 0); + + float scale_b_r, scale_b_r_second_part; + if constexpr (ScaleGranKB != 0) { + int end_index = !IS_UNIFORM_SCALE_B && iters_in_former_scales_b < TILE_N / 8 ? scales_b_ld * 2 + : scales_b_ld; +#pragma unroll + for (int i = threadIdx.x; i < end_index; i += BLOCK_SIZE) { + float gmem_scale_b = __ldg(scales_b + i); + asm volatile("st.shared.f32 [%0], %1;" ::"l"(smem_scales_b + i), "f"(gmem_scale_b)); + } + } else { + scale_b_r = scales_b[0]; + } + + __syncthreads(); + + while (true) { + constexpr int NUM_ACCUMS = WGMMA_OP::NUM_ACCUM; + float accum[NUM_ACCUMS] = {0}; + float final_accum[NUM_ACCUMS] = {0}; + constexpr int K_PER_ITER = NUM_STAGES * TILE_K; + + if (threadIdx.x < 128) { + for (int k_iter = 0; k_iter < div_up(shape_k, K_PER_ITER); k_iter++) { + auto copy_func = [&](Barrier& empty_bar, int stage_range_start, int stage_range_end) { + empty_bar.wait_parity(k_iter + 1 & 1); + for (int i = stage_range_start; i < stage_range_end; i++) { + auto& bar = *full_bars[i]; + int k_idx = k_iter * K_PER_ITER + i * TILE_K; + cde::cp_async_bulk_tensor_2d_global_to_shared( + smem_a[i], &tensor_map_a, k_idx, tile_m_idx * TILE_M + problem_m_offset, bar); + cde::cp_async_bulk_tensor_2d_global_to_shared( + smem_b[i], &tensor_map_b, k_idx, tile_n_idx * TILE_N + problem_n_offset, bar); + if constexpr (std::is_same_v>) { + int scale_y_offset = problem_idx * (problem_input.stride_scales_a / + (div_up(problem_input.shape_m, 4) * 4)); + // The scales has been aligned to 16 bytes + cde::cp_async_bulk_tensor_2d_global_to_shared( + smem_scales_a[i], &tensor_map_scales_a, (tile_m_idx * TILE_M) / ScaleGranMA, + scale_y_offset + k_idx / ScaleGranKA, bar); + } else { + // The scales has been aligned to 16 bytes + cde::cp_async_bulk_tensor_2d_global_to_shared( + smem_scales_a[i], &tensor_map_scales_a, + (problem_m_padded_offset + tile_m_idx * TILE_M) / ScaleGranMA, + k_idx / ScaleGranKA, bar); + } + } + for (int i = stage_range_start; i < stage_range_end; i++) { + auto no_use = mbarrier_arrive_1_expect_tx_cta( + full_bars[i], + SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE); + } + }; + if (threadIdx.x == 0) { + int num_stages = div_up((shape_k - k_iter * K_PER_ITER), TILE_K); + for (int i = 0; i < NUM_EMPTY_BARS; i++) { + int range_start = i * TMA_ISSUE_INTERVAL; + int range_end = (i + 1) * TMA_ISSUE_INTERVAL; + range_end = range_end > NUM_STAGES ? NUM_STAGES : range_end; + range_end = range_end > num_stages ? num_stages : range_end; + copy_func(*empty_bars[i], range_start, range_end); + } + } + } + } else { + int thr_id_in_wg = threadIdx.x % 128; + int base_r = thr_id_in_wg / 32 * 16 + thr_id_in_wg % 32 / 4; + int r_0 = base_r + math_wg_idx * WGMMA_OP::M; + int r_1 = base_r + math_wg_idx * WGMMA_OP::M + 8; + + struct DivisibleK {}; + + struct NotDivisibleK {}; + + auto mma_func = [&](int k_iter, auto type) { + constexpr bool K_IS_DIVISIBLE = std::is_same_v ? true : false; + int num_stages; + if constexpr (K_IS_DIVISIBLE) { + num_stages = NUM_STAGES; + } else { + num_stages = div_up(shape_k % K_PER_ITER, TILE_K); + num_stages = !num_stages ? NUM_STAGES : num_stages; + } + +#pragma unroll + for (int s = 0; s < num_stages; s++) { + if constexpr (ScaleGranKB != 0) { + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(scale_b_r) : "l"(smem_scales_b)); + if (!IS_UNIFORM_SCALE_B && iters_in_former_scales_b < TILE_N / 8) { + asm volatile("ld.shared.f32 %0, [%1];" + : "=f"(scale_b_r_second_part) + : "l"(smem_scales_b + scales_b_ld)); + } + smem_scales_b++; + } + (*full_bars[s]).wait_parity(k_iter & 1); + for (int _ = 0; _ < NUM_ACCUMS; _++) { + warpgroup_fence_operand(accum[_]); + } + warpgroup_arrive(); + for (int k = 0; k < TILE_K / WGMMA_OP::K; k++) { + auto desc_a = + make_smem_desc(smem_a[s] + math_wg_idx * WGMMA_OP::M * TILE_K + k * WGMMA_OP::K, 1); + auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA_OP::K, 1); + WGMMA_OP::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + for (int _ = 0; _ < NUM_ACCUMS; _++) { + warpgroup_fence_operand(accum[_]); + } + warpgroup_wait<0>(); + + float scale_0 = smem_scales_a[s][r_0] * scale_b_r; + float scale_1 = smem_scales_a[s][r_1] * scale_b_r; + + bool cross_0 = tile_m_idx * TILE_M + r_0 >= m_boundary; + bool cross_1 = tile_m_idx * TILE_M + r_1 >= m_boundary; + + if (cross_0) { + scale_0 = 0; + } + if (cross_1) { + scale_1 = 0; + } + + if constexpr (K_IS_DIVISIBLE) { + if (s % TMA_ISSUE_INTERVAL == TMA_ISSUE_INTERVAL - 1 || s == NUM_STAGES - 1) { + int tma_group_idx = s / TMA_ISSUE_INTERVAL; + auto no_use = (*empty_bars[tma_group_idx]).arrive(); + } + } + + float scale_0_second_part = smem_scales_a[s][r_0] * scale_b_r_second_part; + float scale_1_second_part = smem_scales_a[s][r_1] * scale_b_r_second_part; + + if (!IS_UNIFORM_SCALE_B && iters_in_former_scales_b < TILE_N / 8) { + for (int i = 0; i < iters_in_former_scales_b; i++) { + final_accum[i * 4 + 0] += scale_0 * accum[i * 4]; + final_accum[i * 4 + 1] += scale_0 * accum[i * 4 + 1]; + } + for (int i = 0; i < iters_in_former_scales_b; i++) { + final_accum[i * 4 + 2] += scale_1 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1 * accum[i * 4 + 3]; + } + + for (int i = iters_in_former_scales_b; i < WGMMA_OP::NUM_ACCUM / 4; i++) { + final_accum[i * 4 + 0] += scale_0_second_part * accum[i * 4]; + final_accum[i * 4 + 1] += scale_0_second_part * accum[i * 4 + 1]; + } + for (int i = iters_in_former_scales_b; i < WGMMA_OP::NUM_ACCUM / 4; i++) { + final_accum[i * 4 + 2] += scale_1_second_part * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1_second_part * accum[i * 4 + 3]; + } + } else { + for (int i = 0; i < WGMMA_OP::NUM_ACCUM / 4; i++) { + final_accum[i * 4 + 0] += scale_0 * accum[i * 4]; + final_accum[i * 4 + 1] += scale_0 * accum[i * 4 + 1]; + } + for (int i = 0; i < WGMMA_OP::NUM_ACCUM / 4; i++) { + final_accum[i * 4 + 2] += scale_1 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_1 * accum[i * 4 + 3]; + } + } + } + }; + + int num_iterations = div_up(shape_k, K_PER_ITER); + for (int k_iter = 0; k_iter < num_iterations - 1; k_iter++) { + mma_func(k_iter, DivisibleK{}); + } + mma_func(num_iterations - 1, NotDivisibleK{}); + } + + if constexpr (LayoutD == Layout::RowMajor) { + __syncthreads(); + ElementD* smem_c = reinterpret_cast(smem_buffer); + constexpr int SMEM_C_PADDING = 8; + + if (threadIdx.x >= 128) { + int thr_id_in_wg = threadIdx.x % 128; + int base_r = thr_id_in_wg / 32 * 16 + thr_id_in_wg % 32 / 4; + int base_c = thr_id_in_wg % 4 * 2; + int r_0 = base_r; + int r_1 = base_r + 8; + int c_0 = base_c; + + for (int i = 0; i < WGMMA_OP::NUM_ACCUM / 4; i++) { + int c_1 = c_0 + 1; + smem_c[(r_0 + math_wg_idx * WGMMA_OP::M) * (TILE_N + SMEM_C_PADDING) + c_0] = + static_cast(final_accum[i * 4]); + smem_c[(r_0 + math_wg_idx * WGMMA_OP::M) * (TILE_N + SMEM_C_PADDING) + c_1] = + static_cast(final_accum[i * 4 + 1]); + smem_c[(r_1 + math_wg_idx * WGMMA_OP::M) * (TILE_N + SMEM_C_PADDING) + c_0] = + static_cast(final_accum[i * 4 + 2]); + smem_c[(r_1 + math_wg_idx * WGMMA_OP::M) * (TILE_N + SMEM_C_PADDING) + c_1] = + static_cast(final_accum[i * 4 + 3]); + c_0 += 8; + } + } + __syncthreads(); + ElementD* gmem_d_this_block; + if constexpr (std::is_same_v>) { + gmem_d_this_block = + gmem_d + problem_idx * problem_input.stride_d + (tile_m_idx * TILE_M) * ld_d; + } else { + gmem_d_this_block = gmem_d + (problem_m_offset + tile_m_idx * TILE_M) * ld_d; + } + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + constexpr int int4_per_tile_line = TILE_N * sizeof(ElementD) / sizeof(int4); + // assert(shape_n * sizeof(ElementD) % sizeof(int4) == 0) + int int4_per_global_line = shape_n * sizeof(ElementD) / sizeof(int4); + constexpr int num_lines = TILE_M; + constexpr int num_warps = BLOCK_SIZE / 32; + int4* smem_c_int4 = reinterpret_cast(smem_c); + bool is_last_tile_n = (tile_n_idx + 1) * TILE_N > shape_n; + int int4_per_line = + is_last_tile_n ? int4_per_global_line % int4_per_tile_line : int4_per_tile_line; + + for (int line_idx = warp_idx; line_idx < num_lines; line_idx += num_warps) { + if (tile_m_idx * TILE_M + line_idx >= m_boundary) { + break; + } + for (int elem_idx = lane_idx; elem_idx < int4_per_line; elem_idx += 32) { + int4* g_data_addr = + reinterpret_cast(&gmem_d_this_block[line_idx * ld_d + tile_n_idx * TILE_N]) + + elem_idx; + int4* s_data_addr = + &smem_c_int4[line_idx * (int4_per_tile_line + + SMEM_C_PADDING * sizeof(ElementD) / sizeof(int4)) + + elem_idx]; + *g_data_addr = *s_data_addr; + } + __syncwarp(); + } + } else if constexpr (LayoutD == Layout::ColMajor) { + } + + if constexpr (!IsPersistentKernel) { + return; + } + + tile_m_idx += guessed_m / TILE_M; + if (tile_m_idx * TILE_M >= m_boundary) return; + + if (threadIdx.x < 32 && lane_predicate == 1) { + for (int i = 0; i < NUM_STAGES; i++) { + full_bars[i]->~Barrier(); + init(full_bars[i], 1); + } + for (int i = 0; i < NUM_EMPTY_BARS; i++) { + empty_bars[i]->~Barrier(); + init(empty_bars[i], BLOCK_SIZE - 128); + } + cutlass::arch::fence_view_async_shared(); + } + __syncthreads(); + smem_scales_b = reinterpret_cast(empty_bars[NUM_EMPTY_BARS - 1] + 1); + } +#else + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("This kernel requires SM90a\n"); + asm volatile("trap;"); + } +#endif +} + +template +class Fp8Gemm { + public: + static constexpr int MAX_SHAPE_K = 20480; + + private: + using Barrier = cuda::barrier; + static constexpr int SMEM_A_SIZE_PER_STAGE = TILE_M * TILE_K * sizeof(ElementA); + static constexpr int SMEM_B_SIZE_PER_STAGE = TILE_N * TILE_K * sizeof(ElementB); + static constexpr bool IS_UNIFORM_SCALE_B = ScaleGranNB % TILE_N == 0; + + public: + static constexpr int get_smem_size(int num_stages, int max_shape_k = MAX_SHAPE_K) { + auto smem_size = num_stages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + sizeof(Barrier) + + sizeof(Barrier)); + + if constexpr (ScaleTypeA == ScaleType::PerSubChannel) { + auto scale_smem_size = num_stages * div_up(TILE_M, ScaleGranMA) * + div_up(TILE_K, ScaleGranKA) * sizeof(ElementScalar); + smem_size += scale_smem_size; + } + if constexpr (ScaleTypeB != ScaleType::PerTensor) { + auto scale_smem_size = + (IS_UNIFORM_SCALE_B ? 1 : 2) * div_up(max_shape_k, ScaleGranKB) * sizeof(ElementScalar); + smem_size += scale_smem_size; + } + return smem_size; + } + + private: + static constexpr int get_num_stages() { + constexpr auto sm90_capacity = 232448; + + if constexpr (get_smem_size(8) <= sm90_capacity) return 8; + if constexpr (get_smem_size(7) <= sm90_capacity) return 7; + if constexpr (get_smem_size(6) <= sm90_capacity) return 6; + if constexpr (get_smem_size(5) <= sm90_capacity) return 5; + static_assert(get_smem_size(4) <= sm90_capacity, + "The required shared memory size is too large"); + return 4; + } + + static constexpr int NUM_STAGES = NUM_OF_STAGES == 0 ? get_num_stages() : NUM_OF_STAGES; + static constexpr int BLOCK_SIZE = TILE_M == 64 ? 256 : 384; + + public: + Fp8Gemm() { + static_assert( + !(ScaleTypeA == ScaleType::PerSubChannel && (ScaleGranMA == 0 || ScaleGranKA == 0))); + static_assert(TILE_M % ScaleGranMA == 0 && TILE_K % ScaleGranKA == 0); + } + + // GroupedGemm + static void run(ElementA* gmem_a, ElementB* gmem_b, ElementD* gmem_d, ElementScalar* scales_a, + ElementScalar const* scales_b, int num_problems, int64_t const* problem_m_offsets, + int shape_n, int shape_k, int max_shape_m, cudaStream_t stream = 0, + int guessed_m = TILE_M, int max_shape_m_padded = 0) { + using ProblemVisitor = GroupedGemmProblemVisitor; + // Need a factory for selecting WGMMA_OP, need to add E5M2 op if needed. + using WGMMA_OP = typename Fp8MmaSelector::Type; +#define Kernel \ + cooperative_1x128_by_128x128_fp8_gemm_kernel + assert(shape_n % TILE_N == 0); + auto tma_a_desc = make_2d_tma_a_desc(gmem_a, max_shape_m, shape_k); + auto tma_b_desc = make_2d_tma_b_desc(gmem_b, shape_k, num_problems * shape_n); + auto tma_scales_a_desc = make_2d_tma_scales_a_desc(scales_a, max_shape_m_padded, shape_k); + static_assert(TILE_N == WGMMA_OP::N); + guessed_m = div_up(guessed_m, TILE_M) * TILE_M; + int smem_size = get_smem_size(NUM_STAGES, shape_k); + cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + typename ProblemVisitor::Input problem_input{problem_m_offsets}; + auto grid_size = ProblemVisitor::grid_dim(guessed_m, shape_n, num_problems); + + Kernel<<>>(gmem_d, shape_n, scales_b, problem_input, + shape_n, shape_k, tma_a_desc, tma_b_desc, + tma_scales_a_desc, guessed_m); +#undef Kernel + } + + // PlainGemm + static void run(ElementA* gmem_a, int ld_a, ElementB* gmem_b, int ld_b, ElementD* gmem_d, + int ld_d, ElementScalar* scales_a, ElementScalar const* scales_b, int shape_m, + int shape_n, int shape_k, cudaStream_t stream = 0, int guessed_m = TILE_M) { + using ProblemVisitor = PlainGemmProblemVisitor; + // Need a factory for selecting WGMMA_OP, need to add E5M2 op if needed. + using WGMMA_OP = typename Fp8MmaSelector::Type; +#define Kernel \ + cooperative_1x128_by_128x128_fp8_gemm_kernel + assert(shape_n % TILE_N == 0); + auto tma_a_desc = make_2d_tma_a_desc(gmem_a, shape_m, shape_k, ld_a * sizeof(*gmem_a)); + auto tma_b_desc = make_2d_tma_b_desc(gmem_b, shape_k, shape_n, ld_b * sizeof(*gmem_b)); + auto tma_scales_a_desc = make_2d_tma_scales_a_desc(scales_a, div_up(shape_m, 4) * 4, shape_k); + static_assert(TILE_N == WGMMA_OP::N); + guessed_m = div_up(guessed_m, TILE_M) * TILE_M; + int smem_size = get_smem_size(NUM_STAGES, shape_k); + cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + typename ProblemVisitor::Input problem_input{shape_m}; + auto grid_size = ProblemVisitor::grid_dim(guessed_m, shape_n); + + Kernel<<>>(gmem_d, ld_d, scales_b, problem_input, + shape_n, shape_k, tma_a_desc, tma_b_desc, + tma_scales_a_desc, guessed_m); +#undef Kernel + } + + // StridedBatchedGemm + static void run(ElementA* gmem_a, int ld_a, int stride_a, ElementB* gmem_b, int ld_b, + int stride_b, ElementD* gmem_d, int ld_d, int stride_d, ElementScalar* scales_a, + int stride_scales_a, ElementScalar const* scales_b, int shape_m, int shape_n, + int shape_k, int num_problems, cudaStream_t stream = 0) { + using ProblemVisitor = StridedBatchedGemmProblemVisitor; + // Need a factory for selecting WGMMA_OP, need to add E5M2 op if needed. + using WGMMA_OP = typename Fp8MmaSelector::Type; +#define Kernel \ + cooperative_1x128_by_128x128_fp8_gemm_kernel + assert(shape_n % TILE_N == 0); + auto tma_a_desc = + make_2d_tma_a_desc(gmem_a, shape_m * num_problems, shape_k, ld_a * sizeof(*gmem_a)); + auto tma_b_desc = + make_2d_tma_b_desc(gmem_b, shape_k, shape_n * num_problems, ld_b * sizeof(*gmem_b)); + auto tma_scales_a_desc = make_2d_tma_scales_a_desc(scales_a, shape_m, shape_k, num_problems); + static_assert(TILE_N == WGMMA_OP::N); + typename ProblemVisitor::Input problem_input{shape_m, ld_a, stride_a, ld_b, + stride_b, stride_d, stride_scales_a}; + + int guessed_m = div_up(shape_m, TILE_M) * TILE_M; + int smem_size = get_smem_size(NUM_STAGES, shape_k); + cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + auto grid_size = ProblemVisitor::grid_dim(shape_m, shape_n, num_problems); + + Kernel<<>>(gmem_d, ld_d, scales_b, problem_input, + shape_n, shape_k, tma_a_desc, tma_b_desc, + tma_scales_a_desc, guessed_m); +#undef Kernel + } + + template + static CUtensorMap make_2d_tma_a_desc(T* global_address, uint64_t gmem_rows, uint64_t gmem_cols, + uint64_t global_stride_in_bytes = 0) { + return make_2d_tma_desc(global_address, LayoutA, gmem_rows, gmem_cols, global_stride_in_bytes, + TILE_M, TILE_K); + } + + template + static CUtensorMap make_2d_tma_b_desc(T* global_address, uint64_t gmem_rows, uint64_t gmem_cols, + uint64_t global_stride_in_bytes = 0) { + return make_2d_tma_desc(global_address, LayoutB, gmem_rows, gmem_cols, global_stride_in_bytes, + TILE_K, TILE_N); + } + + template + static CUtensorMap make_2d_tma_scales_a_desc(T* global_address, uint64_t shape_m, + uint64_t shape_k, int num_problems = 1, + uint64_t global_stride_in_bytes = 0) { + static_assert(TILE_M % ScaleGranMA == 0); + static_assert(TILE_K % ScaleGranKA == 0); + + constexpr auto tma_alignment_bytes = 16; + constexpr auto alignment = tma_alignment_bytes / sizeof(T); + static_assert(sizeof(T) * alignment == tma_alignment_bytes); + + shape_m = div_up(shape_m, alignment) * alignment; + return make_2d_tma_desc(global_address, Layout::ColMajor, div_up(shape_m, ScaleGranMA), + div_up(shape_k, ScaleGranKA) * num_problems, global_stride_in_bytes, + TILE_M / ScaleGranMA, TILE_K / ScaleGranKA, + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); + } + + template + static CUtensorMap make_2d_tma_desc( + T* global_address, Layout layout, uint64_t gmem_rows, uint64_t gmem_cols, + uint64_t global_stride_in_bytes, int smem_rows, int smem_cols, + CUtensorMapSwizzle swizzle_type = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B, + int smem_padding = 0) { + if (layout == Layout::RowMajor) { + uint64_t gmem_dim[2] = {gmem_cols, gmem_rows}; + uint32_t smem_dim[2] = {uint32_t(smem_cols), uint32_t(smem_rows)}; + if (!global_stride_in_bytes) { + global_stride_in_bytes = gmem_cols * sizeof(T); + } + return make_2d_tma_copy_desc(global_address, gmem_dim, global_stride_in_bytes, smem_dim, + swizzle_type); + } else { + uint64_t gmem_dim[2] = {gmem_rows, gmem_cols}; + uint32_t smem_dim[2] = {uint32_t(smem_rows), uint32_t(smem_cols)}; + + if (!global_stride_in_bytes) { + global_stride_in_bytes = gmem_rows * sizeof(T); + } + return make_2d_tma_copy_desc(global_address, gmem_dim, global_stride_in_bytes, smem_dim, + swizzle_type); + } + } +}; + +template +__forceinline__ __device__ T find_max_elem_in_warp(T value) { + for (int offset = 16; offset > 0; offset /= 2) { + value = T(std::max(float(value), __shfl_down_sync(0xFFFFFFFF, float(value), offset))); + } + value = T(__shfl_sync(0xffffffff, float(value), 0)); + return value; +} + +template +__global__ void scale_1x128_kernel(OutputType* output, ScaleType* scales, + InputType const* const input, int dim_x, int dim_y) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)) + size_t scales_along_dim_x = div_up(dim_x, 128); + size_t scales_along_dim_y = div_up(dim_y, 1); + size_t stride_scale_dim_y = div_up(dim_y, 4) * 4; + using Input2Type = + typename std::conditional::value, half2, __nv_bfloat162>::type; + for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + warp_idx < scales_along_dim_x * scales_along_dim_y; + warp_idx += gridDim.x * blockDim.x / 32) { + int scales_idx_y = warp_idx / scales_along_dim_x; + int scales_idx_x = warp_idx % scales_along_dim_x; + + InputType const* input_line = input + (size_t)scales_idx_y * dim_x + scales_idx_x * 128; + InputType input_amax = InputType(0); + // Each thread reads 2 elements from input_line + int lane_id = threadIdx.x % 32 * 2; + + Input2Type input_frag2[2] = {Input2Type(0, 0), Input2Type(0, 0)}; +#pragma unroll + for (int i = 0; i < 2; i++) { + if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x) { + break; + } else { + input_frag2[i] = *((Input2Type*)(input_line) + lane_id / 2); + } + input_line += 64; + } +#pragma unroll + for (int i = 0; i < 2; i++) { + if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x) { + break; + } else { + input_amax = InputType( + __hmax(input_amax, __hmax(__habs(input_frag2[i].x), __habs(input_frag2[i].y)))); + } + } + + InputType amax = find_max_elem_in_warp(input_amax); + ScaleType scale = amax != InputType(0.f) ? 448.f / ScaleType(amax) : 1.f; + + if (lane_id == 0) { + scales[(size_t)scales_idx_x * stride_scale_dim_y + scales_idx_y] = ScaleType(1.f / scale); + } + + OutputType* output_line = output + (size_t)scales_idx_y * dim_x + scales_idx_x * 128; +#pragma unroll + for (int i = 0; i < 2; i++) { + if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x) { + break; + } else { + ScaleType value_1 = ScaleType(input_frag2[i].x) * scale; + ScaleType value_2 = ScaleType(input_frag2[i].y) * scale; + output_line[lane_id] = OutputType(value_1); + output_line[lane_id + 1] = OutputType(value_2); + } + output_line += 64; + } + } +#endif +} + +template +__global__ void scale_1x128_kernel(OutputType* output, float* scales, InputType const* input, + int64_t const* problem_m_offsets, int num_problems, int dim_x, + int64_t scale_leading_dim, uint32_t scale_dim_x_mul, + uint32_t scale_dim_x_shr) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + extern __shared__ char shared_memory[]; + int64_t* smem_problem_m_boundaries = reinterpret_cast(shared_memory); + + // problem_m_offsets[0] is omitted because its value is known to be 0 + for (int i = threadIdx.x; i < num_problems; i += blockDim.x) { + smem_problem_m_boundaries[i] = problem_m_offsets[i + 1]; + } + __syncthreads(); + + size_t scales_along_dim_x = div_up(dim_x, 128); + size_t scales_along_dim_y = smem_problem_m_boundaries[num_problems - 1]; + size_t total_scales = scales_along_dim_x * scales_along_dim_y; + + int problem_idx = 0; + int64_t padded_offset = 0; + int64_t boundary_left, boundary_right; + if constexpr (UseBinarySearch) { + boundary_left = smem_problem_m_boundaries[0]; + boundary_right = scales_along_dim_y; + } else { + boundary_left = 0; + boundary_right = smem_problem_m_boundaries[0]; + } + + for (size_t warp_idx = (threadIdx.x + blockIdx.x * blockDim.x) / 32; warp_idx < total_scales; + warp_idx += (blockDim.x * gridDim.x) / 32) { + uint32_t scales_idx_y; // = warp_idx / scales_along_dim_x; + uint32_t scales_idx_x; // = warp_idx % scales_along_dim_x; + kernel_utils::fast_divmod(scales_idx_y, scales_idx_x, warp_idx, scales_along_dim_x, + scale_dim_x_mul, scale_dim_x_shr); + + if constexpr (UseBinarySearch) { + int idx_right = num_problems - 1; + int64_t val_right = boundary_right; + if (scales_idx_y >= boundary_left) { + while (problem_idx + 1 < idx_right) { + int idx_mid = (problem_idx + idx_right) >> 1; + int64_t val_mid = smem_problem_m_boundaries[idx_mid]; + if (scales_idx_y < val_mid) { + idx_right = idx_mid; + val_right = val_mid; + } else { + problem_idx = idx_mid; + boundary_left = val_mid; + } + } + padded_offset = + deep_gemm::compute_padded_offset(boundary_left, problem_idx + 1) - boundary_left; + boundary_left = val_right; + } + } else { + if (boundary_right <= scales_idx_y) { + while (problem_idx < num_problems - 1) { + boundary_left = boundary_right; + boundary_right = smem_problem_m_boundaries[++problem_idx]; + if (scales_idx_y < boundary_right) { + break; + } + } + padded_offset = + deep_gemm::compute_padded_offset(boundary_left, problem_idx) - boundary_left; + } + } + + auto warp_offset = (size_t)scales_idx_y * dim_x + scales_idx_x * 128; + InputType const* input_line = input + warp_offset; + OutputType* output_line = output + warp_offset; + auto& scale_output = + scales[(size_t)scales_idx_x * scale_leading_dim + scales_idx_y + padded_offset]; + + int lane_id = threadIdx.x % 32; + InputType input_frag[4]; + + for (int i = 0; i < 4; i++) { + input_frag[i] = + (scales_idx_x * 128 + i * 32 + lane_id < dim_x) ? input_line[lane_id] : InputType(0); + input_line += 32; + } + + InputType amax = kernel_utils::warpReduceSum( + max(max(fabs(float(input_frag[0])), fabs(float(input_frag[1]))), + max(fabs(float(input_frag[2])), fabs(float(input_frag[3]))))); + + // Half seems to be slower, probably because we need float values below + // anyway. InputType amax = kernel_utils::warpReduceSum( + // __hmax(__hmax(__habs(input_frag[0]), __habs(input_frag[1])), + // __hmax(__habs(input_frag[2]), __habs(input_frag[3])))); + + float scale = amax != InputType(0.f) ? 448.f / float(amax) : 1.f; + + if (kernel_utils::elect_one_sync(lane_id)) { + scale_output = float(1.f / scale); + } + + for (int i = 0; i < 4; i++) { + float value = float(input_frag[i]) * scale; + if (scales_idx_x * 128 + i * 32 + lane_id < dim_x) { + output_line[lane_id] = OutputType(value); + } + output_line += 32; + } + } +#endif +} + +// input: [dim_y, dim_h, dim_x] +// output: [dim_h, dim_y, dim_x], cs[dim_h, dim_x/128, padding(dim_y)] +template +__global__ void scale_1x128_reshape_kernel(OutputType* output, ScaleType* scales, + InputType const* const input, int dim_x, int dim_h, + int dim_y, int stride_x) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)) + size_t scales_along_dim_x = div_up(dim_x, 128); + size_t scales_along_dim_y = div_up(dim_y, 1); + size_t scales_along_dim_h = div_up(dim_h, 1); + size_t stride_scale_dim_y = div_up(dim_y, 4) * 4; + using Input2Type = + typename std::conditional::value, half2, __nv_bfloat162>::type; + for (size_t warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + warp_idx < scales_along_dim_x * scales_along_dim_y * scales_along_dim_h; + warp_idx += gridDim.x * blockDim.x / 32) { + int scales_idx_y = warp_idx / (scales_along_dim_x * scales_along_dim_h); + int scales_idx_h = (warp_idx % (scales_along_dim_x * scales_along_dim_h)) / scales_along_dim_x; + int scales_idx_x = warp_idx % scales_along_dim_x; + + InputType const* input_line = input + (size_t)scales_idx_y * stride_x * dim_h + + (size_t)scales_idx_h * stride_x + scales_idx_x * 128; + InputType input_amax = InputType(0); + int lane_id = threadIdx.x % 32 * 2; + + Input2Type input_frag2[2] = {Input2Type(0, 0), Input2Type(0, 0)}; +#pragma unroll + for (int i = 0; i < 2; i++) { + if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x) { + break; + } else { + input_frag2[i] = *((Input2Type*)(input_line) + lane_id / 2); + } + input_line += 64; + } +#pragma unroll + for (int i = 0; i < 2; i++) { + if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x) { + break; + } else { + input_amax = InputType( + __hmax(input_amax, __hmax(__habs(input_frag2[i].x), __habs(input_frag2[i].y)))); + } + } + + InputType amax = find_max_elem_in_warp(input_amax); + ScaleType scale = amax != InputType(0.f) ? 448.f / ScaleType(amax) : 1.f; + + if (lane_id == 0) { + scales[(size_t)scales_idx_h * scales_along_dim_x * stride_scale_dim_y + + (size_t)scales_idx_x * stride_scale_dim_y + scales_idx_y] = ScaleType(1.f / scale); + } + + OutputType* output_line = output + (size_t)scales_idx_h * dim_y * dim_x + + (size_t)scales_idx_y * dim_x + scales_idx_x * 128; +#pragma unroll + for (int i = 0; i < 2; i++) { + if (scales_idx_x * 128 + i * 64 + lane_id >= dim_x) { + break; + } else { + ScaleType value_1 = ScaleType(input_frag2[i].x) * scale; + ScaleType value_2 = ScaleType(input_frag2[i].y) * scale; + output_line[lane_id] = OutputType(value_1); + output_line[lane_id + 1] = OutputType(value_2); + } + output_line += 64; + } + } +#endif +} + +template +__global__ void scale_128x128_kernel(OutputType* output, ScaleType* scales, + InputType const* const input, int dim_x, int dim_y) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + int scales_along_dim_x = div_up(dim_x, 128); + int scales_along_dim_y = div_up(dim_y, 128); + + for (int warp_idx = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + warp_idx < scales_along_dim_x * scales_along_dim_y; + warp_idx += gridDim.x * blockDim.x / 32) { + int scales_idx_y = warp_idx / scales_along_dim_x; + int scales_idx_x = warp_idx % scales_along_dim_x; + + InputType const* input_line = input + scales_idx_y * 128 * dim_x + scales_idx_x * 128; + InputType input_amax = InputType(0); + int lane_id = threadIdx.x % 32; + + for (int i = 0; i < 128; i++) { + if (scales_idx_y * 128 + i >= dim_y) { + break; + } + InputType const* input_d = input_line; + + for (int j = 0; j < 4; j++) { + if (scales_idx_x * 128 + i * 32 + lane_id >= dim_x) { + break; + } else { + input_amax = InputType(std::max(float(input_amax), std::fabs(float(input_d[lane_id])))); + } + input_d += 32; + } + input_line += dim_x; + } + + InputType amax = find_max_elem_in_warp(input_amax); + ScaleType scale = amax != InputType(0.f) ? 448.f / ScaleType(amax) : 1.f; + + if (lane_id == 0) { + scales[scales_idx_y * scales_along_dim_x + scales_idx_x] = ScaleType(1.f / scale); + } + + input_line = input + scales_idx_y * 128 * dim_x + scales_idx_x * 128; + OutputType* output_line = output + scales_idx_y * 128 * dim_x + scales_idx_x * 128; + + for (int i = 0; i < 128; i++) { + if (scales_idx_y * 128 + i >= dim_y) { + break; + } + InputType const* input_d = input_line; + OutputType* output_d = output_line; + + for (int j = 0; j < 4; j++) { + if (scales_idx_x * 128 + j * 32 + lane_id >= dim_x) { + break; + } else { + output_d[lane_id] = OutputType(ScaleType(input_d[lane_id]) * scale); + } + input_d += 32; + output_d += 32; + } + + input_line += dim_x; + output_line += dim_x; + } + } +#endif +} + +template +__global__ void fill_kernel(OutputType* output, size_t num_elems, float value) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num_elems; + idx += gridDim.x * blockDim.x) { + output[idx] = OutputType(value); + } +} + +template +__global__ void convert_kernel(OutputType* output, InputType const* const input, size_t num_elems) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num_elems; + idx += gridDim.x * blockDim.x) { + float value = float(input[idx]); + if (std::isnan(value)) { + output[idx] = OutputType(448); + } else { + output[idx] = OutputType(value); + } + } +} + +static int kNumDeviceSMs = -1; +static bool kDeepGemmEnabled = []() -> bool { + char const* env_var = std::getenv("TRTLLM_DG_ENABLED"); + return deep_gemm::jit::getGlobalCompiler().isValid() && (!env_var || std::string(env_var) != "0"); +}(); + +void fp8_1x128_cs(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 const* mat, int shape_x, + int shape_y, cudaStream_t stream) { + if (kNumDeviceSMs < 0) { + kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); + } + scale_1x128_kernel<<>>(mat_quant, scales, mat, shape_x, + shape_y); +} + +void fp8_1x128_cs_reshape(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 const* mat, + int shape_x, int shape_h, int shape_y, int stride_x, + cudaStream_t stream) { + if (kNumDeviceSMs < 0) { + kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); + } + scale_1x128_reshape_kernel<<>>(mat_quant, scales, mat, shape_x, + shape_h, shape_y, stride_x); +} + +void fp8_128x128_cs(__nv_fp8_e4m3* mat_quant, float* scales, __nv_bfloat16 const* mat, int shape_x, + int shape_y, cudaStream_t stream) { + if (kNumDeviceSMs < 0) { + kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); + } + convert_kernel<<>>(mat_quant, mat, shape_x * shape_y); + fill_kernel<<>>(scales, + div_up(shape_x, 128) * div_up(shape_y, 128), 1); +} + +void gemm_dispatch_old(void* mat_a, int ld_a, void* mat_b, int ld_b, void* mat_d, int ld_d, + float* scales_a, float* scales_b, int shape_m, int shape_n, int shape_k, + cudaStream_t stream) { + if (kNumDeviceSMs < 0) { + kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); + } + auto get_status = [=](int tile_n) -> std::pair { + int num_blocks = div_up(shape_n, tile_n); + int num_waves = div_up(num_blocks, kNumDeviceSMs); + return {num_waves, num_blocks % kNumDeviceSMs}; + }; + + auto compare = [=](int tile_n, int old_block_n) -> bool { + if (old_block_n == 0) return true; + + auto status = get_status(tile_n); + auto old_status = get_status(old_block_n); + if (status.first != old_status.first) return status.first < old_status.first; + if (status.first == 1) return status.second > old_status.second; + return tile_n > old_block_n; + }; + + int best_tile_m = shape_m <= 64 ? 64 : 128, best_block_n = 0; + for (auto const& tile_n : {32, 64, 128}) + if (compare(tile_n, best_block_n)) best_block_n = tile_n; + +#define DISPATCH_BLOCK_SIZE(TILE_M, TILE_N) \ + { \ + using GemmType = \ + Fp8Gemm<__nv_fp8_e4m3, Layout::RowMajor, __nv_fp8_e4m3, Layout::ColMajor, __nv_bfloat16, \ + Layout::RowMajor, float, float, float, TILE_M, TILE_N, 128, \ + ScaleType::PerSubChannel, ScaleType::PerBlock, 1, 128, 128, 128>; \ + GemmType::run(reinterpret_cast<__nv_fp8_e4m3*>(mat_a), ld_a, \ + reinterpret_cast<__nv_fp8_e4m3*>(mat_b), ld_b, \ + reinterpret_cast<__nv_bfloat16*>(mat_d), ld_d, scales_a, scales_b, shape_m, \ + shape_n, shape_k, stream \ + \ + ); \ + } \ + break + +#define DISPATCH_BLOCK_SIZE_M(TILE_N) \ + { \ + switch (best_tile_m) { \ + case 64: \ + DISPATCH_BLOCK_SIZE(64, TILE_N); \ + case 128: \ + DISPATCH_BLOCK_SIZE(128, TILE_N); \ + } \ + } \ + break + + switch (best_block_n) { + case 16: + DISPATCH_BLOCK_SIZE_M(16); + case 32: + DISPATCH_BLOCK_SIZE_M(32); + case 64: + DISPATCH_BLOCK_SIZE_M(64); + case 128: + DISPATCH_BLOCK_SIZE_M(128); + } +#undef DISPATCH_BLOCK_SIZE +#undef DISPATCH_BLOCK_SIZE_M +} + +void gemm_dispatch_old(void* mat_a, void* mat_b, void* mat_d, float* scales_a, float* scales_b, + int num_problems, int64_t const* problem_m_offsets, int max_shape_m, + int shape_n, int shape_k, cudaStream_t stream) { + if (kNumDeviceSMs < 0) { + kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); + } + auto get_status = [=](int tile_n) -> std::pair { + int num_blocks = div_up(shape_n, tile_n); + int num_waves = div_up(num_blocks, kNumDeviceSMs); + return {num_waves, num_blocks % kNumDeviceSMs}; + }; + + auto compare = [=](int tile_n, int old_block_n) -> bool { + if (old_block_n == 0) return true; + + auto status = get_status(tile_n), old_status = get_status(old_block_n); + if (status.first != old_status.first) return status.first < old_status.first; + if (status.first == 1) return status.second > old_status.second; + return tile_n > old_block_n; + }; + + int shape_m = 128; + int best_tile_m = shape_m <= 64 ? 64 : 128, best_block_n = 0; + for (auto const& tile_n : {64, 128}) + if (compare(tile_n, best_block_n)) best_block_n = tile_n; + +#define DISPATCH_BLOCK_SIZE(TILE_M, TILE_N) \ + { \ + using GemmType = \ + Fp8Gemm<__nv_fp8_e4m3, Layout::RowMajor, __nv_fp8_e4m3, Layout::ColMajor, __nv_bfloat16, \ + Layout::RowMajor, float, float, float, TILE_M, TILE_N, 128, \ + ScaleType::PerSubChannel, ScaleType::PerBlock, 1, 128, 128, 128>; \ + GemmType::run(reinterpret_cast<__nv_fp8_e4m3*>(mat_a), \ + reinterpret_cast<__nv_fp8_e4m3*>(mat_b), \ + reinterpret_cast<__nv_bfloat16*>(mat_d), scales_a, scales_b, num_problems, \ + problem_m_offsets, shape_n, shape_k, max_shape_m, stream \ + \ + ); \ + } \ + break + +#define DISPATCH_BLOCK_SIZE_M(TILE_N) \ + { \ + switch (best_tile_m) { \ + case 64: \ + DISPATCH_BLOCK_SIZE(64, TILE_N); \ + case 128: \ + DISPATCH_BLOCK_SIZE(128, TILE_N); \ + } \ + } \ + break + + switch (best_block_n) { + case 16: + DISPATCH_BLOCK_SIZE_M(16); + case 32: + DISPATCH_BLOCK_SIZE_M(32); + case 64: + DISPATCH_BLOCK_SIZE_M(64); + case 128: + DISPATCH_BLOCK_SIZE_M(128); + } +#undef DISPATCH_BLOCK_SIZE +#undef DISPATCH_BLOCK_SIZE_M +} + +void gemm_dispatch(void* mat_a, int ld_a, void* mat_b, int ld_b, void* mat_d, int ld_d, + float* scales_a, float* scales_b, uint32_t shape_m, uint32_t shape_n, + uint32_t shape_k, cudaStream_t stream, int num_device_sms = kNumDeviceSMs) { + if (num_device_sms < 0) { + num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); + } + + constexpr uint32_t block_k = 128; + constexpr uint32_t num_problems = 1; + + uint32_t m_threshold = 32; + if (shape_m >= m_threshold) { + // Select the best configuration based on shape dimensions + auto [best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size] = + deep_gemm::jit::get_best_gemm_config(shape_m, shape_n, shape_k, num_problems, + num_device_sms); + + auto runtime = deep_gemm::jit::getGlobalCompiler().build( + shape_n, shape_k, best_block_m, best_block_n, block_k, num_problems, best_num_stages, + best_num_tma_multicast, deep_gemm::GemmType::Normal); + auto kernel = reinterpret_cast(runtime->getKernel()); + deep_gemm::runGemm(kernel, mat_a, ld_a, mat_b, ld_b, mat_d, ld_d, scales_a, scales_b, shape_m, + shape_n, shape_k, best_block_m, best_block_n, block_k, num_problems, + best_num_tma_multicast, deep_gemm::GemmType::Normal, + static_cast(nullptr), stream, num_device_sms, + static_cast(best_smem_size)); + } else { + auto [best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size] = + deep_gemm::jit::get_best_gemm_config(shape_n, shape_m, shape_k, num_problems, + num_device_sms, false, true); + auto runtime = deep_gemm::jit::getGlobalCompiler().build( + shape_n, shape_k, best_block_m, best_block_n, block_k, num_problems, best_num_stages, + best_num_tma_multicast, deep_gemm::GemmType::Normal, true); + auto kernel = reinterpret_cast(runtime->getKernel()); + deep_gemm::runGemmSwapAB(kernel, mat_b, ld_b, mat_a, ld_a, mat_d, ld_d, scales_b, scales_a, + shape_n, shape_m, shape_k, best_block_m, best_block_n, block_k, + num_problems, best_num_tma_multicast, deep_gemm::GemmType::Normal, + static_cast(nullptr), stream, num_device_sms, + static_cast(best_smem_size)); + } +} + +void gemm_dispatch_sm89(void* mat_a, void* mat_b, void* mat_d, float* scales_a, float* scales_b, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, cudaStream_t stream, + int num_device_sms = kNumDeviceSMs) { + if (num_device_sms < 0) { + num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); + } + using ElementInput = cute::float_e4m3_t; + using ElementOutput = cute::bfloat16_t; + using ElementAccum = float; + using ElementBlockScale = float; + static constexpr int Stages = 3; + using TileShape = cutlass::gemm::GemmShape<32, 128, 128>; + using KT = ada_blockwise_gemm::AdaBlockwiseGemmTraits; + using GemmKernel = ada_blockwise_gemm::AdaBlockwiseGemmKernel; + + static constexpr int kSmemSize = KT::kSmemSize; + static constexpr int kThreadCount = KT::kThreadCount; + int grid_m = (shape_m + KT::kTileM - 1) / KT::kTileM; + int grid_n = (shape_n + KT::kTileN - 1) / KT::kTileN; + int grid_k = 1; + dim3 grid = dim3(grid_m, grid_n, grid_k); + dim3 block = dim3(kThreadCount, 1, 1); + + if (kSmemSize > (48 << 10)) { + cudaFuncSetAttribute(ada_blockwise_gemm::sm89_fp8_gemm_1d1d_impl, + cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize); + auto result = cudaGetLastError(); + TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel cannot launch: %s", + cudaGetErrorString(result)); + } + + ada_blockwise_gemm::sm89_fp8_gemm_1d1d_impl<<>>( + shape_m, shape_n, shape_k, mat_a, mat_b, mat_d, scales_a, scales_b); +} + +void fp8_gemm_run(__nv_fp8_e4m3* mat_a, int ld_a, __nv_fp8_e4m3* mat_b, int ld_b, + __nv_bfloat16* mat_d, int ld_d, uint32_t shape_m, uint32_t shape_n, + uint32_t shape_k, float* scales_a, float* scales_b, cudaStream_t stream) { + if (shape_m == 0) { + return; + } +#ifndef PLACEHOLDER_KERNELS + int arch = tensorrt_llm::common::getSMVersion(); + if (arch == 89) { + gemm_dispatch_sm89(mat_a, mat_b, mat_d, scales_a, scales_b, shape_m, shape_n, shape_k, stream); + return; + } + if (kDeepGemmEnabled) { + gemm_dispatch(mat_a, ld_a, mat_b, ld_b, mat_d, ld_d, scales_a, scales_b, shape_m, shape_n, + shape_k, stream); + } else { + gemm_dispatch_old(mat_a, ld_a, mat_b, ld_b, mat_d, ld_d, scales_a, scales_b, + static_cast(shape_m), static_cast(shape_n), + static_cast(shape_k), stream); + } +#endif +} + +void fp8_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, int ld_a, float* scales_a, + __nv_bfloat16 const* mat_b, __nv_fp8_e4m3* fp8_mat_b, int ld_b, float* scales_b, + __nv_bfloat16* mat_d, int ld_d, uint32_t shape_m, uint32_t shape_n, + uint32_t shape_k, cudaStream_t stream, bool internal_quantize_a = true, + bool internal_quantize_b = true) { + if (shape_m == 0) { + return; + } + if (kNumDeviceSMs < 0) { + kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); + } + + if (internal_quantize_a) { + scale_1x128_kernel<<>>(fp8_mat_a, scales_a, mat_a, shape_k, + shape_m); + } + if (internal_quantize_b) { + scale_128x128_kernel<<>>(fp8_mat_b, scales_b, mat_b, shape_k, + shape_n); + } + fp8_gemm_run(fp8_mat_a, ld_a, fp8_mat_b, ld_b, mat_d, ld_d, shape_m, shape_n, shape_k, scales_a, + scales_b, stream); +} + +void grouped_gemm_dispatch(__nv_fp8_e4m3* mat_a, __nv_fp8_e4m3* mat_b, __nv_bfloat16* mat_d, + uint32_t num_problems, int64_t const* problem_m_offsets, + uint32_t expected_m, uint32_t max_shape_m, uint32_t max_shape_m_padded, + uint32_t shape_n, uint32_t shape_k, float* scales_a, float* scales_b, + cudaStream_t stream, int num_device_sms = kNumDeviceSMs) { + if (num_device_sms < 0) { + num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); + } + + constexpr uint32_t block_k = 128; + uint32_t m_per_expert_threshold = + num_device_sms == 78 ? 64 : 32; // 64 for H20(sms=78), 32 for H100/H200 + if (expected_m >= m_per_expert_threshold) { + auto [best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size] = + deep_gemm::jit::get_best_gemm_config(expected_m, shape_n, shape_k, num_problems, + num_device_sms); + + auto runtime = deep_gemm::jit::getGlobalCompiler().build( + shape_n, shape_k, best_block_m, best_block_n, block_k, num_problems, best_num_stages, + best_num_tma_multicast, deep_gemm::GemmType::GroupedWithOffset); + auto kernel = reinterpret_cast(runtime->getKernel()); + deep_gemm::runGemm(kernel, mat_a, 0, mat_b, 0, mat_d, 0, scales_a, scales_b, max_shape_m, + shape_n, shape_k, best_block_m, best_block_n, block_k, num_problems, + best_num_tma_multicast, deep_gemm::GemmType::GroupedWithOffset, + const_cast(problem_m_offsets), stream, num_device_sms, + static_cast(best_smem_size), max_shape_m_padded); + } else { + auto [best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size] = + deep_gemm::jit::get_best_gemm_config(shape_n, expected_m, shape_k, num_problems, + num_device_sms, false, true); + auto runtime = deep_gemm::jit::getGlobalCompiler().build( + shape_n, shape_k, best_block_m, best_block_n, block_k, num_problems, best_num_stages, + best_num_tma_multicast, deep_gemm::GemmType::GroupedWithOffset, true); + auto kernel = reinterpret_cast(runtime->getKernel()); + + deep_gemm::runGemmSwapAB( + kernel, mat_b, 0, mat_a, 0, mat_d, 0, scales_b, scales_a, shape_n, max_shape_m, shape_k, + best_block_m, best_block_n, block_k, num_problems, best_num_tma_multicast, + deep_gemm::GemmType::GroupedWithOffset, const_cast(problem_m_offsets), stream, + num_device_sms, static_cast(best_smem_size), max_shape_m_padded); + } +} + +void fp8_grouped_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, float* scales_a, + __nv_bfloat16 const* mat_b, __nv_fp8_e4m3* fp8_mat_b, float* scales_b, + __nv_bfloat16* mat_d, int64_t const* problem_m_offsets, int num_problems, + int64_t expected_m, int64_t max_shape_m, int64_t max_shape_m_padded, + int shape_n, int shape_k, cudaStream_t stream, + bool internal_quantize_a = true, bool internal_quantize_b = true) { + if (kNumDeviceSMs < 0) { + kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); + } + + if (internal_quantize_a) { + constexpr int NumThreads = 256; + int scales_dim_x = div_up(shape_k, 128); + uint32_t scale_dim_x_mul, scale_dim_x_shr; + kernel_utils::find_divisor(scale_dim_x_mul, scale_dim_x_shr, scales_dim_x); + + int smem_size = num_problems * sizeof(int64_t); + int num_blocks = std::min(static_cast(kNumDeviceSMs), + div_up(max_shape_m * scales_dim_x, NumThreads / 32)); + // Binary search is expected to have lower complexity when max_shape_m is small + bool use_binary_search = + static_cast(max_shape_m) * scales_dim_x / + static_cast(NumThreads * num_blocks / 32) <= + static_cast(num_problems) / std::log2(static_cast(num_problems)); + auto kernel = use_binary_search ? scale_1x128_kernel + : scale_1x128_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + kernel<<>>( + fp8_mat_a, scales_a, mat_a, problem_m_offsets, num_problems, shape_k, max_shape_m_padded, + scale_dim_x_mul, scale_dim_x_shr); + } + + if (internal_quantize_b) { + __nv_fp8_e4m3* fp8_mat_b_tmp = fp8_mat_b; + float* scales_b_tmp = scales_b; + __nv_bfloat16 const* mat_b_tmp = mat_b; + + for (int i = 0; i < num_problems; i++) { + scale_128x128_kernel<<>>(fp8_mat_b_tmp, scales_b_tmp, + mat_b_tmp, shape_k, shape_n); + fp8_mat_b_tmp += shape_n * shape_k; + mat_b_tmp += shape_n * shape_k; + scales_b_tmp += div_up(shape_n, 128) * div_up(shape_k, 128); + } + } + + if (kDeepGemmEnabled) { + grouped_gemm_dispatch(fp8_mat_a, fp8_mat_b, mat_d, num_problems, problem_m_offsets, expected_m, + max_shape_m, max_shape_m_padded, shape_n, shape_k, scales_a, scales_b, + stream); + } else { + using GemmType = Fp8Gemm<__nv_fp8_e4m3, Layout::RowMajor, __nv_fp8_e4m3, Layout::ColMajor, + __nv_bfloat16, Layout::RowMajor, float, float, float, 128, 64, 128, + ScaleType::PerSubChannel, ScaleType::PerBlock, 1, 128, 128, 128>; + GemmType::run(fp8_mat_a, fp8_mat_b, mat_d, scales_a, scales_b, num_problems, problem_m_offsets, + shape_n, shape_k, static_cast(max_shape_m), stream, 128, + static_cast(max_shape_m_padded)); + } +} + +void strided_batch_gemm_dispatch(__nv_fp8_e4m3* mat_a, int ld_a, int stride_a, __nv_fp8_e4m3* mat_b, + int ld_b, int stride_b, __nv_bfloat16* mat_d, int ld_d, + int stride_d, float* scales_a, float* scales_b, + uint32_t num_problems, uint32_t shape_m, uint32_t shape_n, + uint32_t shape_k, cudaStream_t stream, + int num_device_sms = kNumDeviceSMs) { + if (num_device_sms < 0) { + num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); + } + + constexpr uint32_t block_k = 128; + + // Select the best configuration based on shape dimensions + auto [best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size] = + deep_gemm::jit::get_best_gemm_config(shape_m, shape_n, shape_k, num_problems, num_device_sms); + + auto runtime = deep_gemm::jit::getGlobalCompiler().build( + shape_n, shape_k, best_block_m, best_block_n, block_k, num_problems, best_num_stages, + best_num_tma_multicast, deep_gemm::GemmType::StridedBatched); + auto kernel = reinterpret_cast(runtime->getKernel()); + deep_gemm::runGemm(kernel, mat_a, static_cast(ld_a), static_cast(stride_a), + mat_b, static_cast(ld_b), static_cast(stride_b), mat_d, + static_cast(ld_d), static_cast(stride_d), scales_a, + scales_b, shape_m, shape_n, shape_k, best_block_m, best_block_n, block_k, + num_problems, best_num_tma_multicast, deep_gemm::GemmType::StridedBatched, + stream, num_device_sms, static_cast(best_smem_size)); +} + +void strided_batch_gemm_dispatch_sm89(__nv_fp8_e4m3* mat_a, int ld_a, int stride_a, + __nv_fp8_e4m3* mat_b, int ld_b, int stride_b, + __nv_bfloat16* mat_d, int ld_d, int stride_d, float* scales_a, + int stride_scales_a, float* scales_b, uint32_t num_problems, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + cudaStream_t stream, int num_device_sms = kNumDeviceSMs) { + if (num_device_sms < 0) { + num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); + } + using ElementInput = cute::float_e4m3_t; + using ElementOutput = cute::bfloat16_t; + using ElementAccum = float; + using ElementBlockScale = float; + static constexpr int Stages = 3; + using TileShape = cutlass::gemm::GemmShape<32, 128, 128>; + using KT = ada_blockwise_gemm::AdaBlockwiseGemmTraits; + using GemmKernel = ada_blockwise_gemm::AdaBlockwiseGemmKernel; + + static constexpr int kSmemSize = KT::kSmemSize; + static constexpr int kThreadCount = KT::kThreadCount; + int grid_m = (shape_m + KT::kTileM - 1) / KT::kTileM; + int grid_n = (shape_n + KT::kTileN - 1) / KT::kTileN; + int grid_k = num_problems; + dim3 grid = dim3(grid_m, grid_n, grid_k); + dim3 block = dim3(kThreadCount, 1, 1); + + int stride_scales_b = ((shape_n + 128 - 1) / 128) * ((shape_k + 128 - 1) / 128); + + if (kSmemSize > (48 << 10)) { + cudaFuncSetAttribute(ada_blockwise_gemm::sm89_fp8_bmm_1d1d_impl, + cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize); + auto result = cudaGetLastError(); + TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm89 gemm kernel cannot launch: %s", + cudaGetErrorString(result)); + } + ada_blockwise_gemm::sm89_fp8_bmm_1d1d_impl<<>>( + shape_m, shape_n, shape_k, mat_a, mat_b, mat_d, scales_a, scales_b, stride_a, stride_b, + stride_d, stride_scales_a, stride_scales_b); +} + +void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, + float* scales_a, int ld_a, int stride_a, int stride_scales_a, + __nv_bfloat16 const* mat_b, __nv_fp8_e4m3* fp8_mat_b, + float* scales_b, int ld_b, int stride_b, __nv_bfloat16* mat_d, + int ld_d, int stride_d, uint32_t num_problems, uint32_t shape_m, + uint32_t shape_n, uint32_t shape_k, cudaStream_t stream, + bool internal_quantize_a = true, bool internal_quantize_b = true) { + if (shape_m == 0) { + return; + } + + if (kNumDeviceSMs < 0) { + kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); + } + if (internal_quantize_a) { + scale_1x128_kernel<<>>(fp8_mat_a, scales_a, mat_a, shape_k, + shape_m * num_problems); + } + if (internal_quantize_b) { + scale_128x128_kernel<<>>(fp8_mat_b, scales_b, mat_b, shape_k, + shape_n * num_problems); + } + + int arch = tensorrt_llm::common::getSMVersion(); + if (arch == 89) { + strided_batch_gemm_dispatch_sm89(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, + ld_d, stride_d, scales_a, stride_scales_a, scales_b, + num_problems, shape_m, shape_n, shape_k, stream); + return; + } + if (kDeepGemmEnabled) { + strided_batch_gemm_dispatch(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, + stride_d, scales_a, scales_b, num_problems, shape_m, shape_n, + shape_k, stream); + } else { + using GemmType = Fp8Gemm<__nv_fp8_e4m3, Layout::RowMajor, __nv_fp8_e4m3, Layout::ColMajor, + __nv_bfloat16, Layout::RowMajor, float, float, float, 128, 64, 128, + ScaleType::PerSubChannel, ScaleType::PerBlock, 1, 128, 128, 128>; + GemmType::run(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, stride_d, + scales_a, stride_scales_a, scales_b, static_cast(shape_m), + static_cast(shape_n), static_cast(shape_k), + static_cast(num_problems), stream); + } +} + +} // namespace tensorrt_llm::kernels::fp8_blockscale_gemm diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_stub.cu b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_stub.cu deleted file mode 100644 index 39be8ed42b..0000000000 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_stub.cu +++ /dev/null @@ -1,93 +0,0 @@ -#include "fp8_blockscale_gemm.h" - -namespace tensorrt_llm { -namespace kernels { -namespace fp8_blockscale_gemm { - -// Explicit instantiation of the template -template class CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>; - -template <> -void CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>::gemm( - void*, const void*, const void*, int, int, int, cudaStream_t, const float*, const float*) { - // stub -} - -template <> -void CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>::gemm( - const __nv_fp8_e4m3*, int, const __nv_fp8_e4m3*, int, __nv_bfloat16*, int, int, int, int, - const float*, const float*, cudaStream_t) { - // stub -} - -template <> -void CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>::moeGemm( - void*, const void*, const void*, const int64_t*, size_t, size_t, size_t, cudaStream_t, - const float*, const float*) {} - -template <> -void CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>::strideBatchGemm( - __nv_bfloat16*, int, int, __nv_fp8_e4m3*, int, int, __nv_fp8_e4m3*, int, int, int, int, int, - int, cudaStream_t, float*, int, float*) {} - -template <> -void CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>::fp8CS1x128( - __nv_fp8_e4m3*, float*, const __nv_bfloat16*, int, int, cudaStream_t) {} - -template <> -void CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>::fp8CS1x128Reshape( - __nv_fp8_e4m3*, float*, const __nv_bfloat16*, int, int, int, int, cudaStream_t) {} - -template <> -void CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>::fp8CS128x128( - __nv_fp8_e4m3*, float*, const __nv_bfloat16*, int, int, cudaStream_t) {} - -template <> -size_t CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, - __nv_bfloat16>::getWorkspaceSizeBase(size_t, size_t, size_t, - size_t) { - return 0; -} - -template <> -size_t CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, - __nv_bfloat16>::getWorkspaceSize(size_t, size_t, size_t, - size_t, size_t) { - return 0; -} - -template <> -size_t CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>::getFP8DataSize( - int, int, bool) { - return 0; -} - -template <> -size_t CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>::getActScaleSize( - int, int) { - return 0; -} - -template <> -size_t CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, - __nv_bfloat16>::getWeightScaleSize(int, int) { - return 0; -} - -template <> -size_t CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, - __nv_bfloat16>::getActWorkspaceSize(int, int) { - return 0; -} - -template <> -size_t CutlassFp8BlockScaleGemmRunner<__nv_bfloat16, __nv_fp8_e4m3, - __nv_bfloat16>::getWeightWorkspaceSize(int, int) { - return 0; -} - -// Add other method stubs if linker demands more - -} // namespace fp8_blockscale_gemm -} // namespace kernels -} // namespace tensorrt_llm diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_mma_utils.cuh b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_mma_utils.cuh new file mode 100644 index 0000000000..641ef8f100 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_mma_utils.cuh @@ -0,0 +1,594 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include + +namespace tensorrt_llm::kernels::fp8_blockscale_gemm { + +struct SM90_64x16x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, bool scale_d) { +#ifdef CUTLASS_ARCH_MMA_SM90A_ENABLED + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); +#else + CUTE_INVALID_CONTROL_PATH("wgmma is only available on SM90a"); +#endif + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 16; + static constexpr int K = 32; + static constexpr int NUM_ACCUM = M * N / 128; +}; + +struct SM90_64x32x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + bool scale_d) { +#ifdef CUTLASS_ARCH_MMA_SM90A_ENABLED + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); +#else + CUTE_INVALID_CONTROL_PATH("wgmma is only available on SM90a"); +#endif + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 32; + static constexpr int K = 32; + static constexpr int NUM_ACCUM = M * N / 128; +}; + +struct SM90_64x48x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, + float& d21, float& d22, float& d23, bool scale_d) { +#ifdef CUTLASS_ARCH_MMA_SM90A_ENABLED + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); +#else + CUTE_INVALID_CONTROL_PATH("wgmma is only available on SM90a"); +#endif + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 48; + static constexpr int K = 32; + static constexpr int NUM_ACCUM = M * N / 128; +}; + +struct SM90_64x56x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, + float& d21, float& d22, float& d23, float& d24, float& d25, + float& d26, float& d27, bool scale_d) { +#ifdef CUTLASS_ARCH_MMA_SM90A_ENABLED + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}, " + " %28," + " %29," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); +#else + CUTE_INVALID_CONTROL_PATH("wgmma is only available on SM90a"); +#endif + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 56; + static constexpr int K = 32; + static constexpr int NUM_ACCUM = M * N / 128; +}; + +struct SM90_64x64x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, + float& d21, float& d22, float& d23, float& d24, float& d25, + float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, bool scale_d) { +#ifdef CUTLASS_ARCH_MMA_SM90A_ENABLED + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}, " + " %32," + " %33," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); +#else + CUTE_INVALID_CONTROL_PATH("wgmma is only available on SM90a"); +#endif + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 64; + static constexpr int K = 32; + static constexpr int NUM_ACCUM = M * N / 128; +}; + +struct SM90_64x96x32_F32E4M3E4M3_SS { + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float& d00, + float& d01, float& d02, float& d03, float& d04, float& d05, + float& d06, float& d07, float& d08, float& d09, float& d10, + float& d11, float& d12, float& d13, float& d14, float& d15, + float& d16, float& d17, float& d18, float& d19, float& d20, + float& d21, float& d22, float& d23, float& d24, float& d25, + float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, float& d32, float& d33, float& d34, float& d35, + float& d36, float& d37, float& d38, float& d39, float& d40, + float& d41, float& d42, float& d43, float& d44, float& d45, + float& d46, float& d47, bool scale_d) { +#ifdef CUTLASS_ARCH_MMA_SM90A_ENABLED + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}, " + " %48," + " %49," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); +#else + CUTE_INVALID_CONTROL_PATH("wgmma is only available on SM90a"); +#endif + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], d[32], d[33], d[34], d[35], d[36], d[37], + d[38], d[39], d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 96; + static constexpr int K = 32; + static constexpr int NUM_ACCUM = M * N / 128; +}; + +struct SM90_64x112x32_F32E4M3E4M3_SS { + __device__ static void wgmma( + uint64_t const& desc_a, uint64_t const& desc_b, float& d00, float& d01, float& d02, + float& d03, float& d04, float& d05, float& d06, float& d07, float& d08, float& d09, + float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, float& d16, + float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, + float& d38, float& d39, float& d40, float& d41, float& d42, float& d43, float& d44, + float& d45, float& d46, float& d47, float& d48, float& d49, float& d50, float& d51, + float& d52, float& d53, float& d54, float& d55, bool scale_d) { +#ifdef CUTLASS_ARCH_MMA_SM90A_ENABLED + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}, " + " %56," + " %57," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); +#else + CUTE_INVALID_CONTROL_PATH("wgmma is only available on SM90a"); +#endif + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], d[32], d[33], d[34], d[35], d[36], d[37], + d[38], d[39], d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], d[48], d[49], d[50], + d[51], d[52], d[53], d[54], d[55], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 112; + static constexpr int K = 32; + static constexpr int NUM_ACCUM = M * N / 128; +}; + +struct SM90_64x128x32_F32E4M3E4M3_SS { + __device__ static void wgmma( + uint64_t const& desc_a, uint64_t const& desc_b, float& d00, float& d01, float& d02, + float& d03, float& d04, float& d05, float& d06, float& d07, float& d08, float& d09, + float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, float& d16, + float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, + float& d38, float& d39, float& d40, float& d41, float& d42, float& d43, float& d44, + float& d45, float& d46, float& d47, float& d48, float& d49, float& d50, float& d51, + float& d52, float& d53, float& d54, float& d55, float& d56, float& d57, float& d58, + float& d59, float& d60, float& d61, float& d62, float& d63, bool scale_d) { +#ifdef CUTLASS_ARCH_MMA_SM90A_ENABLED + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}, " + " %64," + " %65," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); +#else + CUTE_INVALID_CONTROL_PATH("wgmma is only available on SM90a"); +#endif + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], d[32], d[33], d[34], d[35], d[36], d[37], + d[38], d[39], d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], d[48], d[49], d[50], + d[51], d[52], d[53], d[54], d[55], d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 128; + static constexpr int K = 32; + static constexpr int NUM_ACCUM = M * N / 128; +}; + +struct SM90_64x192x32_F32E4M3E4M3_SS { + __device__ static void wgmma( + uint64_t const& desc_a, uint64_t const& desc_b, float& d00, float& d01, float& d02, + float& d03, float& d04, float& d05, float& d06, float& d07, float& d08, float& d09, + float& d10, float& d11, float& d12, float& d13, float& d14, float& d15, float& d16, + float& d17, float& d18, float& d19, float& d20, float& d21, float& d22, float& d23, + float& d24, float& d25, float& d26, float& d27, float& d28, float& d29, float& d30, + float& d31, float& d32, float& d33, float& d34, float& d35, float& d36, float& d37, + float& d38, float& d39, float& d40, float& d41, float& d42, float& d43, float& d44, + float& d45, float& d46, float& d47, float& d48, float& d49, float& d50, float& d51, + float& d52, float& d53, float& d54, float& d55, float& d56, float& d57, float& d58, + float& d59, float& d60, float& d61, float& d62, float& d63, float& d64, float& d65, + float& d66, float& d67, float& d68, float& d69, float& d70, float& d71, float& d72, + float& d73, float& d74, float& d75, float& d76, float& d77, float& d78, float& d79, + float& d80, float& d81, float& d82, float& d83, float& d84, float& d85, float& d86, + float& d87, float& d88, float& d89, float& d90, float& d91, float& d92, float& d93, + float& d94, float& d95, bool scale_d) { +#ifdef CUTLASS_ARCH_MMA_SM90A_ENABLED + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3" + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}, " + " %96," + " %97," + " p , 1, 1;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), "l"(desc_b), "r"(int32_t(scale_d))); +#else + CUTE_INVALID_CONTROL_PATH("wgmma is only available on SM90a"); +#endif + } + + __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, + bool scale_d) { + wgmma(desc_a, desc_b, d[0], d[1], d[2], d[3], d[4], d[5], d[6], d[7], d[8], d[9], d[10], d[11], + d[12], d[13], d[14], d[15], d[16], d[17], d[18], d[19], d[20], d[21], d[22], d[23], d[24], + d[25], d[26], d[27], d[28], d[29], d[30], d[31], d[32], d[33], d[34], d[35], d[36], d[37], + d[38], d[39], d[40], d[41], d[42], d[43], d[44], d[45], d[46], d[47], d[48], d[49], d[50], + d[51], d[52], d[53], d[54], d[55], d[56], d[57], d[58], d[59], d[60], d[61], d[62], d[63], + d[64], d[65], d[66], d[67], d[68], d[69], d[70], d[71], d[72], d[73], d[74], d[75], d[76], + d[77], d[78], d[79], d[80], d[81], d[82], d[83], d[84], d[85], d[86], d[87], d[88], d[89], + d[90], d[91], d[92], d[93], d[94], d[95], scale_d); + } + + static constexpr int M = 64; + static constexpr int N = 192; + static constexpr int K = 32; + static constexpr int NUM_ACCUM = M * N / 128; +}; + +__device__ void warpgroup_arrive() { +#ifdef CUTLASS_ARCH_MMA_SM90A_ENABLED + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +#else + CUTE_INVALID_CONTROL_PATH("wgmma is only available on SM90a"); +#endif +} + +__device__ void warpgroup_commit_batch() { +#ifdef CUTLASS_ARCH_MMA_SM90A_ENABLED + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +#else + CUTE_INVALID_CONTROL_PATH("wgmma is only available on SM90a"); +#endif +} + +__device__ void warpgroup_fence_operand(float& reg) { +#ifdef CUTLASS_ARCH_MMA_SM90A_ENABLED + asm volatile("" : "+f"(reg)::"memory"); +#else + CUTE_INVALID_CONTROL_PATH("wgmma is only available on SM90a"); +#endif +} + +template +__device__ void warpgroup_wait() { +#ifdef CUTLASS_ARCH_MMA_SM90A_ENABLED + static_assert(N >= 0 && N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("wgmma is only available on SM90a"); +#endif +} + +union GmmaDescriptor { + __host__ __device__ constexpr GmmaDescriptor() noexcept : desc_(0) {} + + __host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept : desc_(desc) {} + + __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const& t) noexcept : desc_(t.desc_) {} + + __host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor&& t) noexcept : desc_(t.desc_) {} + + __host__ __device__ constexpr GmmaDescriptor& operator=(GmmaDescriptor const& t) noexcept { + desc_ = t.desc_; + return *this; + } + + __host__ __device__ constexpr GmmaDescriptor& operator=(GmmaDescriptor&& t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint64_t desc_; + uint32_t reg32_[2]; + uint16_t reg16_[4]; + + struct { + uint16_t start_address_ : 14, : 2; + uint16_t leading_byte_offset_ : 14, : 2; + uint16_t stride_byte_offset_ : 14, : 2; + uint8_t : 1, base_offset_ : 3, : 4; + uint8_t : 6, layout_type_ : 2; + } bitfield; + + // Decay to a uint64_t + __host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; } +}; + +template +__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type, + int leading_byte_offset = 0, + int stride_byte_offset = 1024) { + GmmaDescriptor desc; + uint32_t uint_ptr = static_cast(cute::cast_smem_ptr_to_uint(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +struct Fp8MmaSelector { + static constexpr auto select_type() { + if constexpr (std::is_same_v && + std::is_same_v) { + if constexpr (N == 16) { + return SM90_64x16x32_F32E4M3E4M3_SS(); + } + if constexpr (N == 32) { + return SM90_64x32x32_F32E4M3E4M3_SS(); + } + if constexpr (N == 48) { + return SM90_64x48x32_F32E4M3E4M3_SS(); + } + if constexpr (N == 56) { + return SM90_64x56x32_F32E4M3E4M3_SS(); + } + if constexpr (N == 64) { + return SM90_64x64x32_F32E4M3E4M3_SS(); + } + if constexpr (N == 96) { + return SM90_64x96x32_F32E4M3E4M3_SS(); + } + if constexpr (N == 112) { + return SM90_64x112x32_F32E4M3E4M3_SS(); + } + if constexpr (N == 128) { + return SM90_64x128x32_F32E4M3E4M3_SS(); + } + if constexpr (N == 192) { + return SM90_64x192x32_F32E4M3E4M3_SS(); + } + } + } + + using Type = decltype(select_type()); +}; + +} // namespace tensorrt_llm::kernels::fp8_blockscale_gemm diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_tma_utils.cuh b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_tma_utils.cuh new file mode 100644 index 0000000000..aacdcf9a82 --- /dev/null +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_tma_utils.cuh @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include + +#include +#include +#include + +namespace tensorrt_llm::kernels::fp8_blockscale_gemm { + +template +inline CUtensorMapDataType get_CUtensorMapDataType() { + if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT16; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT32; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_UINT64; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_INT32; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_INT64; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + } else if constexpr (std::is_same::value) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; + } else { + static_assert(sizeof(T) < 0, "Unknown TMA Format!"); + } +} + +PFN_cuTensorMapEncodeTiled_v12000 get_cuTensorMapEncodeTiled() { + // Get pointer to cuTensorMapEncodeTiled + cudaDriverEntryPointQueryResult driver_status; + void* cuTensorMapEncodeTiled_ptr = nullptr; +#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 5) + cudaGetDriverEntryPointByVersion("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, 12000, + cudaEnableDefault, &driver_status); +#else + cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuTensorMapEncodeTiled_ptr, cudaEnableDefault, + &driver_status); +#endif + + if (driver_status != cudaDriverEntryPointSuccess) { + throw std::runtime_error("driver_status != cudaDriverEntryPointSuccess"); + } + + return reinterpret_cast(cuTensorMapEncodeTiled_ptr); +} + +template +CUtensorMap make_2d_tma_copy_desc(data_type* global_address, uint64_t gmem_dim[2], + uint64_t stride_in_bytes, uint32_t smem_dim[2], + CUtensorMapSwizzle swizzle_type, + PFN_cuTensorMapEncodeTiled_v12000 encode_func = nullptr) { + CUtensorMap tensor_map{}; + constexpr uint32_t rank = 2; + uint64_t global_stride[rank - 1] = {stride_in_bytes}; + uint32_t elem_strides[rank] = {1, 1}; + + if (encode_func == nullptr) { + encode_func = get_cuTensorMapEncodeTiled(); + } + + CUresult res = + encode_func(&tensor_map, get_CUtensorMapDataType::type>(), + rank, global_address, gmem_dim, global_stride, smem_dim, elem_strides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle_type, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + if (int(res) == 1) { + std::cout << "check 0: " << int(res) << std::endl; + std::cout << gmem_dim[0] << "\t" << gmem_dim[1] << std::endl; + } + return tensor_map; +} + +__device__ uint64_t mbarrier_arrive_1_expect_tx_cta(void* smem_ptr, uint32_t tx_count) { + uint64_t state; + asm("mbarrier.arrive.expect_tx.release.cta.shared::cta.b64 %0, [%1], %2; // 8. " + : "=l"(state) + : "r"(static_cast(cute::cast_smem_ptr_to_uint(smem_ptr))), "r"(tx_count) + : "memory"); + return state; +} + +} // namespace tensorrt_llm::kernels::fp8_blockscale_gemm diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 364d4182f1..bb45d3b8cb 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -29,6 +29,7 @@ TunableRunner, TuningConfig, ) +from ..jit.cpp_ext import is_cuda_version_at_least from ..jit.core import logger from ..jit import ( setup_cubin_loader, @@ -746,8 +747,6 @@ def cutlass_fused_moe( ------ NotImplementedError: If any of the following features are requested but not implemented: - - FP8 Block Scaling - - W4A8 Group Scaling - Minimum Latency Mode Note @@ -757,13 +756,22 @@ def cutlass_fused_moe( - Currently, some advanced features like FP8 block scaling and minimum latency mode are not implemented for Blackwell architecture. """ - if use_deepseek_fp8_block_scale: - raise NotImplementedError( - "DeepSeek FP8 Block Scaling is not yet implemented in CUTLASS for Blackwell." - ) + major, minor = torch.cuda.get_device_capability() + device_arch = f"{major * 10 + minor}" + if min_latency_mode: raise NotImplementedError("min latency mode not yet implemented for Blackwell.") + if use_deepseek_fp8_block_scale: + if device_arch != "90": + raise NotImplementedError( + "FP8 block scaling not yet implemented for Blackwell." + ) + elif not is_cuda_version_at_least("12.8"): + raise NotImplementedError( + "FP8 block scaling not implemented for CUDA 12.6 or lower." + ) + if enable_pdl is None: enable_pdl = device_support_pdl(input.device) @@ -780,9 +788,6 @@ def cutlass_fused_moe( output, output_shape, output_dtype, input.device, "output" ) - major, minor = torch.cuda.get_device_capability() - device_arch = f"{major * 10 + minor}" - return get_cutlass_fused_moe_module(device_arch).cutlass_fused_moe( output, input, diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index 56c7d2e751..f03200c3f8 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -64,6 +64,7 @@ def gen_cutlass_fused_moe_sm90_module(use_fast_build: bool = False) -> JitSpec: "-DCOMPILE_HOPPER_TMA_GROUPED_GEMMS", "-DENABLE_BF16", "-DENABLE_FP8", + "-DENABLE_FP8_BLOCK_SCALE" if is_cuda_version_at_least("12.8") else "", "-DENABLE_FP4" if is_cuda_version_at_least("12.8") else "", "-DUSING_OSS_CUTLASS_MOE_GEMM", ] @@ -127,7 +128,7 @@ def gen_cutlass_fused_moe_module( jit_env.FLASHINFER_CSRC_DIR / "nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu", jit_env.FLASHINFER_CSRC_DIR - / "nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm_stub.cu", + / "nv_internal/tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.cu", jit_env.FLASHINFER_CSRC_DIR / "fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu", jit_env.FLASHINFER_CSRC_DIR @@ -148,6 +149,7 @@ def gen_cutlass_fused_moe_module( ], extra_cuda_cflags=nvcc_flags, extra_cflags=["-DFAST_BUILD"] if use_fast_build else [], + extra_ldflags=["-lcuda", "-lnvrtc"], extra_include_paths=[ jit_env.FLASHINFER_CSRC_DIR / "nv_internal", jit_env.FLASHINFER_CSRC_DIR / "nv_internal" / "include", diff --git a/tests/moe/test_trtllm_cutlass_fused_moe.py b/tests/moe/test_trtllm_cutlass_fused_moe.py index ecdb3453da..b9b79a4028 100644 --- a/tests/moe/test_trtllm_cutlass_fused_moe.py +++ b/tests/moe/test_trtllm_cutlass_fused_moe.py @@ -14,6 +14,8 @@ limitations under the License. """ +from contextlib import nullcontext + import pytest import torch from torch.nn import functional as F @@ -938,17 +940,13 @@ def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor: @pytest.mark.parametrize("num_experts", NUM_EXPERTS) @pytest.mark.parametrize("top_k", TOP_K_VALUES) @pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES) -@pytest.mark.skipif( - torch.cuda.get_device_capability()[0] not in [10, 11, 12], - reason="FP8 block scaling is only supported on SM100, SM110 and SM120", -) def test_moe_fp8_block_scaling( batch_size, hidden_size, num_experts, top_k, intermediate_size ): """ Test MoE with FP8 block scaling (Deepseek style): - - Activation: 128x1 blocks - - Weights: 128x128 blocks + - Activation: BF16 (unquantized) + - Weights: FP8 with 128x128 block scaling - Each block has its own scaling factor Args: @@ -957,7 +955,6 @@ def test_moe_fp8_block_scaling( num_experts: Number of experts top_k: Number of experts to route to per token intermediate_size: Intermediate dimension size - Only support bf16 for hidden_states """ torch.manual_seed(42) otype = torch.bfloat16 @@ -981,11 +978,6 @@ def test_moe_fp8_block_scaling( routing_weights = torch.randn((batch_size, top_k)).cuda() routing_weights = F.softmax(routing_weights, dim=1) - # Run reference implementation (no quantization) - _ref_output = compute_with_experts( - num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights - ) - # Quantize input and weights x_quant, x_scales = per_token_group_quant_fp8(x, group_size=128) @@ -993,13 +985,13 @@ def test_moe_fp8_block_scaling( w2_dequant = torch.empty_like(w2_weight) w31_quant = torch.empty_like(w31_weight).to(torch.float8_e4m3fn) w2_quant = torch.empty_like(w2_weight).to(torch.float8_e4m3fn) - w31_scales = torch.randn( + w31_scales = torch.zeros( num_experts, ceil_div(2 * intermediate_size, 128), ceil_div(hidden_size, 128), dtype=torch.float32, ).cuda() - w2_scales = torch.randn( + w2_scales = torch.zeros( num_experts, ceil_div(hidden_size, 128), ceil_div(intermediate_size, 128), @@ -1013,7 +1005,7 @@ def test_moe_fp8_block_scaling( w31_scales.data[expert_id].copy_(w31_s) w2_quant.data[expert_id].copy_(w2) w2_scales.data[expert_id].copy_(w2_s) - # Dequantize for verificationa + # Dequantize for verification x_dequant = dequantize_block(x_quant, x_scales, x.dtype, x.shape) w31_dequant = dequantize_block( w31_quant, w31_scales, w31_weight.dtype, w31_weight.shape @@ -1021,7 +1013,7 @@ def test_moe_fp8_block_scaling( w2_dequant = dequantize_block(w2_quant, w2_scales, w2_weight.dtype, w2_weight.shape) # Run reference implementation with dequantized tensors - _ref_output = compute_with_experts( + ref_output = compute_with_experts( num_experts, x_dequant, w31_dequant, @@ -1029,16 +1021,16 @@ def test_moe_fp8_block_scaling( selected_experts, routing_weights, ) - quant_scales = [ - w31_scales, # .view(-1), # W31 scales - w2_scales, # .view(-1), # W2 scales - ] - # Call flashinfer implementation with block scaling and expect NotImplementedError - with pytest.raises( - NotImplementedError, - match="DeepSeek FP8 Block Scaling is not yet implemented in CUTLASS for Blackwell", - ): + flash_output = torch.zeros_like(x) + + execption_context = ( + pytest.raises(NotImplementedError) + if torch.cuda.get_device_capability()[0] != 9 + else nullcontext() + ) + + with execption_context: _ = fused_moe.cutlass_fused_moe( x.contiguous(), selected_experts.to(torch.int), @@ -1046,12 +1038,13 @@ def test_moe_fp8_block_scaling( w31_quant.contiguous(), w2_quant.contiguous(), otype, - tp_size=1, - tp_rank=0, use_deepseek_fp8_block_scale=True, - quant_scales=quant_scales, + quant_scales=[w31_scales.contiguous(), w2_scales.contiguous()], + output=flash_output, ) + torch.testing.assert_close(flash_output, ref_output, rtol=1e-1, atol=1e-1) + def quant_mxfp4_batches(a, num_experts): quant_a = []