Skip to content

Commit e78cd18

Browse files
authored
Added JIT capabilities into all operators except transform operators. (#1085)
Tested with standalone unit tests, but will turn on real tests in subsequent commit.
1 parent 7051e00 commit e78cd18

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+4813
-466
lines changed

CMakeLists.txt

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ option(MATX_EN_COVERAGE OFF "Enable code coverage reporting")
7979
option(MATX_EN_COMPLEX_OP_NAN_CHECKS "Enable full NaN/Inf handling for complex multiplication and division" OFF)
8080
option(MATX_EN_CUDA_LINEINFO "Enable line information for CUDA kernels via -lineinfo nvcc flag" OFF)
8181
option(MATX_EN_EXTENDED_LAMBDA "Enable extended lambda support for device/host lambdas" ON)
82+
option(MATX_EN_JIT "Enable CUDA JIT compilation support via NVRTC" OFF)
8283
option(MATX_EN_MATHDX "Enable MathDx support for kernel fusion" OFF)
8384
option(MATX_EN_UNSAFE_ALIAS_DETECTION "Enable aliased memory detection" OFF)
8485
option(MATX_DISABLE_EXCEPTIONS "Disable C++ exceptions and log errors instead" OFF)
@@ -316,11 +317,9 @@ if (MATX_EN_CUTENSOR)
316317
target_link_libraries(matx INTERFACE "-Wl,--disable-new-dtags")
317318
endif()
318319

319-
if (MATX_EN_MATHDX)
320-
set(MathDx_VERSION 25.06)
321-
set(MathDx_NANO 0)
322-
include(cmake/FindMathDx.cmake)
323-
target_compile_definitions(matx INTERFACE MATX_EN_MATHDX)
320+
# Enable JIT compilation support
321+
if (MATX_EN_JIT OR MATX_EN_MATHDX)
322+
message(STATUS "Enabling JIT compilation support via NVRTC")
324323
target_compile_definitions(matx INTERFACE MATX_EN_JIT)
325324

326325
# Add NVRTC configuration as compiler definitions
@@ -331,14 +330,24 @@ if (MATX_EN_MATHDX)
331330
target_compile_definitions(matx INTERFACE NVRTC_CUDA_ARCH="${NVRTC_CUDA_ARCH}")
332331
target_compile_definitions(matx INTERFACE NVRTC_CXX_STANDARD="${CMAKE_CXX_STANDARD}")
333332

333+
# Link NVRTC library
334+
target_link_libraries(matx INTERFACE CUDA::nvrtc)
335+
endif()
336+
337+
if (MATX_EN_MATHDX)
338+
set(MathDx_VERSION 25.06)
339+
set(MathDx_NANO 0)
340+
include(cmake/FindMathDx.cmake)
341+
target_compile_definitions(matx INTERFACE MATX_EN_MATHDX)
342+
334343
# Link libmathdx if available
335344
if(TARGET libmathdx::libmathdx)
336345
target_link_libraries(matx INTERFACE libmathdx::libmathdx)
337346
message(STATUS "Linked libmathdx to matx target")
338347
endif()
339348

340349
# Link mathdx components
341-
target_link_libraries(matx INTERFACE mathdx::cufftdx CUDA::nvrtc)
350+
target_link_libraries(matx INTERFACE mathdx::cufftdx)
342351
endif()
343352

344353
if (MATX_EN_CUDSS)

include/matx/core/capabilities.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ namespace detail {
243243
__MATX_INLINE__ __MATX_HOST__ typename capability_attributes<Cap>::type
244244
get_operator_capability(const OperatorType& op, InType& in) {
245245
static_assert(std::is_same_v<remove_cvref_t<InType>, typename capability_attributes<Cap>::input_type>, "Input type mismatch");
246-
if constexpr (is_matx_op<OperatorType>()) {
246+
if constexpr (is_matx_jit_class<OperatorType>) {
247247
return op.template get_capability<Cap, InType>(in);
248248
} else {
249249
// Default capabilities for non-MatX ops

include/matx/core/get_grid_dims.h

Lines changed: 67 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ inline bool get_grid_dims_jit(dim3 &blocks, dim3 &threads, const cuda::std::arra
185185
blocks.y = 1;
186186
blocks.z = 1;
187187

188+
if (RANK > 1) {
189+
MATX_ASSERT_STR_EXP(sizes[sizes.size() - 2] % groups_per_block, 0, matxInvalidParameter, "Second to last dimension must be divisible by groups_per_block");
190+
}
191+
188192
// Dynamic logic to pick thread block size.
189193
// Fill in order x, y, z up to 1024 threads
190194
if constexpr (RANK == 0) {
@@ -216,70 +220,80 @@ inline bool get_grid_dims_jit(dim3 &blocks, dim3 &threads, const cuda::std::arra
216220

217221
// If we have multiple groups per block, we need to adjust the block size
218222
if (threads.y > 1) {
219-
blocks.x = static_cast<int>((static_cast<int64_t>(sizes[0]) + static_cast<int64_t>(threads.y) - 1) / static_cast<int64_t>(threads.y));
223+
blocks.x = static_cast<int>(static_cast<int64_t>(sizes[0]) / static_cast<int64_t>(threads.y));
220224
}
221225
else {
222226
blocks.x = static_cast<int>(sizes[0]);
223227
}
224-
}
225-
// We don't support JIT with rank 3 or higher yet
226-
// else if constexpr (RANK == 3) {
227-
// if (!force_size) {
228-
// while (nt < max_cta_size) {
229-
// if (static_cast<index_t>(threads.x) * ept < sizes[2]) {
230-
// threads.x *= 2;
231-
// }
228+
}
229+
else if constexpr (RANK == 3) {
230+
if (!force_size) {
231+
while (nt < max_cta_size) {
232+
if (static_cast<index_t>(threads.x) * ept < sizes[2]) {
233+
threads.x *= 2;
234+
}
235+
236+
nt *= 2;
237+
}
238+
}
232239

233-
// nt *= 2;
234-
// }
235-
// }
240+
// If we have multiple groups per block, we need to adjust the block size
241+
if (threads.y > 1) {
242+
blocks.x = static_cast<int>(static_cast<int64_t>(sizes[1]) / static_cast<int64_t>(threads.y));
243+
}
244+
else {
245+
blocks.x = static_cast<int>(sizes[1]);
246+
}
236247

237-
// // launch as many blocks as necessary
238-
// blocks.x = static_cast<int>(sizes[1]);
239-
// blocks.y = static_cast<int>(sizes[0]);
248+
// launch as many blocks as necessary
249+
blocks.y = static_cast<int>(sizes[0]);
240250

241-
// if(blocks.x > 65535) {
242-
// blocks.x = 65535;
243-
// stride = true;
244-
// }
245-
// if(blocks.y > 65535) {
246-
// blocks.y = 65535;
247-
// stride = true;
248-
// }
251+
if(blocks.x > 65535) {
252+
blocks.x = 65535;
253+
stride = true;
254+
}
255+
if(blocks.y > 65535) {
256+
blocks.y = 65535;
257+
stride = true;
258+
}
249259

250-
// }
251-
// else if constexpr (RANK == 4) {
252-
// if (!force_size) {
253-
// while (nt < max_cta_size) {
254-
// if (static_cast<index_t>(threads.x) * ept < sizes[3]) {
255-
// threads.x *= 2;
256-
// }
260+
}
261+
else if constexpr (RANK == 4) {
262+
if (!force_size) {
263+
while (nt < max_cta_size) {
264+
if (static_cast<index_t>(threads.x) * ept < sizes[3]) {
265+
threads.x *= 2;
266+
}
257267

258-
// nt *= 2;
259-
// }
260-
// }
268+
nt *= 2;
269+
}
270+
}
261271

262-
// // launch as many blocks as necessary
263-
// blocks.x = static_cast<int>(sizes[2]);
264-
// blocks.y = static_cast<int>(sizes[1]);
265-
// blocks.z = static_cast<int>(sizes[0]);
272+
// If we have multiple groups per block, we need to adjust the block size
273+
if (threads.y > 1) {
274+
blocks.x = static_cast<int>(static_cast<int64_t>(sizes[2]) / static_cast<int64_t>(threads.y));
275+
}
276+
else {
277+
blocks.x = static_cast<int>(sizes[2]);
278+
}
279+
280+
// launch as many blocks as necessary
281+
blocks.y = static_cast<int>(sizes[1]);
282+
blocks.z = static_cast<int>(sizes[0]);
266283

267-
// if(blocks.x > 65535) {
268-
// blocks.x = 65535;
269-
// stride = true;
270-
// }
271-
// if(blocks.y > 65535) {
272-
// blocks.y = 65535;
273-
// stride = true;
274-
// }
275-
// if(blocks.z > 65535) {
276-
// blocks.z = 65535;
277-
// stride = true;
278-
// }
279-
// }
280-
else {
281-
MATX_THROW(matxInvalidParameter, "Rank not supported");
282-
}
284+
if(blocks.x > 65535) {
285+
blocks.x = 65535;
286+
stride = true;
287+
}
288+
if(blocks.y > 65535) {
289+
blocks.y = 65535;
290+
stride = true;
291+
}
292+
if(blocks.z > 65535) {
293+
blocks.z = 65535;
294+
stride = true;
295+
}
296+
}
283297

284298
MATX_LOG_DEBUG("Blocks {}x{}x{} Threads {}x{}x{} groups_per_block={}", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, groups_per_block);
285299
return stride;

include/matx/core/iterator.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ namespace matx {
4848
template <typename OperatorType, bool ConvertType = true>
4949
struct RandomOperatorIterator {
5050
using self_type = RandomOperatorIterator<OperatorType, ConvertType>;
51-
using value_type = typename std::conditional_t<ConvertType, detail::convert_matx_type_t<typename OperatorType::value_type>, typename OperatorType::value_type>;
52-
// using stride_type = std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
51+
using value_type = typename cuda::std::conditional_t<ConvertType, detail::convert_matx_type_t<typename OperatorType::value_type>, typename OperatorType::value_type>;
52+
// using stride_type = cuda::std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
5353
// index_t>;
5454
using stride_type = index_t;
5555
using pointer = value_type*;
@@ -66,7 +66,7 @@ struct RandomOperatorIterator {
6666
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(OperatorType &&t, stride_type offset) : t_(t), offset_(offset) {}
6767

6868
template<typename T = OperatorType>
69-
requires (!std::is_same_v<T, OperatorBaseType>)
69+
requires (!cuda::std::is_same_v<T, OperatorBaseType>)
7070
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}
7171

7272
template<typename T = OperatorType>
@@ -193,8 +193,8 @@ __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ index_t operator-(const RandomOper
193193
template <typename OperatorType, bool ConvertType = true>
194194
struct RandomOperatorOutputIterator {
195195
using self_type = RandomOperatorOutputIterator<OperatorType, ConvertType>;
196-
using value_type = typename std::conditional_t<ConvertType, detail::convert_matx_type_t<typename OperatorType::value_type>, typename OperatorType::value_type>;
197-
// using stride_type = std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
196+
using value_type = typename cuda::std::conditional_t<ConvertType, detail::convert_matx_type_t<typename OperatorType::value_type>, typename OperatorType::value_type>;
197+
// using stride_type = cuda::std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
198198
// index_t>;
199199
using stride_type = index_t;
200200
using pointer = value_type*;
@@ -211,11 +211,11 @@ struct RandomOperatorOutputIterator {
211211
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(OperatorType &&t, stride_type offset) : t_(t), offset_(offset) {}
212212

213213
template<typename T = OperatorType>
214-
requires (!std::is_same_v<T, OperatorBaseType>)
214+
requires (!cuda::std::is_same_v<T, OperatorBaseType>)
215215
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}
216216

217217
template<typename T = OperatorType>
218-
requires (!std::is_same_v<T, OperatorBaseType>)
218+
requires (!cuda::std::is_same_v<T, OperatorBaseType>)
219219
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(OperatorBaseType &&t, stride_type offset) : t_(t), offset_(offset) {}
220220

221221
[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator*()
@@ -338,10 +338,10 @@ template <typename OperatorType, bool ConvertType = true>
338338
struct RandomOperatorThrustIterator {
339339
using self_type = RandomOperatorThrustIterator<OperatorType, ConvertType>;
340340
using const_strip_type = remove_cvref_t<typename OperatorType::value_type>;
341-
using value_type = typename std::conditional_t<ConvertType,
341+
using value_type = typename cuda::std::conditional_t<ConvertType,
342342
detail::convert_matx_type_t<const_strip_type>,
343343
const_strip_type>;
344-
// using stride_type = std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
344+
// using stride_type = cuda::std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
345345
// index_t>;
346346
using stride_type = index_t;
347347
using pointer = cuda::std::remove_const_t<value_type>*;
@@ -359,11 +359,11 @@ struct RandomOperatorThrustIterator {
359359
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(OperatorType &&t, stride_type offset) : t_(t), offset_(offset) {}
360360

361361
template<typename T = OperatorType>
362-
requires (!std::is_same_v<T, OperatorBaseType>)
362+
requires (!cuda::std::is_same_v<T, OperatorBaseType>)
363363
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}
364364

365365
template<typename T = OperatorType>
366-
requires (!std::is_same_v<T, OperatorBaseType>)
366+
requires (!cuda::std::is_same_v<T, OperatorBaseType>)
367367
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(OperatorBaseType &&t, stride_type offset) : t_(t), offset_(offset) {}
368368

369369
[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator*() const
@@ -463,7 +463,7 @@ template <typename OperatorType>
463463
struct BeginOffset {
464464
using self_type = BeginOffset<OperatorType>;
465465
using value_type = index_t;
466-
// using stride_type = std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
466+
// using stride_type = cuda::std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
467467
// index_t>;
468468
using stride_type = index_t;
469469
using pointer = value_type*;
@@ -522,7 +522,7 @@ template <typename OperatorType>
522522
struct EndOffset {
523523
using self_type = EndOffset<OperatorType>;
524524
using value_type = index_t;
525-
// using stride_type = std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
525+
// using stride_type = cuda::std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
526526
// index_t>;
527527
using stride_type = index_t;
528528
using pointer = value_type*;

include/matx/core/jit_includes.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,13 @@
3434

3535
// This file is used for jitify/NVRTC preprocessing. Do NOT include any files in here that can't be
3636
// parsed on the device, and try to keep this minimal to avoid unnecessary dependencies.
37+
#include <cuda/barrier>
38+
#include <cuda/std/__algorithm/min.h>
39+
#include <cuda/std/__algorithm/max.h>
3740
#include "matx/core/defines.h"
3841
#include "matx/core/type_utils_both.h"
3942
#include "matx/core/vector.h"
40-
#include "matx/operators/scalar_internal.h"
43+
//#include "matx/operators/scalar_internal.h"
44+
#include "matx/operators/scalar_ops.h"
4145
#include "matx/core/operator_utils.h"
42-
#include <cuda/barrier>
4346
#include <cub/block/block_load_to_shared.cuh>

include/matx/core/operator_options.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,16 @@ enum class SVDHostAlgo {
127127
DC /**< Divide and Conquer method (corresponds to `gesdd`) */
128128
};
129129

130+
/**
131+
* @brief Padding mode
132+
*
133+
* Specifies the padding mode to use for the pad operator.
134+
*/
135+
enum PadMode {
136+
MATX_PAD_MODE_CONSTANT, ///< Constant padding mode. All padding elements will be set to the user-provided pad_value.
137+
MATX_PAD_MODE_EDGE ///< Edge padding mode. All padding elements will be set to the edge values of the original operator.
138+
};
139+
130140

131141
namespace detail {
132142
static constexpr int MAX_FFT_RANK = 2;

0 commit comments

Comments
 (0)