|
27 | 27 | #include <string> |
28 | 28 | #include <vector> |
29 | 29 |
|
30 | | -#include "6kd_blockwise_gemm/sm120_fp8_gemm_1d2d.cuh" |
31 | 30 | #include "ada_blockwise_gemm/sm89_fp8_gemm_1d1d.cuh" |
32 | 31 | #include "fp8_blockscale_mma_utils.cuh" |
33 | 32 | #include "fp8_blockscale_tma_utils.cuh" |
| 33 | +#include "sm120_blockwise_gemm/sm120_fp8_gemm_1d1d.cuh" |
34 | 34 | #include "tensorrt_llm/common/config.h" |
35 | 35 | #include "tensorrt_llm/common/cudaUtils.h" |
36 | 36 | #include "tensorrt_llm/common/logger.h" |
@@ -1653,14 +1653,28 @@ void gemm_dispatch_sm120(void* mat_a, void* mat_b, void* mat_d, float* scales_a, |
1653 | 1653 | using Params = typename GemmKernel::Params; |
1654 | 1654 | using Arguments = typename GemmKernel::Arguments; |
1655 | 1655 | using ProblemShape = typename GemmKernel::ProblemShape; |
| 1656 | + ProblemShape problem_shape = make_shape((int) shape_m, (int) shape_n, (int) shape_k, 1); |
1656 | 1657 |
|
1657 | 1658 | auto ptr_A = reinterpret_cast<ElementInput*>(mat_a); |
1658 | 1659 | auto ptr_B = reinterpret_cast<ElementInput*>(mat_b); |
1659 | 1660 | auto ptr_SFA = reinterpret_cast<ElementBlockScale*>(scales_a); |
1660 | 1661 | auto ptr_SFB = reinterpret_cast<ElementBlockScale*>(scales_b); |
1661 | 1662 | auto ptr_D = reinterpret_cast<ElementOutput*>(mat_d); |
1662 | | - Arguments args = {ptr_A, ptr_B, ptr_SFA, ptr_SFB, ptr_D}; |
1663 | | - ProblemShape problem_shape = make_shape((int) shape_m, (int) shape_n, (int) shape_k, 1); |
| 1663 | + |
| 1664 | + int32_t ld_a = shape_k; |
| 1665 | + int32_t stride_a = shape_m * shape_k; |
| 1666 | + int32_t ld_b = shape_k; |
| 1667 | + int32_t stride_b = shape_n * shape_k; |
| 1668 | + int32_t ld_d = shape_n; |
| 1669 | + int32_t stride_d = shape_m * shape_n; |
| 1670 | + |
| 1671 | + typename KT::StrideA dA = make_stride(ld_a, Int<1>{}, stride_a); |
| 1672 | + typename KT::StrideB dB = make_stride(ld_b, Int<1>{}, stride_b); |
| 1673 | + typename KT::StrideSFA dSFA = KT::deduce_sfa_layout(problem_shape).stride(); |
| 1674 | + typename KT::StrideSFB dSFB = KT::deduce_sfb_layout(problem_shape).stride(); |
| 1675 | + typename KT::StrideD dD = make_stride(ld_d, Int<1>{}, stride_d); |
| 1676 | + |
| 1677 | + Arguments args = {ptr_A, dA, ptr_B, dB, ptr_SFA, dSFA, ptr_SFB, dSFB, ptr_D, dD}; |
1664 | 1678 |
|
1665 | 1679 | Params kernel_params = GemmKernel::to_underlying_arguments(problem_shape, args); |
1666 | 1680 | auto kernel_ptr = &cutlass::device_kernel<GemmKernel>; |
@@ -1914,6 +1928,65 @@ void strided_batch_gemm_dispatch_sm89(__nv_fp8_e4m3* mat_a, int ld_a, int stride |
1914 | 1928 | stride_scales_b); |
1915 | 1929 | } |
1916 | 1930 |
|
| 1931 | +void strided_batch_gemm_dispatch_sm120(__nv_fp8_e4m3* mat_a, int ld_a, int stride_a, __nv_fp8_e4m3* mat_b, int ld_b, |
| 1932 | + int stride_b, __nv_bfloat16* mat_d, int ld_d, int stride_d, float* scales_a, int stride_scales_a, float* scales_b, |
| 1933 | + uint32_t num_problems, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, cudaStream_t stream, |
| 1934 | + int num_device_sms = kNumDeviceSMs) |
| 1935 | +{ |
| 1936 | + if (num_device_sms < 0) |
| 1937 | + { |
| 1938 | + num_device_sms = kNumDeviceSMs = tensorrt_llm::common::getMultiProcessorCount(); |
| 1939 | + } |
| 1940 | + using ElementInput = cute::float_e4m3_t; |
| 1941 | + using ElementOutput = cute::bfloat16_t; |
| 1942 | + using ElementAccum = float; |
| 1943 | + using ElementBlockScale = int32_t; |
| 1944 | + using KT = sm120_blockscaled_gemm::SM120BlockScaledBuilder<32, 128>; |
| 1945 | + using GemmKernel = sm120_blockscaled_gemm::SM120BlockScaledKernel<KT>; |
| 1946 | + using Params = typename GemmKernel::Params; |
| 1947 | + using Arguments = typename GemmKernel::Arguments; |
| 1948 | + using ProblemShape = typename GemmKernel::ProblemShape; |
| 1949 | + ProblemShape problem_shape = make_shape((int) shape_m, (int) shape_n, (int) shape_k, (int) num_problems); |
| 1950 | + |
| 1951 | + auto ptr_A = reinterpret_cast<ElementInput*>(mat_a); |
| 1952 | + auto ptr_B = reinterpret_cast<ElementInput*>(mat_b); |
| 1953 | + auto ptr_SFA = reinterpret_cast<ElementBlockScale*>(scales_a); |
| 1954 | + auto ptr_SFB = reinterpret_cast<ElementBlockScale*>(scales_b); |
| 1955 | + auto ptr_D = reinterpret_cast<ElementOutput*>(mat_d); |
| 1956 | + |
| 1957 | + typename KT::StrideA dA = make_stride(ld_a, Int<1>{}, stride_a); |
| 1958 | + typename KT::StrideB dB = make_stride(ld_b, Int<1>{}, stride_b); |
| 1959 | + typename KT::StrideSFA dSFA = KT::deduce_sfa_layout(problem_shape).stride(); |
| 1960 | + typename KT::StrideSFB dSFB = KT::deduce_sfb_layout(problem_shape).stride(); |
| 1961 | + typename KT::StrideD dD = make_stride(ld_d, Int<1>{}, stride_d); |
| 1962 | + |
| 1963 | + Arguments args = {ptr_A, dA, ptr_B, dB, ptr_SFA, dSFA, ptr_SFB, dSFB, ptr_D, dD}; |
| 1964 | + |
| 1965 | + Params kernel_params = GemmKernel::to_underlying_arguments(problem_shape, args); |
| 1966 | + auto kernel_ptr = &cutlass::device_kernel<GemmKernel>; |
| 1967 | + |
| 1968 | + cudaFuncSetAttribute(kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmKernel::kSmemSize); |
| 1969 | + auto result = cudaGetLastError(); |
| 1970 | + TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 gemm kernel cannot launch: %s", cudaGetErrorString(result)); |
| 1971 | + |
| 1972 | + cudaLaunchConfig_t launch_config; |
| 1973 | + cudaLaunchAttribute attrs[1]; |
| 1974 | + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; |
| 1975 | + attrs[0].val.programmaticStreamSerializationAllowed = 1; |
| 1976 | + |
| 1977 | + launch_config.gridDim = GemmKernel::get_grid_shape(kernel_params); |
| 1978 | + launch_config.blockDim = GemmKernel::get_block_shape(); |
| 1979 | + launch_config.dynamicSmemBytes = GemmKernel::kSmemSize; |
| 1980 | + launch_config.stream = stream; |
| 1981 | + launch_config.attrs = attrs; |
| 1982 | + launch_config.numAttrs = 1; |
| 1983 | + |
| 1984 | + cudaLaunchKernelEx(&launch_config, kernel_ptr, kernel_params); |
| 1985 | + |
| 1986 | + result = cudaGetLastError(); |
| 1987 | + TLLM_CHECK_WITH_INFO(result == cudaSuccess, "sm120 gemm kernel runtime error: %s", cudaGetErrorString(result)); |
| 1988 | +} |
| 1989 | + |
1917 | 1990 | void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_mat_a, float* scales_a, int ld_a, |
1918 | 1991 | int stride_a, int stride_scales_a, __nv_bfloat16 const* mat_b, __nv_fp8_e4m3* fp8_mat_b, float* scales_b, int ld_b, |
1919 | 1992 | int stride_b, __nv_bfloat16* mat_d, int ld_d, int stride_d, uint32_t num_problems, uint32_t shape_m, |
@@ -1941,12 +2014,18 @@ void fp8_stride_batch_gemm_run(__nv_bfloat16 const* mat_a, __nv_fp8_e4m3* fp8_ma |
1941 | 2014 | } |
1942 | 2015 |
|
1943 | 2016 | int arch = tensorrt_llm::common::getSMVersion(); |
1944 | | - if (arch == 89 || arch == 120) |
| 2017 | + if (arch == 89) |
1945 | 2018 | { |
1946 | 2019 | strided_batch_gemm_dispatch_sm89(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, stride_d, |
1947 | 2020 | scales_a, stride_scales_a, scales_b, num_problems, shape_m, shape_n, shape_k, stream); |
1948 | 2021 | return; |
1949 | 2022 | } |
| 2023 | + if (arch == 120) |
| 2024 | + { |
| 2025 | + strided_batch_gemm_dispatch_sm120(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, stride_d, |
| 2026 | + scales_a, stride_scales_a, scales_b, num_problems, shape_m, shape_n, shape_k, stream); |
| 2027 | + return; |
| 2028 | + } |
1950 | 2029 | if (kDeepGemmEnabled) |
1951 | 2030 | { |
1952 | 2031 | strided_batch_gemm_dispatch(fp8_mat_a, ld_a, stride_a, fp8_mat_b, ld_b, stride_b, mat_d, ld_d, stride_d, |
|
0 commit comments