Skip to content

Commit 2be1dcf

Browse files
authored
CUDA JIT Support (#1071)
Implement runtime compilation (RTC/JIT) with C++20 logging and operator improvements. Note that with this commit C++20 is now *required* to use MatX. This commit series introduces major infrastructure changes to support Just-In-Time (JIT) compilation via NVRTC, along with a modern logging system and extensive operator refactoring. ### Major Changes: **Runtime Compilation (RTC/JIT):** - Add comprehensive NVRTC support with new jit_cuda.h executor (573 lines) - Create jit_kernel infrastructure for runtime kernel compilation - Implement cuda_executor_common.h with shared executor functionality - Add operator_options.h for JIT compilation control - Introduce cufftdx support with new FFT implementation (389 lines) - Add FindMathDx.cmake module for cuFFTDx library detection - Create jit_includes.h to manage headers for JIT compilation - Enhance cache.h with JIT compilation caching (471 lines, +408 additions) **C++20 Logging System:** - Implement zero-overhead logging system in log.h (334 lines) - Add comprehensive logging documentation (544 lines) - Create test_logging.cu and test_logging_comprehensive.cu test suites - Refactor logging calls throughout operators and transforms - Support compile-time log level filtering for zero runtime overhead **Type System Refactoring:** - Split type_utils.h into type_utils.h and type_utils_both.h - Move 1000+ lines to type_utils_both.h for host/device compatibility - Add reduce_utils.h (104 lines) for reduction operations - Fix missing <cuda/std/utility> includes **Operator Enhancements:** - Add new apply_idx operator with comprehensive test suite (706 test lines) - Create scalar_internal.h (283 lines) for internal scalar operations - Refactor 100+ operators to support JIT compilation - Update FFT operators with cuFFTDx integration - Enhance cumsum, sort, sum, and other reduction operators - Improve operator_utils.h with better JIT support (368 lines) **Infrastructure Improvements:** - Refactor cuda.h executor (426 lines restructured) - Update tensor_impl.h with better JIT support (+302 lines) - Enhance capabilities.h for better compile-time feature detection - Improve get_grid_dims.h for kernel launch configuration - Add cub_device.h (152 lines) for CUB device operations **Documentation & Examples:** - Update fusion.rst with JIT compilation examples - Revise build.rst with new build options - Update developer_guide for operator development - Enhance black_scholes.cu example with JIT features This represents a significant architectural enhancement enabling runtime code generation and compilation, improving performance through JIT optimization, and providing a modern, efficient logging framework.
1 parent ad55c6b commit 2be1dcf

File tree

196 files changed

+10300
-3906
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

196 files changed

+10300
-3906
lines changed

CMakeLists.txt

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ option(MATX_EN_COVERAGE OFF "Enable code coverage reporting")
7979
option(MATX_EN_COMPLEX_OP_NAN_CHECKS "Enable full NaN/Inf handling for complex multiplication and division" OFF)
8080
option(MATX_EN_CUDA_LINEINFO "Enable line information for CUDA kernels via -lineinfo nvcc flag" OFF)
8181
option(MATX_EN_EXTENDED_LAMBDA "Enable extended lambda support for device/host lambdas" ON)
82+
option(MATX_EN_MATHDX "Enable MathDx support for kernel fusion" OFF)
8283

8384
set(MATX_EN_PYBIND11 OFF CACHE BOOL "Enable pybind11 support")
8485

@@ -96,9 +97,9 @@ if (MATX_BUILD_DOCS)
9697
add_subdirectory(docs_input)
9798
endif()
9899

99-
# MatX requires C++17 to build. Enforce on all libraries pulled in as well
100-
set(CMAKE_CXX_STANDARD 17)
101-
set(CUDA_CXX_STANDARD 17)
100+
# MatX requires C++20 to build. Enforce on all libraries pulled in as well
101+
set(CMAKE_CXX_STANDARD 20)
102+
set(CUDA_CXX_STANDARD 20)
102103

103104
if ("${CMAKE_CXX_COMPILER_ID}" MATCHES "GNU")
104105
execute_process(COMMAND ${CMAKE_CXX_COMPILER} -dumpversion OUTPUT_VARIABLE GCC_VERSION)
@@ -124,6 +125,8 @@ target_include_directories(matx INTERFACE "$<BUILD_INTERFACE:${CMAKE_CURRENT_SOU
124125
"$<INSTALL_INTERFACE:include>")
125126
target_include_directories(matx INTERFACE "$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include/matx/kernels>"
126127
"$<INSTALL_INTERFACE:include/matx/kernels>")
128+
129+
127130
target_compile_features(matx INTERFACE cxx_std_17 $<BUILD_INTERFACE:cuda_std_17>)
128131

129132
# 11.2 and above required for async allocation
@@ -141,7 +144,7 @@ target_link_libraries(matx INTERFACE CCCL::CCCL)
141144

142145
# Set flags for compiling tests faster (only for nvcc)
143146
if (NOT CMAKE_CUDA_COMPILER_ID STREQUAL "Clang")
144-
set(MATX_CUDA_FLAGS ${CMAKE_CUDA_FLAGS} --threads 0 -ftemplate-backtrace-limit=0)
147+
set(MATX_CUDA_FLAGS ${CMAKE_CUDA_FLAGS} --threads 0 -ftemplate-backtrace-limit=0 --extended-lambda)
145148
endif()
146149

147150
# Hack because CMake doesn't have short circult evaluation
@@ -304,6 +307,31 @@ if (MATX_EN_CUTENSOR)
304307
target_link_libraries(matx INTERFACE "-Wl,--disable-new-dtags")
305308
endif()
306309

310+
if (MATX_EN_MATHDX)
311+
set(MathDx_VERSION 25.06)
312+
set(MathDx_NANO 0)
313+
include(cmake/FindMathDx.cmake)
314+
target_compile_definitions(matx INTERFACE MATX_EN_MATHDX)
315+
target_compile_definitions(matx INTERFACE MATX_EN_JIT)
316+
317+
# Add NVRTC configuration as compiler definitions
318+
list(GET CMAKE_CUDA_ARCHITECTURES 0 NVRTC_CUDA_ARCH)
319+
# Strip -real or -virt postfix if present
320+
string(REGEX REPLACE "-real$" "" NVRTC_CUDA_ARCH "${NVRTC_CUDA_ARCH}")
321+
string(REGEX REPLACE "-virtual$" "" NVRTC_CUDA_ARCH "${NVRTC_CUDA_ARCH}")
322+
target_compile_definitions(matx INTERFACE NVRTC_CUDA_ARCH="${NVRTC_CUDA_ARCH}")
323+
target_compile_definitions(matx INTERFACE NVRTC_CXX_STANDARD="${CMAKE_CXX_STANDARD}")
324+
325+
# Link libmathdx if available
326+
if(TARGET libmathdx::libmathdx)
327+
target_link_libraries(matx INTERFACE libmathdx::libmathdx)
328+
message(STATUS "Linked libmathdx to matx target")
329+
endif()
330+
331+
# Link mathdx components
332+
target_link_libraries(matx INTERFACE mathdx::cufftdx CUDA::nvrtc)
333+
endif()
334+
307335
if (MATX_EN_CUDSS)
308336
set(cuDSS_VERSION 0.7.0.20)
309337
include(cmake/FindcuDSS.cmake)

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ are necessary
5050
## Requirements
5151
MatX support is currently limited to **Linux only** due to the time to test Windows. If you'd like to voice your support for native Windows support using Visual Studio, please comment on the issue here: https://github.com/NVIDIA/MatX/issues/153.
5252

53-
**Note**: CUDA 12.0.0 through 12.2.0 have an issue that causes building MatX unit tests to show a compiler error or cause a segfault in the compiler. Please use CUDA 11.8 or CUDA 12.2.1+ with MatX.
53+
**Note**: CUDA 12.0.0 through 12.2.0 have an issue that causes building MatX unit tests to show a compiler error or cause a segfault in the compiler. Please use CUDA 12.2.1+ with MatX.
5454

55-
MatX is using features in C++17 and the latest CUDA compilers and libraries. For this reason, when running with GPU support, CUDA 11.8 and g++9, nvc++ 24.5, or clang 17 or newer is required. You can download the CUDA Toolkit [here](https://developer.nvidia.com/cuda-downloads).
55+
MatX is using features in C++20 and the latest CUDA compilers and libraries. For this reason, when running with GPU support, CUDA 12.2.1 and g++9, nvc++ 24.5, or clang 17 or newer is required. You can download the CUDA Toolkit [here](https://developer.nvidia.com/cuda-downloads).
5656

5757
MatX has been tested on and supports Volta, Ampere, Ada, Hopper, and Blackwell GPU architectures. Jetson products are supported with Jetpack 5.0 or above.
5858

cmake/FindMathDx.cmake

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
#=============================================================================
2+
# Copyright (c) 2021, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#=============================================================================
16+
17+
#[=======================================================================[.rst:
18+
FindMathDx
19+
--------
20+
21+
Find MathDx
22+
23+
Imported targets
24+
^^^^^^^^^^^^^^^^
25+
26+
This module defines the following :prop_tgt:`IMPORTED` target(s):
27+
28+
``MathDx::MathDx``
29+
The MathDx library, if found.
30+
31+
Result variables
32+
^^^^^^^^^^^^^^^^
33+
34+
This module will set the following variables in your project:
35+
36+
``MathDx_FOUND``
37+
True if MathDx is found.
38+
``MathDx_INCLUDE_DIRS``
39+
The include directories needed to use MathDx.
40+
``MathDx_VERSION_STRING``
41+
The version of the MathDx library found. [OPTIONAL]
42+
43+
#]=======================================================================]
44+
set(MathDx_VERSION_FULL ${MathDx_VERSION}.${MathDx_NANO})
45+
46+
# Prefer using a Config module if it exists for this project
47+
set(MathDx_NO_CONFIG FALSE)
48+
if(NOT MathDx_NO_CONFIG)
49+
find_package(MathDx CONFIG QUIET HINTS ${MathDx_DIR})
50+
if(MathDx_FOUND)
51+
find_package_handle_standard_args(MathDx DEFAULT_MSG MathDx_CONFIG)
52+
return()
53+
endif()
54+
endif()
55+
56+
find_path(MathDx_INCLUDE_DIR NAMES MathDx.h)
57+
58+
# Search for the MathDx library
59+
find_library(MathDx_LIBRARY
60+
NAMES MathDx mathdx
61+
HINTS ${MathDx_DIR}
62+
PATH_SUFFIXES lib lib64
63+
)
64+
65+
include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake)
66+
67+
find_package_handle_standard_args(MathDx
68+
REQUIRED_VARS MathDx_LIBRARY MathDx_INCLUDE_DIR
69+
VERSION_VAR )
70+
71+
if(NOT MathDx_FOUND)
72+
set(MathDx_FILENAME libMathDx-linux-x86_64-${MathDx_VERSION}-archive)
73+
74+
message(STATUS "MathDx not found. Downloading library. By continuing this download you accept to the license terms of MathDx")
75+
76+
CPMAddPackage(
77+
NAME MathDx
78+
VERSION ${MathDx_VERSION}
79+
URL https://developer.download.nvidia.com/compute/cuFFTDx/redist/cuFFTDx/nvidia-mathdx-${MathDx_VERSION_FULL}.tar.gz
80+
DOWNLOAD_ONLY YES
81+
)
82+
endif()
83+
84+
# Download libmathdx based on CUDA version and platform
85+
# Detect CUDA version (12 or 13)
86+
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
87+
set(LIBMATHDX_CUDA_VERSION "cuda13")
88+
set(LIBMATHDX_CUDA_SUFFIX "cuda13.0")
89+
else()
90+
set(LIBMATHDX_CUDA_VERSION "cuda12")
91+
set(LIBMATHDX_CUDA_SUFFIX "cuda12.0")
92+
endif()
93+
94+
# Detect platform
95+
if(WIN32)
96+
set(LIBMATHDX_PLATFORM "win32-x86_64")
97+
set(LIBMATHDX_EXT "zip")
98+
elseif(CMAKE_SYSTEM_NAME STREQUAL "Linux")
99+
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
100+
set(LIBMATHDX_PLATFORM "Linux-aarch64")
101+
else()
102+
set(LIBMATHDX_PLATFORM "Linux-x86_64")
103+
endif()
104+
set(LIBMATHDX_EXT "tar.gz")
105+
else()
106+
message(WARNING "Unsupported platform for libmathdx download")
107+
endif()
108+
109+
# Set libmathdx version
110+
set(LIBMATHDX_VERSION "0.2.3")
111+
112+
# Download libmathdx if platform is supported
113+
if(DEFINED LIBMATHDX_PLATFORM)
114+
set(LIBMATHDX_URL "https://developer.nvidia.com/downloads/compute/cublasdx/redist/cublasdx/${LIBMATHDX_CUDA_VERSION}/libmathdx-${LIBMATHDX_PLATFORM}-${LIBMATHDX_VERSION}-${LIBMATHDX_CUDA_SUFFIX}.${LIBMATHDX_EXT}")
115+
116+
message(STATUS "Downloading libmathdx for ${LIBMATHDX_PLATFORM} with ${LIBMATHDX_CUDA_VERSION}")
117+
message(STATUS "libmathdx URL: ${LIBMATHDX_URL}")
118+
119+
CPMAddPackage(
120+
NAME libmathdx
121+
VERSION ${LIBMATHDX_VERSION}
122+
URL ${LIBMATHDX_URL}
123+
DOWNLOAD_ONLY YES
124+
)
125+
126+
# Add libmathdx to the search paths
127+
set(LIBMATHDX_ROOT "${PROJECT_BINARY_DIR}/_deps/libmathdx-src")
128+
list(APPEND CMAKE_PREFIX_PATH "${LIBMATHDX_ROOT}")
129+
130+
# Find libmathdx library file
131+
find_library(LIBMATHDX_LIBRARY
132+
NAMES mathdx libmathdx
133+
PATHS "${LIBMATHDX_ROOT}/lib"
134+
NO_DEFAULT_PATH
135+
)
136+
137+
# Set include directories (in both local and parent scope)
138+
set(LIBMATHDX_INCLUDE_DIR "${LIBMATHDX_ROOT}/include")
139+
set(LIBMATHDX_INCLUDE_DIR "${LIBMATHDX_INCLUDE_DIR}" PARENT_SCOPE)
140+
141+
if(LIBMATHDX_LIBRARY AND EXISTS ${LIBMATHDX_INCLUDE_DIR})
142+
message(STATUS "Found libmathdx library: ${LIBMATHDX_LIBRARY}")
143+
message(STATUS "Found libmathdx include dir: ${LIBMATHDX_INCLUDE_DIR}")
144+
145+
# Create libmathdx target
146+
if(NOT TARGET libmathdx::libmathdx)
147+
add_library(libmathdx::libmathdx INTERFACE IMPORTED)
148+
set_target_properties(libmathdx::libmathdx PROPERTIES
149+
INTERFACE_INCLUDE_DIRECTORIES "${LIBMATHDX_INCLUDE_DIR}"
150+
INTERFACE_LINK_LIBRARIES "${LIBMATHDX_LIBRARY}"
151+
)
152+
endif()
153+
else()
154+
message(WARNING "Could not find libmathdx library or include directory after download")
155+
endif()
156+
endif()
157+
158+
find_package(mathdx REQUIRED COMPONENTS cufftdx CONFIG
159+
PATHS
160+
"${PROJECT_BINARY_DIR}/_deps/mathdx-src/nvidia/mathdx/${MathDx_VERSION}/lib/cmake/mathdx/"
161+
"${PROJECT_BINARY_DIR}/_deps/libmathdx-src/lib/cmake/libmathdx/"
162+
"${PROJECT_BINARY_DIR}/_deps/libmathdx-src"
163+
"/opt/nvidia/mathdx/${MathDx_VERSION_FULL}"
164+
)
165+

cmake/versions.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
{
22
"packages": {
33
"CCCL": {
4-
"version": "3.0.0",
4+
"version": "3.2.0",
55
"git_shallow": false,
66
"git_url": "https://github.com/NVIDIA/cccl.git",
7-
"git_tag": "e944297"
7+
"git_tag": "0320434"
88
},
99
"nvbench" : {
1010
"version" : "0.0",

docs_input/api/synchronization/sync.rst

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
sync
44
====
55

6-
Wait for any code running on an executor to complete.
7-
8-
.. doxygenfunction:: matx::cudaExecutor::sync()
9-
.. doxygenfunction:: matx::HostExecutor::sync()
6+
Wait for any code running on an executor to complete. For CUDA executors this typically synchronizes
7+
the stream backing the executor, while host executors wait until the calling thread completes.
108

119
Examples
1210
~~~~~~~~

docs_input/basics/fusion.rst

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33
Operator Fusion
44
###############
55

6+
MatX supports operator fusion for all element-wise operators, and CUDA JIT kernel fusion for math functions with a
7+
supporting MathDx function. JIT kernel fusion is considered *experimental* currently and may contain bugs that don't
8+
occur with JIT enabled.
9+
10+
Element-wise Operator Fusion
11+
============================
12+
613
When writing a simple arithmetic expression like the following:
714

815
.. code-block:: cpp
@@ -43,4 +50,57 @@ expressions, this opens the possibility to selectively fuse more complex express
4350
4451
The type system can see that we have a multiply where the right-hand side is an FFT transform and the left side is another
4552
operator. This allows MatX to potentially fuse the output of the FFT with a multiply of B at compile-time. In general, the
46-
more information it can deduce during compilation and runtime, the better the performance will be.
53+
more information it can deduce during compilation and runtime, the better the performance will be.
54+
55+
CUDA JIT Kernel Fusion
56+
======================
57+
58+
.. note::
59+
60+
CUDA JIT kernel fusion is considered an experimental feature. There may be bugs that don't occur with JIT disabled, and new features are being added over time.
61+
62+
MatX supports CUDA JIT kernel fusion that compiles the entire expression into a single kernel. Currently this is enabled
63+
for all standard MatX element-wise operators and FFT operations via MathDx. To enable fusion with MathDx,
64+
the following options must be enabled: ``-DMATX_EN_MATHDX=ON``. Once enabled, the ``CUDAJITExecutor`` can be used perform JIT compilation
65+
in supported situations. If the expression cannot be JIT compiled, the JITExecutor will fall back to the normal non-JIT path.
66+
67+
While JIT compilation can provide a large performance boost, there are two overheads that occur when using JIT compilation:
68+
- The first pass to JIT the code takes time. The first time a ``run()`` statement is executed on a new operator, MatX identifies this and performs JIT compilation. Depending on the complexity of the operator, this could be anywhere from milliseconds to seconds to complete. Once finished, MatX will cache the compiled kernel so that subsequent runs of the same operator will not require JIT compilation.
69+
- A lookup is done to find kernels that have already been compiled. This is a small overhead and may not be noticeable.
70+
71+
As mentioned above, there is no difference in syntax between MatX statements that perform JIT compilation and those that do not. The executor
72+
is the only change, just as it would be with a host executor. For example, in the following code:
73+
74+
.. code-block:: cpp
75+
76+
(A = B * fft(C)).run(CUDAExecutor{});
77+
(A = B * fft(C)).run(CUDAJITExecutor{});
78+
79+
When MathDx is disabled, the the first statement will execute the FFT into a temporary buffer, then the multiply will be executed. This results
80+
in a minimum of 2 kernels (one for MatX and at least one for cuFFT). The second statement will execute the FFT and multiply in a single kernel if
81+
possible.
82+
83+
Some operators cannot be JIT compiled. For example, if the FFT above is a size not compatible with the cuFFTDx library or if MathDx is disabled
84+
the expression will not be JIT compiled. To determine if an operator can be JIT compiled, use the ``matx::jit_supported(op)`` function:
85+
86+
.. code-block:: cpp
87+
88+
auto my_op = (fft(b) + c);
89+
if (matx::jit_supported(my_op)) {
90+
printf("FFT is supported by JIT\n");
91+
} else {
92+
printf("FFT is not supported by JIT\n");
93+
}
94+
95+
Even if the MathDx library supports a particular operation, other operators in the expression may prevent JIT compilation. For
96+
example:
97+
98+
.. code-block:: cpp
99+
100+
auto my_op = (fftshift1D(fft(b)));
101+
102+
In this case the MathDx library requires at least 2 elements per thread for the FFT, but the ``fftshift1D`` operator requires
103+
only 1 element per thread. Therefore, the entire expression cannot be JIT-compiled and will fall back to the non-JIT path. Some of
104+
these restrictions may be relaxed in newer versions of MatX or the MathDx library.
105+
106+

0 commit comments

Comments
 (0)