Skip to content

Commit 9d6546e

Browse files
authored
[CUDA] fp16 intB gemm (microsoft#24854)
### Description * Add fpA intB gemm kernel from WeightOnlyGroupwiseQuantGemmPlugin of TensorRT-LLM. * Add prepacking to convert weight/scales/zero_points to adapt MatMulNBits to use the kernel. Limitations: * Only enable fp16 kernel. BF16 support will be added later. * Requires zero points. The support of scales only might be added later. * Bias is not enabled since previous MatMulNBits kernel does not support bias. ### Motivation and Context To improve performance of LLM. Initial result shows 2.2x throughput on prompt processing and 1.25X throughput on token generation using onnxruntime-genai benchmark_e2e.py on phi-4-mini-instruct on A100.
1 parent cd9d5fc commit 9d6546e

File tree

80 files changed

+15173
-80
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+15173
-80
lines changed

cmake/CMakeLists.txt

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,10 @@ set(ONNXRUNTIME_PROVIDER_NAMES cpu)
859859
set(ORT_PROVIDER_FLAGS)
860860

861861
if (onnxruntime_USE_CUDA)
862+
include(cuda_configuration)
863+
setup_cuda_compiler()
864+
setup_cuda_architectures()
865+
862866
enable_language(CUDA)
863867
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")
864868

@@ -878,9 +882,6 @@ if (onnxruntime_USE_CUDA)
878882
set(onnxruntime_USE_FLASH_ATTENTION OFF)
879883
endif()
880884

881-
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4)
882-
message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4")
883-
endif()
884885
if (WIN32)
885886
message( STATUS "Lean Attention unsupported in Windows")
886887
set(onnxruntime_USE_LEAN_ATTENTION OFF)
@@ -1590,25 +1591,17 @@ if (onnxruntime_USE_CUDA)
15901591
file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME})
15911592
endif()
15921593
find_package(CUDAToolkit REQUIRED)
1593-
if (NOT CMAKE_CUDA_ARCHITECTURES)
1594-
# Note that we generate SASS+PTX code for specified cuda architectures by assigning "xy"
1595-
# To add SASS only, assign "xy-real"
1596-
# To add PTX only, assign "xy-virtual"
1597-
if (CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu")
1598-
# Support for Jetson/Tegra ARM devices
1599-
set(CMAKE_CUDA_ARCHITECTURES "53-real;62-real;72-real;87") # TX1/Nano, TX2, Xavier, Orin
1600-
else()
1601-
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12)
1602-
# 37, 50 still work in CUDA 11 but are marked deprecated and will be removed in future CUDA version.
1603-
set(CMAKE_CUDA_ARCHITECTURES "37-real;50-real;52-real;60-real;70-real;75-real;80-real;86-real;89")
1604-
elseif (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8)
1605-
set(CMAKE_CUDA_ARCHITECTURES "52-real;60-real;70-real;75-real;80-real;86-real;89-real;90")
1606-
else()
1607-
# https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html
1608-
set(CMAKE_CUDA_ARCHITECTURES "all") # Supporting all, including latest Blackwell B series & RTX 50 series
1609-
endif()
1610-
endif()
1594+
1595+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.8)
1596+
add_definitions("-DENABLE_FP8")
1597+
message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_FP8 flag")
16111598
endif()
1599+
1600+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
1601+
add_definitions("-DENABLE_FP4")
1602+
message(STATUS "CUDA Toolkit version is greater or equal than 12.8, enable -DENABLE_FP4 flag")
1603+
endif()
1604+
16121605
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xfatbin=-compress-all")
16131606
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
16141607
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --Werror default-stream-launch")
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
6+
# the License. You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
11+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
12+
# specific language governing permissions and limitations under the License.
13+
#
14+
15+
macro(setup_cuda_compiler)
16+
# Determine CUDA version before enabling the language extension check_language(CUDA) clears CMAKE_CUDA_HOST_COMPILER
17+
# if CMAKE_CUDA_COMPILER is not set
18+
include(CheckLanguage)
19+
if(NOT CMAKE_CUDA_COMPILER AND CMAKE_CUDA_HOST_COMPILER)
20+
set(CMAKE_CUDA_HOST_COMPILER_BACKUP ${CMAKE_CUDA_HOST_COMPILER})
21+
endif()
22+
check_language(CUDA)
23+
if(CMAKE_CUDA_HOST_COMPILER_BACKUP)
24+
set(CMAKE_CUDA_HOST_COMPILER ${CMAKE_CUDA_HOST_COMPILER_BACKUP})
25+
check_language(CUDA)
26+
endif()
27+
if(CMAKE_CUDA_COMPILER)
28+
message(STATUS "CUDA compiler: ${CMAKE_CUDA_COMPILER}")
29+
if(NOT WIN32) # Linux
30+
execute_process(
31+
COMMAND "bash" "-c" "${CMAKE_CUDA_COMPILER} --version | grep -E -o 'V[0-9]+.[0-9]+.[0-9]+' | cut -c2-"
32+
RESULT_VARIABLE _BASH_SUCCESS
33+
OUTPUT_VARIABLE CMAKE_CUDA_COMPILER_VERSION
34+
OUTPUT_STRIP_TRAILING_WHITESPACE)
35+
36+
if(NOT _BASH_SUCCESS EQUAL 0)
37+
message(FATAL_ERROR "Failed to determine CUDA version")
38+
endif()
39+
40+
else() # Windows
41+
execute_process(
42+
COMMAND ${CMAKE_CUDA_COMPILER} --version
43+
OUTPUT_VARIABLE versionString
44+
RESULT_VARIABLE versionResult)
45+
46+
if(versionResult EQUAL 0 AND versionString MATCHES "V[0-9]+\\.[0-9]+\\.[0-9]+")
47+
string(REGEX REPLACE "V" "" version ${CMAKE_MATCH_0})
48+
set(CMAKE_CUDA_COMPILER_VERSION "${version}")
49+
else()
50+
message(FATAL_ERROR "Failed to determine CUDA version")
51+
endif()
52+
endif()
53+
else()
54+
message(FATAL_ERROR "No CUDA compiler found")
55+
endif()
56+
57+
set(CUDA_REQUIRED_VERSION "11.4")
58+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS CUDA_REQUIRED_VERSION)
59+
message(FATAL_ERROR "CUDA version ${CMAKE_CUDA_COMPILER_VERSION} must be at least ${CUDA_REQUIRED_VERSION}")
60+
endif()
61+
endmacro()
62+
63+
macro(setup_cuda_architectures)
64+
# cmake-format: off
65+
# Initialize and normalize CMAKE_CUDA_ARCHITECTURES before enabling CUDA.
66+
# Special values:
67+
# (1) `native` is resolved to HIGHEST available architecture. Fallback to `all` if detection failed.
68+
# (2) `all` / `all-major` / unset is resolved to a default set of architectures we optimized and compiler supports.
69+
# Numerical architectures:
70+
# * For `-virtual` architectures, the last one is kept as it is, and the others are ignored.
71+
# * `-real` suffix is automatically added for other cases.
72+
# * Always use accelerated (`-a` suffix) target for supported real architectures.
73+
# cmake-format: on
74+
75+
if(CMAKE_CUDA_ARCHITECTURES STREQUAL "native")
76+
# Detect highest available compute capability
77+
set(OUTPUTFILE ${PROJECT_BINARY_DIR}/detect_cuda_arch)
78+
set(CUDAFILE ${CMAKE_SOURCE_DIR}/utils/detect_cuda_arch.cu)
79+
execute_process(COMMAND ${CMAKE_CUDA_COMPILER} -lcuda ${CUDAFILE} -o ${OUTPUTFILE})
80+
message(VERBOSE "Detecting native CUDA compute capability")
81+
execute_process(
82+
COMMAND ${OUTPUTFILE}
83+
RESULT_VARIABLE CUDA_RETURN_CODE
84+
OUTPUT_VARIABLE CUDA_ARCH_OUTPUT)
85+
if(NOT ${CUDA_RETURN_CODE} EQUAL 0)
86+
message(WARNING "Detecting native CUDA compute capability - fail")
87+
message(WARNING "CUDA compute capability detection failed, compiling for all optimized architectures")
88+
unset(CMAKE_CUDA_ARCHITECTURES)
89+
else()
90+
message(STATUS "Detecting native CUDA compute capability - done")
91+
set(CMAKE_CUDA_ARCHITECTURES "${CUDA_ARCH_OUTPUT}")
92+
endif()
93+
elseif(CMAKE_CUDA_ARCHITECTURES STREQUAL "all")
94+
unset(CMAKE_CUDA_ARCHITECTURES)
95+
message(STATUS "Setting CMAKE_CUDA_ARCHITECTURES to all enables a list of architectures OnnxRuntime optimized for, "
96+
"not all architectures CUDA compiler supports.")
97+
elseif(CMAKE_CUDA_ARCHITECTURES STREQUAL "all-major")
98+
unset(CMAKE_CUDA_ARCHITECTURES)
99+
message(
100+
STATUS "Setting CMAKE_CUDA_ARCHITECTURES to all-major enables a list of architectures OnnxRuntime optimized for, "
101+
"not all major architectures CUDA compiler supports.")
102+
else()
103+
message(STATUS "Original CMAKE_CUDA_ARCHITECTURES : ${CMAKE_CUDA_ARCHITECTURES}")
104+
endif()
105+
106+
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
107+
if(CMAKE_LIBRARY_ARCHITECTURE STREQUAL "aarch64-linux-gnu")
108+
# Support for Jetson/Tegra ARM devices
109+
set(CMAKE_CUDA_ARCHITECTURES "53;62;72;87") # TX1/Nano, TX2, Xavier, Orin
110+
else()
111+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12)
112+
# 37, 50 still work in CUDA 11 but are marked deprecated and will be removed in future CUDA version.
113+
set(CMAKE_CUDA_ARCHITECTURES "37;50;52;60;70;75;80;86;89")
114+
elseif(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8)
115+
set(CMAKE_CUDA_ARCHITECTURES "52;60;70;75;80;86;89;90")
116+
else()
117+
set(CMAKE_CUDA_ARCHITECTURES "60;70;75;80;86;89;90;100;120")
118+
endif()
119+
endif()
120+
endif()
121+
122+
unset(CMAKE_CUDA_ARCHITECTURES_CLEAN)
123+
unset(CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL)
124+
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
125+
if(CUDA_ARCH STREQUAL "")
126+
continue()
127+
endif()
128+
129+
if(CUDA_ARCH MATCHES "^([1-9])([0-9])+a?-virtual$")
130+
set(CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL ${CUDA_ARCH})
131+
elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)a?-real$")
132+
list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1})
133+
elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)a?$")
134+
list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1})
135+
else()
136+
message(FATAL_ERROR "Unrecognized CUDA architecture: ${CUDA_ARCH}")
137+
endif()
138+
endforeach()
139+
list(REMOVE_DUPLICATES CMAKE_CUDA_ARCHITECTURES_CLEAN)
140+
set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_CLEAN})
141+
142+
# CMAKE_CUDA_ARCHITECTURES_ORIG contains all architectures enabled, without automatically added -real or -a suffix.
143+
set(CMAKE_CUDA_ARCHITECTURES_ORIG "${CMAKE_CUDA_ARCHITECTURES}")
144+
message(STATUS "GPU architectures: ${CMAKE_CUDA_ARCHITECTURES_ORIG}")
145+
146+
set(ARCHITECTURES_WITH_KERNELS "80" "86" "89" "90" "100" "120")
147+
foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS)
148+
if(NOT "${CUDA_ARCH}" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
149+
add_definitions("-DEXCLUDE_SM_${CUDA_ARCH}")
150+
message(STATUS "Excluding SM ${CUDA_ARCH}")
151+
endif()
152+
endforeach()
153+
154+
# Enable accelerated features (like WGMMA, TMA and setmaxnreg) for SM >= 90.
155+
set(ARCHITECTURES_WITH_ACCEL "90" "100" "101" "120")
156+
unset(CMAKE_CUDA_ARCHITECTURES_NORMALIZED)
157+
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
158+
if("${CUDA_ARCH}" IN_LIST ARCHITECTURES_WITH_ACCEL)
159+
list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}a-real")
160+
else()
161+
list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}-real")
162+
endif()
163+
endforeach()
164+
165+
if(DEFINED CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL)
166+
list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL}")
167+
endif()
168+
169+
set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_NORMALIZED})
170+
171+
message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}")
172+
endmacro()

cmake/onnxruntime_providers_cuda.cmake

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@
179179
set(onnxruntime_NVCC_THREADS "1" CACHE STRING "Number of threads that NVCC can use for compilation.")
180180
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">")
181181
endif()
182-
182+
183183
# Since CUDA 12.8, compiling diagnostics become stricter
184184
if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8)
185185
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--relocatable-device-code=true>")
@@ -261,6 +261,11 @@
261261
set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA)
262262
set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime")
263263

264+
if("90" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG)
265+
target_compile_options(${target} PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xptxas=-w>)
266+
target_compile_definitions(${target} PRIVATE COMPILE_HOPPER_TMA_GEMMS)
267+
endif()
268+
264269
if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling
265270
target_link_libraries(${target} PRIVATE CUDA::cupti)
266271
endif()

cmake/utils/detect_cuda_arch.cu

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include <algorithm>
2+
#include <cuda_runtime.h>
3+
#include <iomanip>
4+
#include <iostream>
5+
#include <vector>
6+
7+
int main(int argc, char* argv[])
8+
{
9+
int n_devices = 0;
10+
int rc = cudaGetDeviceCount(&n_devices);
11+
if (rc != cudaSuccess)
12+
{
13+
cudaError_t error = cudaGetLastError();
14+
std::cout << "CUDA error: " << cudaGetErrorString(error) << std::endl;
15+
return rc;
16+
}
17+
18+
std::vector<std::pair<int, int>> arch(n_devices);
19+
for (int cd = 0; cd < n_devices; ++cd)
20+
{
21+
cudaDeviceProp dev;
22+
int rc = cudaGetDeviceProperties(&dev, cd);
23+
if (rc != cudaSuccess)
24+
{
25+
cudaError_t error = cudaGetLastError();
26+
std::cout << "CUDA error: " << cudaGetErrorString(error) << std::endl;
27+
return rc;
28+
}
29+
else
30+
{
31+
arch[cd] = {dev.major, dev.minor};
32+
}
33+
}
34+
35+
std::pair<int, int> best_cc = *std::max_element(begin(arch), end(arch));
36+
std::cout << best_cc.first << best_cc.second;
37+
38+
return 0;
39+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#pragma once
18+
19+
#include <cuda_runtime_api.h>
20+
#include "core/providers/cuda/shared_inc/cuda_call.h"
21+
22+
namespace onnxruntime::llm::common {
23+
inline int getDevice() {
24+
int deviceID{0};
25+
CUDA_CALL_THROW(cudaGetDevice(&deviceID));
26+
return deviceID;
27+
}
28+
29+
inline int getSMVersion() {
30+
int device{-1};
31+
CUDA_CALL_THROW(cudaGetDevice(&device));
32+
int sm_major = 0;
33+
int sm_minor = 0;
34+
CUDA_CALL_THROW(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device));
35+
CUDA_CALL_THROW(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device));
36+
return sm_major * 10 + sm_minor;
37+
}
38+
39+
inline int getMultiProcessorCount() {
40+
int nSM{0};
41+
int deviceID{0};
42+
CUDA_CALL_THROW(cudaGetDevice(&deviceID));
43+
CUDA_CALL_THROW(cudaDeviceGetAttribute(&nSM, cudaDevAttrMultiProcessorCount, deviceID));
44+
return nSM;
45+
}
46+
} // namespace onnxruntime::llm::common
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/shared_library/provider_api.h"
7+
8+
#ifndef NDEBUG
9+
#define ORT_LLM_LOG_TRACE(msg) LOGS_DEFAULT(VERBOSE) << msg
10+
#define ORT_LLM_LOG_DEBUG(msg) LOGS_DEFAULT(VERBOSE) << msg
11+
#else
12+
#define ORT_LLM_LOG_TRACE(msg)
13+
#define ORT_LLM_LOG_DEBUG(msg)
14+
#endif
15+
16+
#define ORT_LLM_LOG_INFO(msg) LOGS_DEFAULT(INFO) << msg
17+
#define ORT_LLM_LOG_WARNING(msg) LOGS_DEFAULT(WARNING) << msg
18+
#define ORT_LLM_LOG_ERROR(msg) LOGS_DEFAULT(ERROR) << msg

0 commit comments

Comments
 (0)