diff --git a/include/matx/core/capabilities.h b/include/matx/core/capabilities.h index 5e5b7334..9961206d 100644 --- a/include/matx/core/capabilities.h +++ b/include/matx/core/capabilities.h @@ -70,6 +70,7 @@ namespace detail { MAX_EPT_VEC_LOAD, // The maximum EPT for a vector load. ELEMENT_WISE, // Whether the operator is element-wise (safe with aliasing) ALIASED_MEMORY, // Whether the operator's input and output pointers alias + GLOBAL_KERNEL, // Kernel operates entirely on a global level per chunk of data. False when at least one operator works on a block level // Add more capabilities as needed }; @@ -123,7 +124,7 @@ namespace detail { struct capability_attributes { using type = bool; using input_type = VoidCapabilityType; - static constexpr bool default_value = true; + static constexpr bool default_value = false; static constexpr bool or_identity = false; static constexpr bool and_identity = true; }; @@ -144,7 +145,16 @@ namespace detail { static constexpr bool default_value = false; static constexpr bool or_identity = false; static constexpr bool and_identity = true; - }; + }; + + template <> + struct capability_attributes { + using type = bool; + using input_type = VoidCapabilityType; + static constexpr bool default_value = true; + static constexpr bool or_identity = false; + static constexpr bool and_identity = true; + }; template <> struct capability_attributes { @@ -250,6 +260,10 @@ namespace detail { if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { return detail::type_to_string(); } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { + // If this is not a matx operator (like a constant or a lambda), we assume it supports JIT. + return true; + } else { return capability_attributes::default_value; } @@ -274,6 +288,8 @@ namespace detail { return CapabilityQueryType::AND_QUERY; // If any sub-operator supports JIT, the expression might be JIT-able. case OperatorCapability::ASYNC_LOADS_REQUESTED: return CapabilityQueryType::OR_QUERY; // If any sub-operator requires asynchronous loads, the expression might require asynchronous loads. + case OperatorCapability::GLOBAL_KERNEL: + return CapabilityQueryType::AND_QUERY; // If any sub-operator operates on a global level, the expression might operate on a global level. case OperatorCapability::ELEMENTS_PER_THREAD: return CapabilityQueryType::RANGE_QUERY; // The expression should use the range of elements per thread of its children. case OperatorCapability::SET_ELEMENTS_PER_THREAD: diff --git a/include/matx/core/get_grid_dims.h b/include/matx/core/get_grid_dims.h index 71cfce89..afecb785 100644 --- a/include/matx/core/get_grid_dims.h +++ b/include/matx/core/get_grid_dims.h @@ -173,7 +173,7 @@ inline bool get_grid_dims(dim3 &blocks, dim3 &threads, const cuda::std::array -inline bool get_grid_dims_jit(dim3 &blocks, dim3 &threads, const cuda::std::array &sizes, index_t ept, int groups_per_block, +inline bool get_grid_dims_block(dim3 &blocks, dim3 &threads, const cuda::std::array &sizes, index_t ept, int groups_per_block, int max_cta_size = 1024, bool force_size = false) { bool stride = false; diff --git a/include/matx/core/half.h b/include/matx/core/half.h index 3f349c39..623648ae 100644 --- a/include/matx/core/half.h +++ b/include/matx/core/half.h @@ -32,6 +32,8 @@ #pragma once +#include +#include #include #include @@ -41,6 +43,83 @@ namespace matx { +// Constexpr helper functions for float to half conversion +namespace detail { + +/** + * @brief Constexpr conversion from float to FP16 bits + * + * @param f Input float value + * @return uint16_t FP16 bit representation + */ +constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ uint16_t float_to_fp16_bits(float f) { + // Use bit_cast for constexpr context + uint32_t bits = cuda::std::bit_cast(f); + + uint32_t sign = (bits >> 16) & 0x8000; + int32_t exponent = static_cast(((bits >> 23) & 0xff)) - 127 + 15; + uint32_t mantissa = (bits >> 13) & 0x3ff; + + // Handle special cases + if (exponent <= 0) { + // Subnormal or zero + if (exponent < -10) { + // Too small, flush to zero + return static_cast(sign); + } + // Subnormal + mantissa = (mantissa | 0x400) >> (1 - exponent); + return static_cast(sign | mantissa); + } else if (exponent >= 0x1f) { + // Overflow to infinity or NaN + if (exponent == 0x1f + (127 - 15) && mantissa != 0) { + // NaN + return static_cast(sign | 0x7e00 | (mantissa != 0 ? 0x200 : 0)); + } + // Infinity + return static_cast(sign | 0x7c00); + } + + return static_cast(sign | (static_cast(exponent) << 10) | mantissa); +} + +/** + * @brief Constexpr conversion from float to BF16 bits + * + * @param f Input float value + * @return uint16_t BF16 bit representation (top 16 bits of float) + */ +constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ uint16_t float_to_bf16_bits(float f) { + // BF16 is just the top 16 bits of a float32 + // With rounding to nearest even + uint32_t bits = cuda::std::bit_cast(f); + + // Round to nearest even + uint32_t rounding_bias = 0x00007FFF + ((bits >> 16) & 1); + bits += rounding_bias; + uint16_t result = static_cast(bits >> 16); + + return result; +} + +/** + * @brief Helper to convert float to half type at compile time + * + * @tparam T The target half type (__half or __nv_bfloat16) + * @param f Input float value + * @return T Half-precision value + */ +template +constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ T float_to_half_constexpr(float f) { + if constexpr (cuda::std::is_same_v) { + return cuda::std::bit_cast<__half>(float_to_fp16_bits(f)); + } else { + return cuda::std::bit_cast<__nv_bfloat16>(float_to_bf16_bits(f)); + } +} + +} // namespace detail + /** * Template class for half precison numbers (__half and __nv_bfloat16). CUDA * does not have standardized classes/operators available on both host and @@ -64,12 +143,49 @@ template struct alignas(sizeof(T)) matxHalf { __MATX_INLINE__ matxHalf(const matxHalf &x_) noexcept = default; /** - * @brief Copy constructor from arbitrary type + * @brief Constexpr constructor from float + * + * @param f Float value to convert + */ + constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf(float f) noexcept + : x(detail::float_to_half_constexpr(f)) + { + } + + /** + * @brief Constexpr constructor from double + * + * @param d Double value to convert + */ + constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf(double d) noexcept + : x(detail::float_to_half_constexpr(static_cast(d))) + { + } + + /** + * @brief Constructor from integral types (constexpr) + * + * @tparam T2 Integral type to copy from + * @param x_ Value to copy + */ + template , int> = 0> + constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf(T2 x_) noexcept + : x(detail::float_to_half_constexpr(static_cast(x_))) + { + } + + /** + * @brief Copy constructor from arbitrary type (non-constexpr for non-arithmetic types) * * @tparam T2 Type to copy from * @param x_ Value to copy */ - template + template , float> && + !cuda::std::is_same_v, double> && + !cuda::std::is_integral_v, int> = 0> __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf(const T2 &x_) noexcept : x(static_cast(x_)) { @@ -1316,3 +1432,38 @@ using matxFp16 = matxHalf<__half>; ///< Alias for fp16 using matxBf16 = matxHalf<__nv_bfloat16>; ///< Alias for bf16 }; // namespace matx + +#ifndef __CUDACC_RTC__ +// Add std::formatter specializations for matxFp16 and matxBf16 +#include + +namespace std { + +/** + * @brief std::formatter specialization for matxFp16 + * + * Enables matxFp16 to work with std::format by converting to float + */ +template <> +struct formatter : formatter { + template + auto format(const matx::matxFp16& val, FormatContext& ctx) const { + return formatter::format(static_cast(val), ctx); + } +}; + +/** + * @brief std::formatter specialization for matxBf16 + * + * Enables matxBf16 to work with std::format by converting to float + */ +template <> +struct formatter : formatter { + template + auto format(const matx::matxBf16& val, FormatContext& ctx) const { + return formatter::format(static_cast(val), ctx); + } +}; + +} // namespace std +#endif \ No newline at end of file diff --git a/include/matx/core/half_complex.h b/include/matx/core/half_complex.h index 16555a8e..7f0d7157 100644 --- a/include/matx/core/half_complex.h +++ b/include/matx/core/half_complex.h @@ -60,7 +60,7 @@ template struct alignas(sizeof(T) * 2) matxHalfComplex { * * @param x_ Object to copy from */ - __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ + constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(const cuda::std::complex &x_) noexcept : x(x_.real()), y(x_.imag()) { @@ -73,7 +73,7 @@ template struct alignas(sizeof(T) * 2) matxHalfComplex { * @param x_ Value of scalar */ template - __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(const T2 &x_) noexcept + constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(const T2 &x_) noexcept : x(static_cast(x_)), y(0.0f) { } @@ -87,7 +87,7 @@ template struct alignas(sizeof(T) * 2) matxHalfComplex { * @param y_ Imaginary value */ template - __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(const T2 &x_, + constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(const T2 &x_, const T3 &y_) noexcept : x(static_cast(x_)), y(static_cast(y_)) { @@ -103,7 +103,7 @@ template struct alignas(sizeof(T) * 2) matxHalfComplex { template requires (cuda::std::is_same_v, matxFp16> || cuda::std::is_same_v, matxBf16>) - __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(T2 &&x_, T2 &&y_) noexcept + constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(T2 &&x_, T2 &&y_) noexcept : x(static_cast(x_)), y(static_cast(y_)) { @@ -1056,3 +1056,62 @@ using matxFp16Complex = matxHalfComplex; ///< Alias for a MatX fp16 co using matxBf16Complex = matxHalfComplex; ///< Alias for a MatXbf16 complex wrapper }; // namespace matx + +#ifndef __CUDACC_RTC__ +// Add std::formatter specializations for matxFp16Complex and matxBf16Complex +#include + +namespace std { + +/** + * @brief std::formatter specialization for matxFp16Complex + * + * Enables matxFp16Complex to work with std::format by converting to complex + */ +template <> +struct formatter { + template + constexpr auto parse(ParseContext& ctx) { + return ctx.begin(); + } + + template + auto format(const matx::matxFp16Complex& val, FormatContext& ctx) const { + float real_val = static_cast(val.real()); + float imag_val = static_cast(val.imag()); + + if (imag_val >= 0) { + return std::format_to(ctx.out(), "({}+{}i)", real_val, imag_val); + } else { + return std::format_to(ctx.out(), "({}{}i)", real_val, imag_val); + } + } +}; + +/** + * @brief std::formatter specialization for matxBf16Complex + * + * Enables matxBf16Complex to work with std::format by converting to complex + */ +template <> +struct formatter { + template + constexpr auto parse(ParseContext& ctx) { + return ctx.begin(); + } + + template + auto format(const matx::matxBf16Complex& val, FormatContext& ctx) const { + float real_val = static_cast(val.real()); + float imag_val = static_cast(val.imag()); + + if (imag_val >= 0) { + return std::format_to(ctx.out(), "({}+{}i)", real_val, imag_val); + } else { + return std::format_to(ctx.out(), "({}{}i)", real_val, imag_val); + } + } +}; + +} // namespace std +#endif // __CUDACC_RTC__ diff --git a/include/matx/core/jit_includes.h b/include/matx/core/jit_includes.h index 08e63063..71dfd273 100644 --- a/include/matx/core/jit_includes.h +++ b/include/matx/core/jit_includes.h @@ -34,9 +34,9 @@ // This file is used for jitify/NVRTC preprocessing. Do NOT include any files in here that can't be // parsed on the device, and try to keep this minimal to avoid unnecessary dependencies. -#include #include #include +#include #include "matx/core/defines.h" #include "matx/core/type_utils_both.h" #include "matx/core/vector.h" diff --git a/include/matx/core/log.h b/include/matx/core/log.h index 5bd93841..c3b62b27 100644 --- a/include/matx/core/log.h +++ b/include/matx/core/log.h @@ -89,38 +89,12 @@ namespace std { } }; - // Formatter for matxHalfComplex (fp16/bf16 complex) - template - struct formatter> { - constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } - - template - auto format(const matx::matxHalfComplex& c, FormatContext& ctx) const { - return format_to(ctx.out(), "{}", matx::detail::format_complex(c)); - } - }; - - // Formatter for matxFp16 (half-precision float) - template<> - struct formatter { - constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } - - template - auto format(const matx::matxFp16& val, FormatContext& ctx) const { - return format_to(ctx.out(), "{:g}", static_cast(val)); - } - }; + // Formatter for matxHalfComplex (fp16/bf16 complex) - moved to half_complex.h + // Formatter for matxFp16 (half-precision float) - moved to half.h + // Formatter for matxBf16 (bfloat16) - moved to half.h - // Formatter for matxBf16 (bfloat16) - template<> - struct formatter { - constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); } - - template - auto format(const matx::matxBf16& val, FormatContext& ctx) const { - return format_to(ctx.out(), "{:g}", static_cast(val)); - } - }; + // Note: The formatters for matxHalfComplex, matxFp16, and matxBf16 are now defined + // in their respective header files (half_complex.h and half.h) with proper guards. } namespace matx { diff --git a/include/matx/core/nvrtc_helper.h b/include/matx/core/nvrtc_helper.h index 884eab61..9245c666 100644 --- a/include/matx/core/nvrtc_helper.h +++ b/include/matx/core/nvrtc_helper.h @@ -151,7 +151,7 @@ std::vector __MATX_HOST__ __MATX_INLINE__ get_preprocessor_options( // Read file contents into a string -std::string read_file_contents(const std::string& filepath) { +inline std::string read_file_contents(const std::string& filepath) { std::ifstream file(filepath); if (!file.is_open()) { MATX_LOG_ERROR("Failed to open file: {}", filepath); @@ -163,41 +163,45 @@ std::string read_file_contents(const std::string& filepath) { } // Get the full path to jit_includes.h -std::string get_jit_includes_path() { +inline std::string get_jit_includes_path() { const auto source_path = std::filesystem::path(std::source_location::current().file_name()); const auto matx_root = source_path.parent_path().parent_path().parent_path().parent_path(); return (matx_root / "include" / "matx" / "core" / "jit_includes.h").string(); } template -std::string get_kernel_name([[maybe_unused]] const Op &op, bool stride) { +std::string get_kernel_name([[maybe_unused]] const Op &op, bool stride, bool global_kernel) { if constexpr (Op::Rank() == 0) { return "matx::detail::matxOpT0Kernel"; } else if constexpr (Op::Rank() == 1) { - return "matx::detail::matxOpT1Kernel"; + return global_kernel ? "matx::detail::matxOpT1Kernel" : "matx::detail::matxOpT1KernelBlock"; } else if constexpr (Op::Rank() == 2) { if (stride) { - return "matx::detail::matxOpT2StrideKernel"; + return global_kernel ? "matx::detail::matxOpT2StrideKernel" : "matx::detail::matxOpT2StrideKernelBlock"; } else { - return "matx::detail::matxOpT2Kernel"; + return global_kernel ? "matx::detail::matxOpT2Kernel" : "matx::detail::matxOpT2KernelBlock"; } } else if constexpr (Op::Rank() == 3) { if (stride) { - return "matx::detail::matxOpT3StrideKernel"; + return global_kernel ? "matx::detail::matxOpT3StrideKernel" : "matx::detail::matxOpT3StrideKernelBlock"; } else { - return "matx::detail::matxOpT3Kernel"; + return global_kernel ? "matx::detail::matxOpT3Kernel" : "matx::detail::matxOpT3KernelBlock"; } } else if constexpr (Op::Rank() == 4) { if (stride) { - return "matx::detail::matxOpT4StrideKernel"; + return global_kernel ? "matx::detail::matxOpT4StrideKernel" : "matx::detail::matxOpT4StrideKernelBlock"; } else { - return "matx::detail::matxOpT4Kernel"; + return global_kernel ? "matx::detail::matxOpT4Kernel" : "matx::detail::matxOpT4KernelBlock"; } } + else { + // For ranks > 4, use the TD (Tensor Dynamic) kernel + return "matx::detail::matxOpTDKernel"; + } return "MatXInvalidKernel"; } @@ -296,7 +300,7 @@ inline std::string qualify_jit_type_names(const std::string& type_str) { } template -auto nvrtc_compile_and_run([[maybe_unused]] const std::string &name, Op op, const SizeArray &sa, dim3 &blocks, dim3 &threads, ElementsPerThread ept, bool stride, int dynamic_shmem_size, int osize) { +auto nvrtc_compile_and_run([[maybe_unused]] const std::string &name, Op op, const SizeArray &sa, dim3 &blocks, dim3 &threads, ElementsPerThread ept, bool stride, int dynamic_shmem_size, int osize, bool global_kernel) { // Pure NVRTC implementation // Cache both module and function to prevent resource leaks // CUmodule must remain loaded for CUfunction to be valid @@ -311,7 +315,7 @@ auto nvrtc_compile_and_run([[maybe_unused]] const std::string &name, Op op, cons auto capstr = generate_capability_params_string(op, ept, false, osize, threads.x); const auto kernel_op_type = detail::get_operator_capability(op); - std::string kernel_name = get_kernel_name(op, stride); + std::string kernel_name = get_kernel_name(op, stride, global_kernel); std::string cache_key = kernel_name + "_" + kernel_op_type; MATX_LOG_DEBUG("nvrtc_compile_and_run called with operator type: {}", typeid(op).name()); @@ -541,31 +545,55 @@ auto nvrtc_compile_and_run([[maybe_unused]] const std::string &name, Op op, cons auto storage = op.ToJITStorage(); // Prepare kernel arguments - void* args[Op::Rank() + 1]; - args[0] = &storage; - if constexpr (Op::Rank() >= 1) { - args[1] = const_cast(reinterpret_cast(&sa[0])); - } - if constexpr (Op::Rank() >= 2) { - args[2] = const_cast(reinterpret_cast(&sa[1])); - } - if constexpr (Op::Rank() >= 3) { - args[3] = const_cast(reinterpret_cast(&sa[2])); + if constexpr (Op::Rank() > 4) { + // ND kernel: matxOpTDKernel(Op op, const cuda::std::array sizes, matx::index_t mult) + // mult is the product of all sizes except the first + index_t mult = cuda::std::accumulate(cuda::std::begin(sa) + 1, cuda::std::end(sa), 1, cuda::std::multiplies()); + + void* args[3]; + args[0] = &storage; + args[1] = const_cast(reinterpret_cast(&sa)); + args[2] = &mult; + + MATX_LOG_DEBUG("Launching kernel with grid=({}, {}, {}), block=({}, {}, {}), dynamic_shmem_size={} bytes", + blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, dynamic_shmem_size); + // Launch kernel + CUDA_CHECK(cuLaunchKernel(kernel_func, + blocks.x, blocks.y, blocks.z, + threads.x, threads.y, threads.z, + dynamic_shmem_size, + nullptr, // stream + args, + nullptr)); } - if constexpr (Op::Rank() >= 4) { - args[4] = const_cast(reinterpret_cast(&sa[3])); + else { + // Rank 0-4 kernels: Pass individual size parameters + void* args[Op::Rank() + 1]; + args[0] = &storage; + if constexpr (Op::Rank() >= 1) { + args[1] = const_cast(reinterpret_cast(&sa[0])); + } + if constexpr (Op::Rank() >= 2) { + args[2] = const_cast(reinterpret_cast(&sa[1])); + } + if constexpr (Op::Rank() >= 3) { + args[3] = const_cast(reinterpret_cast(&sa[2])); + } + if constexpr (Op::Rank() == 4) { + args[4] = const_cast(reinterpret_cast(&sa[3])); + } + + MATX_LOG_DEBUG("Launching kernel with grid=({}, {}, {}), block=({}, {}, {}), dynamic_shmem_size={} bytes", + blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, dynamic_shmem_size); + // Launch kernel + CUDA_CHECK(cuLaunchKernel(kernel_func, + blocks.x, blocks.y, blocks.z, + threads.x, threads.y, threads.z, + dynamic_shmem_size, + nullptr, // stream + args, + nullptr)); } - - MATX_LOG_DEBUG("Launching kernel with grid=({}, {}, {}), block=({}, {}, {}), dynamic_shmem_size={} bytes", - blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, dynamic_shmem_size); - // Launch kernel - CUDA_CHECK(cuLaunchKernel(kernel_func, - blocks.x, blocks.y, blocks.z, - threads.x, threads.y, threads.z, - dynamic_shmem_size, - nullptr, // stream - args, - nullptr)); } } diff --git a/include/matx/core/tensor_impl.h b/include/matx/core/tensor_impl.h index 15863fc8..c226df4a 100644 --- a/include/matx/core/tensor_impl.h +++ b/include/matx/core/tensor_impl.h @@ -248,15 +248,29 @@ class tensor_impl_t { " return ldata_[offset];\n" + " } else {\n" + " return *reinterpret_cast*>(ldata_ + offset);\n" + - " }\n" + - " }\n" + - " template \n" + - " __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T* data_ptr(index_t block_idx, index_t ttl_threads) const noexcept\n" + - " {\n" - " //const index_t offset = GetOffsetOptimized(indices...);\n" + - " //return ldata_ + offset;\n" + - " return ldata_ + block_idx * ttl_threads * static_cast(CapType::ept);\n" + - " }\n" + + " }\n" + + " }\n" + + " template \n" + + " __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array &idx) const noexcept\n" + + " {\n" + + " return cuda::std::apply([&](auto &&...args) {\n" + + " return this->operator()(args...);\n" + + " }, idx);\n" + + " }\n" + + " template \n" + + " __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array &idx) noexcept\n" + + " {\n" + + " return cuda::std::apply([&](auto &&...args) -> T& {\n" + + " return this->operator()(args...);\n" + + " }, idx);\n" + + " }\n" + + " template \n" + + " __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ T* data_ptr(index_t block_idx, index_t ttl_threads) const noexcept\n" + + " {\n" + " //const index_t offset = GetOffsetOptimized(indices...);\n" + + " //return ldata_ + offset;\n" + + " return ldata_ + block_idx * ttl_threads * static_cast(CapType::ept);\n" + + " }\n" + " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank()\n" + " {\n" + " return " + std::to_string(Rank()) + ";\n" + @@ -1442,6 +1456,13 @@ MATX_IGNORE_WARNING_POP_GCC return power; } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; +#endif + } else if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { #ifdef MATX_EN_JIT // No need to use combine_capabilities here since we're just returning a string. diff --git a/include/matx/core/type_utils_both.h b/include/matx/core/type_utils_both.h index 27cec067..cf4c8f95 100644 --- a/include/matx/core/type_utils_both.h +++ b/include/matx/core/type_utils_both.h @@ -382,6 +382,31 @@ concept is_cuda_executor = requires { template inline constexpr bool is_cuda_executor_v = requires { typename remove_cvref_t::cuda_executor; }; +/** + * @brief Determine if a type is a CUDA executor but NOT a JIT CUDA executor + * + * @tparam T Type to test + */ +template +concept is_cuda_non_jit_executor = requires { typename remove_cvref_t::cuda_executor; } + && !(requires { typename remove_cvref_t::jit_cuda_executor; }); + +// Legacy variable for backwards compatibility +template +inline constexpr bool is_cuda_non_jit_executor_v = requires { typename remove_cvref_t::cuda_executor; } + && !(requires { typename remove_cvref_t::jit_cuda_executor; }); + +/** + * @brief Determine if a type is a CUDA JIT executor + * + * @tparam T Type to test + */ +template +concept is_cuda_jit_executor = requires { typename remove_cvref_t::jit_cuda_executor; }; + +// Legacy variable for backwards compatibility +template +inline constexpr bool is_cuda_jit_executor_v = requires { typename remove_cvref_t::jit_cuda_executor; }; /** * @brief Determine if a type is a complex type (any type supported) @@ -766,7 +791,7 @@ struct complex_type_of template using matx_convert_complex_type = typename cuda::std::conditional_t, identity, - complex_type_of>::type; + complex_type_of>::type; #endif template struct value_type { diff --git a/include/matx/core/utils.h b/include/matx/core/utils.h index a97ddc7f..1e4415c2 100644 --- a/include/matx/core/utils.h +++ b/include/matx/core/utils.h @@ -34,6 +34,7 @@ #include #include +#include #include #include "matx/core/defines.h" @@ -241,12 +242,17 @@ __MATX_INLINE__ std::string array_to_string(const Container& container) { template __MATX_INLINE__ std::string array_to_string(const cuda::std::array& arr) { - std::string s; - for (size_t i = 0; i < N; ++i) { - if (i != 0) s += ", "; - s += std::to_string(arr[i]); + if constexpr (N == 0) { + return std::string(""); + } + else { + std::string s; + for (size_t i = 0; i < N; ++i) { + if (i != 0) s += ", "; + s += std::to_string(arr[i]); + } + return s; } - return s; } @@ -300,16 +306,16 @@ __MATX_INLINE__ __MATX_HOST__ std::string type_to_string() return "__nv_bfloat16"; } else if constexpr (std::is_same_v) { - return "matxFp16"; + return "matx::matxFp16"; } else if constexpr (std::is_same_v) { - return "matxBf16"; + return "matx::matxBf16"; } else if constexpr (std::is_same_v) { - return "matxFp16Complex"; + return "matx::matxFp16Complex"; } else if constexpr (std::is_same_v) { - return "matxBf16Complex"; + return "matx::matxBf16Complex"; } // CCCL complex types else if constexpr (std::is_same_v>) { @@ -355,6 +361,18 @@ __MATX_INLINE__ __MATX_HOST__ std::string type_to_string_c_name() else if constexpr (std::is_same_v) { return "long_long"; } + else if constexpr (std::is_same_v) { + return "matx_matxFp16"; + } + else if constexpr (std::is_same_v) { + return "matx_matxBf16"; + } + else if constexpr (std::is_same_v) { + return "matx_matxFp16Complex"; + } + else if constexpr (std::is_same_v) { + return "matx_matxBf16Complex"; + } else { return type_to_string(); } @@ -372,6 +390,48 @@ auto get_jit_class_or_pod_name(const T& op) } } +/** + * @brief Convert a number to a valid C++ symbol/identifier string + * + * Formats a numeric value as a string that can be used in C++ variable names. + * For complex numbers, the format is "r{real}_i{imag}". + * For non-complex numbers, the format is the string representation of the value. + * Special characters like '.' and '-' are replaced with 'p' (for point) and + * 'n' (for negative) respectively. + * + * @tparam T Numeric type (can be complex or non-complex) + * @param val Numeric value to convert + * @return String representation safe for use in C++ identifiers + * + * @example + * number_to_symbol(cuda::std::complex{1.5, -2.3}) returns "r1p5_in2p3" + * number_to_symbol(3.14f) returns "3p14" + * number_to_symbol(-5) returns "n5" + */ +template +__MATX_INLINE__ std::string number_to_symbol(const T& val) +{ + // Helper lambda to sanitize floating point values for variable names + auto sanitize_float = [](auto v) -> std::string { + std::string str = std::format("{}", v); + // Replace '.' with 'p' (for point), '-' with 'n' (for negative) + for (auto& c : str) { + if (c == '.') c = 'p'; + else if (c == '-') c = 'n'; + } + return str; + }; + + if constexpr (is_complex_v) { + // Format complex numbers as r{real}_i{imag} + auto real_val = val.real(); + auto imag_val = val.imag(); + return std::format("r{}_i{}", sanitize_float(real_val), sanitize_float(imag_val)); + } else { + // Format non-complex numbers directly + return sanitize_float(val); + } +} } diff --git a/include/matx/executors/cuda.h b/include/matx/executors/cuda.h index 5cb15b6b..f0776288 100644 --- a/include/matx/executors/cuda.h +++ b/include/matx/executors/cuda.h @@ -100,115 +100,8 @@ namespace matx } if constexpr (Op::Rank() <= 4) { - // Create kernel provider for non-JIT - auto kernel_provider = [&](detail::ElementsPerThread ept) { - dim3 local_blocks = 1; - dim3 local_threads = 1; - bool stride = detail::get_grid_dims(local_blocks, local_threads, sizes, static_cast(ept), 256); - - // Return appropriate kernel function pointer based on EPT, rank, and stride - switch (ept) { - case detail::ElementsPerThread::THIRTY_TWO: - if constexpr (Op::Rank() == 0) { - return (const void*)detail::matxOpT0Kernel, Op>; - } else if constexpr (Op::Rank() == 1) { - return (const void*)detail::matxOpT1Kernel, Op>; - } else if constexpr (Op::Rank() == 2) { - return stride ? (const void*)detail::matxOpT2StrideKernel, Op> - : (const void*)detail::matxOpT2Kernel, Op>; - } else if constexpr (Op::Rank() == 3) { - return stride ? (const void*)detail::matxOpT3StrideKernel, Op> - : (const void*)detail::matxOpT3Kernel, Op>; - } else if constexpr (Op::Rank() == 4) { - return stride ? (const void*)detail::matxOpT4StrideKernel, Op> - : (const void*)detail::matxOpT4Kernel, Op>; - } - break; - case detail::ElementsPerThread::SIXTEEN: - if constexpr (Op::Rank() == 0) { - return (const void*)detail::matxOpT0Kernel, Op>; - } else if constexpr (Op::Rank() == 1) { - return (const void*)detail::matxOpT1Kernel, Op>; - } else if constexpr (Op::Rank() == 2) { - return stride ? (const void*)detail::matxOpT2StrideKernel, Op> - : (const void*)detail::matxOpT2Kernel, Op>; - } else if constexpr (Op::Rank() == 3) { - return stride ? (const void*)detail::matxOpT3StrideKernel, Op> - : (const void*)detail::matxOpT3Kernel, Op>; - } else if constexpr (Op::Rank() == 4) { - return stride ? (const void*)detail::matxOpT4StrideKernel, Op> - : (const void*)detail::matxOpT4Kernel, Op>; - } - break; - case detail::ElementsPerThread::EIGHT: - if constexpr (Op::Rank() == 0) { - return (const void*)detail::matxOpT0Kernel, Op>; - } else if constexpr (Op::Rank() == 1) { - return (const void*)detail::matxOpT1Kernel, Op>; - } else if constexpr (Op::Rank() == 2) { - return stride ? (const void*)detail::matxOpT2StrideKernel, Op> - : (const void*)detail::matxOpT2Kernel, Op>; - } else if constexpr (Op::Rank() == 3) { - return stride ? (const void*)detail::matxOpT3StrideKernel, Op> - : (const void*)detail::matxOpT3Kernel, Op>; - } else if constexpr (Op::Rank() == 4) { - return stride ? (const void*)detail::matxOpT4StrideKernel, Op> - : (const void*)detail::matxOpT4Kernel, Op>; - } - break; - case detail::ElementsPerThread::FOUR: - if constexpr (Op::Rank() == 0) { - return (const void*)detail::matxOpT0Kernel, Op>; - }else if constexpr (Op::Rank() == 1) { - return (const void*)detail::matxOpT1Kernel, Op>; - } else if constexpr (Op::Rank() == 2) { - return stride ? (const void*)detail::matxOpT2StrideKernel, Op> - : (const void*)detail::matxOpT2Kernel, Op>; - } else if constexpr (Op::Rank() == 3) { - return stride ? (const void*)detail::matxOpT3StrideKernel, Op> - : (const void*)detail::matxOpT3Kernel, Op>; - } else if constexpr (Op::Rank() == 4) { - return stride ? (const void*)detail::matxOpT4StrideKernel, Op> - : (const void*)detail::matxOpT4Kernel, Op>; - } - break; - case detail::ElementsPerThread::TWO: - if constexpr (Op::Rank() == 0) { - return (const void*)detail::matxOpT0Kernel, Op>; - } else if constexpr (Op::Rank() == 1) { - return (const void*)detail::matxOpT1Kernel, Op>; - } else if constexpr (Op::Rank() == 2) { - return stride ? (const void*)detail::matxOpT2StrideKernel, Op> - : (const void*)detail::matxOpT2Kernel, Op>; - } else if constexpr (Op::Rank() == 3) { - return stride ? (const void*)detail::matxOpT3StrideKernel, Op> - : (const void*)detail::matxOpT3Kernel, Op>; - } else if constexpr (Op::Rank() == 4) { - return stride ? (const void*)detail::matxOpT4StrideKernel, Op> - : (const void*)detail::matxOpT4Kernel, Op>; - } - break; - case detail::ElementsPerThread::ONE: - if constexpr (Op::Rank() == 0) { - return (const void*)detail::matxOpT0Kernel, Op>; - } else if constexpr (Op::Rank() == 1) { - return (const void*)detail::matxOpT1Kernel, Op>; - } else if constexpr (Op::Rank() == 2) { - return stride ? (const void*)detail::matxOpT2StrideKernel, Op> - : (const void*)detail::matxOpT2Kernel, Op>; - } else if constexpr (Op::Rank() == 3) { - return stride ? (const void*)detail::matxOpT3StrideKernel, Op> - : (const void*)detail::matxOpT3Kernel, Op>; - } else if constexpr (Op::Rank() == 4) { - return stride ? (const void*)detail::matxOpT4StrideKernel, Op> - : (const void*)detail::matxOpT4Kernel, Op>; - } - break; - default: - return (const void*)nullptr; - } - return (const void*)nullptr; - }; + // Create kernel provider for non-JIT using consolidated function + auto kernel_provider = detail::create_kernel_provider(sizes, false, false); // Find the best launch parameters auto [best_ept, shm_size, block_size, groups_per_block] = detail::find_best_launch_params(op, kernel_provider, 256, false); diff --git a/include/matx/executors/cuda_executor_common.h b/include/matx/executors/cuda_executor_common.h index 5fcbdef7..ef3f3402 100644 --- a/include/matx/executors/cuda_executor_common.h +++ b/include/matx/executors/cuda_executor_common.h @@ -34,6 +34,9 @@ #include #include "matx/core/capabilities.h" #include "matx/core/defines.h" +#include "matx/core/get_grid_dims.h" +#include "matx/executors/kernel.h" +#include "matx/core/log.h" #include namespace matx @@ -137,6 +140,138 @@ namespace detail cudaEvent_t stop_; }; + /** + * @brief Create a kernel provider that returns appropriate kernel function pointers + * + * This function creates a lambda that provides kernel function pointers based on EPT, + * rank, and stride parameters. Used by both JIT and non-JIT CUDA executors. + * + * @tparam Op Operator type + * @param sizes Array of dimension sizes + * @param is_jit Whether this is for JIT compilation + * @param global_kernel Whether this is a global kernel (only for JIT) + * @return Lambda function that returns kernel pointer for a given EPT + */ + template + auto create_kernel_provider(const cuda::std::array& sizes, bool is_jit = false, bool global_kernel = false) { + return [&, is_jit, global_kernel](ElementsPerThread ept) -> const void* { + dim3 local_blocks = 1; + dim3 local_threads = 1; + bool stride; + + if (is_jit && !global_kernel) { + stride = get_grid_dims_block(local_blocks, local_threads, sizes, static_cast(ept), 1, 1024, true); + } else { + stride = get_grid_dims(local_blocks, local_threads, sizes, static_cast(ept), is_jit ? 1024 : 256); + } + +#ifdef __CUDACC__ + // Return appropriate kernel function pointer based on EPT, rank, and stride + switch (ept) { + case ElementsPerThread::THIRTY_TWO: + if constexpr (Op::Rank() == 0) { + return (const void*)matxOpT0Kernel, Op>; + } else if constexpr (Op::Rank() == 1) { + return (const void*)matxOpT1Kernel, Op>; + } else if constexpr (Op::Rank() == 2) { + return stride ? (const void*)matxOpT2StrideKernel, Op> + : (const void*)matxOpT2Kernel, Op>; + } else if constexpr (Op::Rank() == 3) { + return stride ? (const void*)matxOpT3StrideKernel, Op> + : (const void*)matxOpT3Kernel, Op>; + } else if constexpr (Op::Rank() == 4) { + return stride ? (const void*)matxOpT4StrideKernel, Op> + : (const void*)matxOpT4Kernel, Op>; + } + break; + case ElementsPerThread::SIXTEEN: + if constexpr (Op::Rank() == 0) { + return (const void*)matxOpT0Kernel, Op>; + } else if constexpr (Op::Rank() == 1) { + return (const void*)matxOpT1Kernel, Op>; + } else if constexpr (Op::Rank() == 2) { + return stride ? (const void*)matxOpT2StrideKernel, Op> + : (const void*)matxOpT2Kernel, Op>; + } else if constexpr (Op::Rank() == 3) { + return stride ? (const void*)matxOpT3StrideKernel, Op> + : (const void*)matxOpT3Kernel, Op>; + } else if constexpr (Op::Rank() == 4) { + return stride ? (const void*)matxOpT4StrideKernel, Op> + : (const void*)matxOpT4Kernel, Op>; + } + break; + case ElementsPerThread::EIGHT: + if constexpr (Op::Rank() == 0) { + return (const void*)matxOpT0Kernel, Op>; + } else if constexpr (Op::Rank() == 1) { + return (const void*)matxOpT1Kernel, Op>; + } else if constexpr (Op::Rank() == 2) { + return stride ? (const void*)matxOpT2StrideKernel, Op> + : (const void*)matxOpT2Kernel, Op>; + } else if constexpr (Op::Rank() == 3) { + return stride ? (const void*)matxOpT3StrideKernel, Op> + : (const void*)matxOpT3Kernel, Op>; + } else if constexpr (Op::Rank() == 4) { + return stride ? (const void*)matxOpT4StrideKernel, Op> + : (const void*)matxOpT4Kernel, Op>; + } + break; + case ElementsPerThread::FOUR: + if constexpr (Op::Rank() == 0) { + return (const void*)matxOpT0Kernel, Op>; + } else if constexpr (Op::Rank() == 1) { + return (const void*)matxOpT1Kernel, Op>; + } else if constexpr (Op::Rank() == 2) { + return stride ? (const void*)matxOpT2StrideKernel, Op> + : (const void*)matxOpT2Kernel, Op>; + } else if constexpr (Op::Rank() == 3) { + return stride ? (const void*)matxOpT3StrideKernel, Op> + : (const void*)matxOpT3Kernel, Op>; + } else if constexpr (Op::Rank() == 4) { + return stride ? (const void*)matxOpT4StrideKernel, Op> + : (const void*)matxOpT4Kernel, Op>; + } + break; + case ElementsPerThread::TWO: + if constexpr (Op::Rank() == 0) { + return (const void*)matxOpT0Kernel, Op>; + } else if constexpr (Op::Rank() == 1) { + return (const void*)matxOpT1Kernel, Op>; + } else if constexpr (Op::Rank() == 2) { + return stride ? (const void*)matxOpT2StrideKernel, Op> + : (const void*)matxOpT2Kernel, Op>; + } else if constexpr (Op::Rank() == 3) { + return stride ? (const void*)matxOpT3StrideKernel, Op> + : (const void*)matxOpT3Kernel, Op>; + } else if constexpr (Op::Rank() == 4) { + return stride ? (const void*)matxOpT4StrideKernel, Op> + : (const void*)matxOpT4Kernel, Op>; + } + break; + case ElementsPerThread::ONE: + if constexpr (Op::Rank() == 0) { + return (const void*)matxOpT0Kernel, Op>; + } else if constexpr (Op::Rank() == 1) { + return (const void*)matxOpT1Kernel, Op>; + } else if constexpr (Op::Rank() == 2) { + return stride ? (const void*)matxOpT2StrideKernel, Op> + : (const void*)matxOpT2Kernel, Op>; + } else if constexpr (Op::Rank() == 3) { + return stride ? (const void*)matxOpT3StrideKernel, Op> + : (const void*)matxOpT3Kernel, Op>; + } else if constexpr (Op::Rank() == 4) { + return stride ? (const void*)matxOpT4StrideKernel, Op> + : (const void*)matxOpT4Kernel, Op>; + } + break; + default: + return (const void*)nullptr; + } +#endif + return (const void*)nullptr; + }; + } + /** * Find the best launch parameters by testing EPT values for optimal occupancy * @@ -191,31 +326,70 @@ namespace detail // Determine block size for register calculation if (use_jit) { const auto group_range = detail::get_operator_capability(op); - groups_per_block = group_range[0]; - const int total_batches = static_cast(TotalSize(op) / op.Size(0)); - // If we don't have enough batches then fix this to the smaller amount - groups_per_block = cuda::std::min(groups_per_block, total_batches); - const auto set_groups_per_block_query = detail::SetGroupsPerBlockQueryInput{groups_per_block}; - const auto set_groups_per_block = detail::get_operator_capability(op, set_groups_per_block_query); - // Use the max block size for now - block_size = detail::get_operator_capability(op)[1]; - shm_size = detail::get_operator_capability(op); - } - - // Check register pressure constraint - register_viable = (attr.numRegs * block_size * min_occupancy) <= regs_per_multiprocessor; - - // Check dynamic shared memory constraint - bool shm_viable = (shm_size * 2) < max_dynamic_shm; - - if (shm_viable && register_viable) { - MATX_LOG_DEBUG("Selected EPT {}: jits {}, registers {}, shm_size {}, block_size {}, groups_per_block {}", - static_cast(current_ept), use_jit, attr.numRegs, shm_size, block_size, groups_per_block); - return cuda::std::make_tuple(current_ept, shm_size, block_size, groups_per_block); + int min_groups_per_block = group_range[0]; + int max_groups_per_block = group_range[1]; + if (max_groups_per_block == 32) { + max_groups_per_block = 1024; + } + + int total_batches = 1; + if constexpr (op.Rank() > 0) { + total_batches = static_cast(TotalSize(op) / op.Size(op.Rank() - 1)); + } + + // Iterate through all possible groups_per_block values + for (int current_groups_per_block = max_groups_per_block; current_groups_per_block >= min_groups_per_block; current_groups_per_block /= 2) { + // If we don't have enough batches then skip this groups_per_block + if (current_groups_per_block > total_batches) { + continue; + } + + MATX_LOG_DEBUG("Trying groups_per_block {} with {} batches", current_groups_per_block, total_batches); + + groups_per_block = current_groups_per_block; + const auto set_groups_per_block_query = detail::SetGroupsPerBlockQueryInput{groups_per_block}; + const auto set_groups_per_block = detail::get_operator_capability(op, set_groups_per_block_query); + // Use the max block size for now + block_size = detail::get_operator_capability(op)[1]; + shm_size = detail::get_operator_capability(op); + + // Check register pressure constraint + register_viable = (attr.numRegs * block_size * min_occupancy) <= regs_per_multiprocessor; + + // Check dynamic shared memory constraint + bool shm_viable = (shm_size * 2) < max_dynamic_shm; + + if (shm_viable && register_viable) { + MATX_LOG_DEBUG("Selected EPT {}: jits {}, registers {}, shm_size {}, block_size {}, groups_per_block {}", + static_cast(current_ept), use_jit, attr.numRegs, shm_size, block_size, groups_per_block); + return cuda::std::make_tuple(current_ept, shm_size, block_size, groups_per_block); + } + else { + MATX_LOG_DEBUG("EPT {} with groups_per_block {} failed constraints: shm_viable {} ({} of {}), register_viable {} (regs={}) block size {}", + static_cast(current_ept), groups_per_block, shm_viable, shm_size, max_dynamic_shm, register_viable, attr.numRegs, block_size); + } + + // Break if we're at the minimum + if (current_groups_per_block == min_groups_per_block) break; + } } else { - MATX_LOG_DEBUG("EPT {} failed constraints: shm_viable {} ({} of {}), register_viable {} (regs={}) block size {}, groups_per_block {}", - static_cast(current_ept), shm_viable, shm_size, max_dynamic_shm, register_viable, attr.numRegs, block_size, groups_per_block); + // Non-JIT path - check constraints without groups_per_block loop + // Check register pressure constraint + register_viable = (attr.numRegs * block_size * min_occupancy) <= regs_per_multiprocessor; + + // Check dynamic shared memory constraint + bool shm_viable = (shm_size * 2) < max_dynamic_shm; + + if (shm_viable && register_viable) { + MATX_LOG_DEBUG("Selected EPT {}: jits {}, registers {}, shm_size {}, block_size {}, groups_per_block {}", + static_cast(current_ept), use_jit, attr.numRegs, shm_size, block_size, groups_per_block); + return cuda::std::make_tuple(current_ept, shm_size, block_size, groups_per_block); + } + else { + MATX_LOG_DEBUG("EPT {} failed constraints: shm_viable {} ({} of {}), register_viable {} (regs={}) block size {}, groups_per_block {}", + static_cast(current_ept), shm_viable, shm_size, max_dynamic_shm, register_viable, attr.numRegs, block_size, groups_per_block); + } } // Cut EPT in half diff --git a/include/matx/executors/jit_cuda.h b/include/matx/executors/jit_cuda.h index 3f4feb65..65b35af0 100644 --- a/include/matx/executors/jit_cuda.h +++ b/include/matx/executors/jit_cuda.h @@ -43,9 +43,51 @@ #include #include #include +#include +#include namespace matx { + namespace detail { + /** + * @brief Cached launch parameters for JIT kernels + * + * This structure stores all computed launch parameters so we can skip expensive + * computations when we've already compiled a kernel for this operator type. + * + * IMPORTANT: For JIT compilation, tensor sizes are encoded in the operator type + * string (from JIT_TYPE_QUERY). This means different sizes produce different cache + * keys, so grid dimensions ARE safe to cache - they won't vary for the same type. + * + * The cache works in conjunction with the nvrtc_compile_and_run kernel cache: + * 1. First execution: Computes all launch params -> caches everything + * 2. Subsequent executions: Uses cached params -> skips ALL computations + * + * This avoids: + * - find_best_launch_params: EPT selection, device queries, occupancy calculations + * - get_grid_dims/get_grid_dims_block: Grid dimension calculations + * - All CUDA device attribute queries + * - Register pressure and shared memory constraint analysis + * + * The cache is keyed by the operator type string from JIT_TYPE_QUERY which + * includes both the operator structure AND tensor sizes. + */ + struct JITLaunchParams { + ElementsPerThread best_ept; // Optimal elements per thread for this operator + int shm_size; // Dynamic shared memory size in bytes + int block_size; // Block dimension size + int groups_per_block; // Groups per block (for block-level kernels) + bool stride; // Whether kernel uses grid-stride loops + dim3 blocks; // Grid dimensions (x, y, z blocks) + dim3 threads; // Block dimensions (x, y, z threads) + int osize; // Output size (last dimension) + bool global_kernel; // Whether this is a global or block-level kernel + }; + + // Global cache for JIT launch parameters, keyed by operator type string from JIT_TYPE_QUERY + static std::unordered_map jit_launch_params_cache; + static std::mutex jit_launch_params_mutex; + } // namespace detail /** * @brief Executes operators on a CUDA-enabled device using JIT compilation @@ -119,125 +161,158 @@ namespace matx if (jit_ept_bounds[0] == detail::ElementsPerThread::INVALID) { MATX_THROW(matxInvalidParameter, "Operator does not support JIT compilation. Use cudaExecutor instead."); } - - // Create kernel provider for JIT - auto kernel_provider = [&](detail::ElementsPerThread ept) { - dim3 local_blocks = 1; - dim3 local_threads = 1; - bool stride = detail::get_grid_dims_jit(local_blocks, local_threads, sizes, static_cast(ept), 1, 1024, true); + + + bool global_kernel = detail::get_operator_capability(op); + if (global_kernel) { + MATX_LOG_DEBUG("Operator operates on a global level"); + } else { + MATX_LOG_DEBUG("Operator operates on a block level"); + } + + if constexpr (Op::Rank() <= 4) { + // Get operator type string for cache lookup + const auto kernel_op_type = detail::get_operator_capability(op); - // Return appropriate kernel function pointer based on EPT, rank, and stride - switch (ept) { - case detail::ElementsPerThread::THIRTY_TWO: - if constexpr (Op::Rank() == 0) { - return (const void*)detail::matxOpT0Kernel, Op>; - } else if constexpr (Op::Rank() == 1) { - return (const void*)detail::matxOpT1Kernel, Op>; - } else if constexpr (Op::Rank() == 2) { - return stride ? (const void*)detail::matxOpT2StrideKernel, Op> - : (const void*)detail::matxOpT2Kernel, Op>; - } else if constexpr (Op::Rank() == 3) { - return stride ? (const void*)detail::matxOpT3StrideKernel, Op> - : (const void*)detail::matxOpT3Kernel, Op>; - } else if constexpr (Op::Rank() == 4) { - return stride ? (const void*)detail::matxOpT4StrideKernel, Op> - : (const void*)detail::matxOpT4Kernel, Op>; - } - break; - case detail::ElementsPerThread::SIXTEEN: - if constexpr (Op::Rank() == 0) { - return (const void*)detail::matxOpT0Kernel, Op>; - } else if constexpr (Op::Rank() == 1) { - return (const void*)detail::matxOpT1Kernel, Op>; - } else if constexpr (Op::Rank() == 2) { - return stride ? (const void*)detail::matxOpT2StrideKernel, Op> - : (const void*)detail::matxOpT2Kernel, Op>; - } else if constexpr (Op::Rank() == 3) { - return stride ? (const void*)detail::matxOpT3StrideKernel, Op> - : (const void*)detail::matxOpT3Kernel, Op>; - } else if constexpr (Op::Rank() == 4) { - return stride ? (const void*)detail::matxOpT4StrideKernel, Op> - : (const void*)detail::matxOpT4Kernel, Op>; - } - break; - case detail::ElementsPerThread::EIGHT: - if constexpr (Op::Rank() == 0) { - return (const void*)detail::matxOpT0Kernel, Op>; - } else if constexpr (Op::Rank() == 1) { - return (const void*)detail::matxOpT1Kernel, Op>; - } else if constexpr (Op::Rank() == 2) { - return stride ? (const void*)detail::matxOpT2StrideKernel, Op> - : (const void*)detail::matxOpT2Kernel, Op>; - } else if constexpr (Op::Rank() == 3) { - return stride ? (const void*)detail::matxOpT3StrideKernel, Op> - : (const void*)detail::matxOpT3Kernel, Op>; - } else if constexpr (Op::Rank() == 4) { - return stride ? (const void*)detail::matxOpT4StrideKernel, Op> - : (const void*)detail::matxOpT4Kernel, Op>; - } - break; - case detail::ElementsPerThread::FOUR: - if constexpr (Op::Rank() == 0) { - return (const void*)detail::matxOpT0Kernel, Op>; - }else if constexpr (Op::Rank() == 1) { - return (const void*)detail::matxOpT1Kernel, Op>; - } else if constexpr (Op::Rank() == 2) { - return stride ? (const void*)detail::matxOpT2StrideKernel, Op> - : (const void*)detail::matxOpT2Kernel, Op>; - } else if constexpr (Op::Rank() == 3) { - return stride ? (const void*)detail::matxOpT3StrideKernel, Op> - : (const void*)detail::matxOpT3Kernel, Op>; - } else if constexpr (Op::Rank() == 4) { - return stride ? (const void*)detail::matxOpT4StrideKernel, Op> - : (const void*)detail::matxOpT4Kernel, Op>; - } - break; - case detail::ElementsPerThread::TWO: - if constexpr (Op::Rank() == 0) { - return (const void*)detail::matxOpT0Kernel, Op>; - } else if constexpr (Op::Rank() == 1) { - return (const void*)detail::matxOpT1Kernel, Op>; - } else if constexpr (Op::Rank() == 2) { - return stride ? (const void*)detail::matxOpT2StrideKernel, Op> - : (const void*)detail::matxOpT2Kernel, Op>; - } else if constexpr (Op::Rank() == 3) { - return stride ? (const void*)detail::matxOpT3StrideKernel, Op> - : (const void*)detail::matxOpT3Kernel, Op>; - } else if constexpr (Op::Rank() == 4) { - return stride ? (const void*)detail::matxOpT4StrideKernel, Op> - : (const void*)detail::matxOpT4Kernel, Op>; - } - break; - case detail::ElementsPerThread::ONE: - if constexpr (Op::Rank() == 0) { - return (const void*)detail::matxOpT0Kernel, Op>; - } else if constexpr (Op::Rank() == 1) { - return (const void*)detail::matxOpT1Kernel, Op>; - } else if constexpr (Op::Rank() == 2) { - return stride ? (const void*)detail::matxOpT2StrideKernel, Op> - : (const void*)detail::matxOpT2Kernel, Op>; - } else if constexpr (Op::Rank() == 3) { - return stride ? (const void*)detail::matxOpT3StrideKernel, Op> - : (const void*)detail::matxOpT3Kernel, Op>; - } else if constexpr (Op::Rank() == 4) { - return (const void*)detail::matxOpT4Kernel, Op>; - } - break; - default: - return (const void*)nullptr; + // Check if we have cached launch parameters for this operator type + detail::JITLaunchParams cached_params; + bool has_cached_params = false; + { + std::lock_guard lock(detail::jit_launch_params_mutex); + auto it = detail::jit_launch_params_cache.find(kernel_op_type); + if (it != detail::jit_launch_params_cache.end()) { + cached_params = it->second; + has_cached_params = true; + } } - return (const void*)nullptr; - }; - - MATX_LOG_DEBUG("Finding best launch parameters for JIT"); - // Find the best launch parameters - auto [best_ept, shm_size, block_size, groups_per_block] = detail::find_best_launch_params(op, kernel_provider, 0, true); - - bool stride = detail::get_grid_dims_jit(blocks, threads, sizes, static_cast(best_ept), groups_per_block, block_size, true); - MATX_LOG_DEBUG("Shm size {}, Stride {}, estimated EPT {}, blocks {}x{}x{} threads {}x{}x{}", - shm_size, stride, static_cast(best_ept), blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z); - const int osize = op.Rank() == 0 ? 1 : static_cast(op.Size(op.Rank() - 1)); - detail::nvrtc_compile_and_run("output.cu", op, sizes, blocks, threads, best_ept, stride, shm_size, osize); + + detail::ElementsPerThread best_ept; + int shm_size, block_size, groups_per_block; + bool stride; + + if (has_cached_params) { + // Use cached parameters - skip ALL expensive computations! + MATX_LOG_DEBUG("Using cached launch parameters for operator type: {}", kernel_op_type); + best_ept = cached_params.best_ept; + shm_size = cached_params.shm_size; + block_size = cached_params.block_size; + groups_per_block = cached_params.groups_per_block; + stride = cached_params.stride; + blocks = cached_params.blocks; + threads = cached_params.threads; + + MATX_LOG_DEBUG("Cached EPT {}, Shm size {}, Block size {}, Groups per block {}", + static_cast(best_ept), shm_size, block_size, groups_per_block); + } else { + // No cached parameters - compute them + MATX_LOG_DEBUG("No cached parameters found, computing launch parameters for JIT"); + + // Create kernel provider for JIT using consolidated function + auto kernel_provider = detail::create_kernel_provider(sizes, true, global_kernel); + + // Find the best launch parameters + auto result = detail::find_best_launch_params(op, kernel_provider, 0, true); + best_ept = cuda::std::get<0>(result); + shm_size = cuda::std::get<1>(result); + block_size = cuda::std::get<2>(result); + groups_per_block = cuda::std::get<3>(result); + + MATX_LOG_DEBUG("Best EPT {}, Shm size {}, Block size {}, Groups per block {}", + static_cast(best_ept), shm_size, block_size, groups_per_block); + + if (global_kernel) { + stride = detail::get_grid_dims(blocks, threads, sizes, static_cast(best_ept), 256); + } else { + stride = detail::get_grid_dims_block(blocks, threads, sizes, static_cast(best_ept), groups_per_block, block_size, true); + } + + // Cache ALL parameters for future use (sizes are encoded in type string) + detail::JITLaunchParams params_to_cache; + params_to_cache.best_ept = best_ept; + params_to_cache.shm_size = shm_size; + params_to_cache.block_size = block_size; + params_to_cache.groups_per_block = groups_per_block; + params_to_cache.stride = stride; + params_to_cache.blocks = blocks; + params_to_cache.threads = threads; + params_to_cache.osize = op.Rank() == 0 ? 1 : static_cast(op.Size(op.Rank() - 1)); + params_to_cache.global_kernel = global_kernel; + + { + std::lock_guard lock(detail::jit_launch_params_mutex); + detail::jit_launch_params_cache[kernel_op_type] = params_to_cache; + } + } + + MATX_LOG_DEBUG("Shm size {}, Stride {}, estimated EPT {}, blocks {}x{}x{} threads {}x{}x{}", + shm_size, stride, static_cast(best_ept), blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z); + const int osize = op.Rank() == 0 ? 1 : static_cast(op.Size(op.Rank() - 1)); + detail::nvrtc_compile_and_run("output.cu", op, sizes, blocks, threads, best_ept, stride, shm_size, osize, global_kernel); + } + else { + // ND kernel support for ranks > 4 (JIT path) + // Get operator type string for cache lookup + const auto kernel_op_type = detail::get_operator_capability(op); + + // Check if we have cached launch parameters for this operator type + detail::JITLaunchParams cached_params; + bool has_cached_params = false; + { + std::lock_guard lock(detail::jit_launch_params_mutex); + auto it = detail::jit_launch_params_cache.find(kernel_op_type); + if (it != detail::jit_launch_params_cache.end()) { + cached_params = it->second; + has_cached_params = true; + } + } + + detail::ElementsPerThread best_ept; + bool stride; + + if (has_cached_params) { + // Use cached parameters - skip ALL computations! + MATX_LOG_DEBUG("Using cached launch parameters for ND kernel: {}", kernel_op_type); + best_ept = cached_params.best_ept; + stride = cached_params.stride; + blocks = cached_params.blocks; + threads = cached_params.threads; + } else { + // No cached parameters - compute them + MATX_LOG_DEBUG("No cached parameters found, computing launch parameters for ND kernel"); + + // Reuse the ept_type and jit_ept_bounds from above + const auto ept_bounds = jit_ept_bounds; + best_ept = ept_bounds[1]; + stride = detail::get_grid_dims(blocks, threads, sizes, static_cast(best_ept), 1024); + + // Cache ALL parameters for future use (sizes are encoded in type string) + detail::JITLaunchParams params_to_cache; + params_to_cache.best_ept = best_ept; + params_to_cache.shm_size = 0; + params_to_cache.block_size = threads.x; + params_to_cache.groups_per_block = 1; + params_to_cache.stride = stride; + params_to_cache.blocks = blocks; + params_to_cache.threads = threads; + params_to_cache.osize = op.Rank() == 0 ? 1 : static_cast(op.Size(op.Rank() - 1)); + params_to_cache.global_kernel = true; + + { + std::lock_guard lock(detail::jit_launch_params_mutex); + detail::jit_launch_params_cache[kernel_op_type] = params_to_cache; + } + } + + MATX_LOG_DEBUG("Using ND kernel for rank > 4 with JIT and EPT {}", static_cast(best_ept)); + index_t dims = cuda::std::accumulate(cuda::std::begin(sizes) + 1, cuda::std::end(sizes), 1, cuda::std::multiplies()); + const int osize = op.Rank() == 0 ? 1 : static_cast(op.Size(op.Rank() - 1)); + + MATX_LOG_DEBUG("ND kernel: stride {}, blocks {}x{}x{} threads {}x{}x{}, dims {}", + stride, blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, dims); + + // Use ND kernel through JIT compilation + detail::nvrtc_compile_and_run("output.cu", op, sizes, blocks, threads, best_ept, stride, 0, osize, true); + } #else MATX_ASSERT_STR(false, matxInvalidParameter, "Cannot call device executor using host compiler"); diff --git a/include/matx/executors/jit_kernel.h b/include/matx/executors/jit_kernel.h index e0009191..621f0eb5 100644 --- a/include/matx/executors/jit_kernel.h +++ b/include/matx/executors/jit_kernel.h @@ -40,16 +40,16 @@ static const char *matxKernelStr = "\n\ namespace matx {\n\ namespace detail {\n\ template \n\ - __global__ void matxOpT0Kernel(Op op) {\n\ + __global__ void matxOpT0KernelBlock(Op op) {\n\ if constexpr (cuda::std::is_pointer_v) {\n\ - (*op)();\n\ + (*op).template operator()();\n\ } else {\n\ - op();\n\ + op.template operator()();\n\ }\n\ }\n\ \n\ template \n\ - __global__ void matxOpT1Kernel(Op op, matx::index_t size0) {\n\ + __global__ void matxOpT1KernelBlock(Op op, matx::index_t size0) {\n\ matx::index_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;\n\ if (idx * static_cast(CurrentCapabilities::ept) < size0) {\n\ if constexpr (cuda::std::is_pointer_v) {\n\ @@ -61,7 +61,7 @@ namespace matx {\n\ }\n\ \n\ template \n\ - __global__ void matxOpT2Kernel(Op op, matx::index_t size0, matx::index_t size1) {\n\ + __global__ void matxOpT2KernelBlock(Op op, matx::index_t size0, matx::index_t size1) {\n\ matx::index_t idx = threadIdx.x;\n\ matx::index_t idy = static_cast(blockIdx.x)*blockDim.y + threadIdx.y;\n\ if (idx * static_cast(CurrentCapabilities::ept) < size1 && idy < size0) {\n\ @@ -74,7 +74,7 @@ namespace matx {\n\ }\n\ \n\ template \n\ - __global__ void matxOpT2StrideKernel(Op op, matx::index_t size0, matx::index_t size1) {\n\ + __global__ void matxOpT2StrideKernelBlock(Op op, matx::index_t size0, matx::index_t size1) {\n\ matx::index_t idx = threadIdx.x;\n\ for(matx::index_t idy = static_cast(blockIdx.x);\n\ idy < size0;\n\ @@ -88,7 +88,7 @@ namespace matx {\n\ }\n\ \n\ template \n\ - __global__ void matxOpT3Kernel(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2) {\n\ + __global__ void matxOpT3KernelBlock(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2) {\n\ matx::index_t idx = threadIdx.x;\n\ matx::index_t idy = static_cast(blockIdx.x) * blockDim.y + threadIdx.y;\n\ matx::index_t idz = static_cast(blockIdx.y) * blockDim.z + threadIdx.z;\n\ @@ -102,7 +102,7 @@ namespace matx {\n\ }\n\ \n\ template \n\ - __global__ void matxOpT3StrideKernel(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2) {\n\ + __global__ void matxOpT3StrideKernelBlock(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2) {\n\ matx::index_t idx = threadIdx.x;\n\ for(matx::index_t idz = static_cast(blockIdx.y) * blockDim.z + threadIdx.z;\n\ idz < size0;\n\ @@ -122,12 +122,11 @@ namespace matx {\n\ }\n\ }\n\ template \n\ - __global__ void matxOpT4Kernel(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2, matx::index_t size3) {\n\ + __global__ void matxOpT4KernelBlock(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2, matx::index_t size3) {\n\ matx::index_t idx = threadIdx.x;\n\ - matx::index_t nmy = static_cast(blockIdx.x) * blockDim.y + threadIdx.y;\n\ - matx::index_t idy = nmy % size2;\n\ - matx::index_t idz = nmy / size2;\n\ - matx::index_t idw = static_cast(blockIdx.y) * blockDim.z + threadIdx.z;\n\ + matx::index_t idy = blockIdx.x;\n\ + matx::index_t idz = blockIdx.y;\n\ + matx::index_t idw = blockIdx.z;\n\ if (idx * static_cast(CurrentCapabilities::ept) < size3 && idy < size2 && idz < size1 && idw < size0) {\n\ if constexpr (cuda::std::is_pointer_v) {\n\ (*op).template operator()(idw, idz, idy, idx);\n\ @@ -138,7 +137,7 @@ namespace matx {\n\ }\n\ \n\ template \n\ - __global__ void matxOpT4StrideKernel(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2, matx::index_t size3) {\n\ + __global__ void matxOpT4StrideKernelBlock(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2, matx::index_t size3) {\n\ matx::index_t idx = threadIdx.x;\n\ for(matx::index_t nmy = static_cast(blockIdx.x) * blockDim.y + threadIdx.y;\n\ nmy < size1 * size2;\n\ @@ -160,6 +159,168 @@ namespace matx {\n\ }\n\ }\n\ }\n\ + \n\ + template \n\ + __global__ void matxOpT0Kernel(Op op) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()();\n\ + }\n\ + else {\n\ + op.template operator()();\n\ + }\n\ + }\n\ + \n\ + template \n\ + __global__ void matxOpT1Kernel(Op op, matx::index_t size0) {\n\ + matx::index_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;\n\ + if (idx * static_cast(CurrentCapabilities::ept) < size0) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idx);\n\ + }\n\ + else {\n\ + op.template operator()(idx);\n\ + }\n\ + }\n\ + }\n\ + \n\ + template \n\ + __global__ void matxOpT2Kernel(Op op, matx::index_t size0, matx::index_t size1) {\n\ + matx::index_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;\n\ + matx::index_t idy = static_cast(blockIdx.y) * blockDim.y + threadIdx.y;\n\ + if (idx * static_cast(CurrentCapabilities::ept) < size1 && idy < size0) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idy, idx);\n\ + }\n\ + else {\n\ + op.template operator()(idy, idx);\n\ + }\n\ + }\n\ + }\n\ + \n\ + template \n\ + __global__ void matxOpT2StrideKernel(Op op, matx::index_t size0, matx::index_t size1) {\n\ + for(matx::index_t idy = static_cast(blockIdx.y) * blockDim.y + threadIdx.y;\n\ + idy < size0;\n\ + idy += blockDim.y * gridDim.y) {\n\ + for(matx::index_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;\n\ + idx * static_cast(CurrentCapabilities::ept) < size1;\n\ + idx += blockDim.x * gridDim.x) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idy, idx);\n\ + }\n\ + else {\n\ + op.template operator()(idy, idx);\n\ + }\n\ + }\n\ + }\n\ + }\n\ + \n\ + template \n\ + __global__ void matxOpT3Kernel(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2) {\n\ + matx::index_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;\n\ + matx::index_t idy = static_cast(blockIdx.y) * blockDim.y + threadIdx.y;\n\ + matx::index_t idz = static_cast(blockIdx.z) * blockDim.z + threadIdx.z;\n\ + if (idx * static_cast(CurrentCapabilities::ept) < size2 && idy < size1 && idz < size0) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idz, idy, idx);\n\ + }\n\ + else {\n\ + op.template operator()(idz, idy, idx);\n\ + }\n\ + }\n\ + }\n\ + \n\ + template \n\ + __global__ void matxOpT3StrideKernel(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2) {\n\ + for(matx::index_t idz = static_cast(blockIdx.z) * blockDim.z + threadIdx.z;\n\ + idz < size0;\n\ + idz += blockDim.z * gridDim.z) {\n\ + for (matx::index_t idy = static_cast(blockIdx.y) * blockDim.y + threadIdx.y;\n\ + idy < size1;\n\ + idy += blockDim.y * gridDim.y) {\n\ + for(matx::index_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;\n\ + idx * static_cast(CurrentCapabilities::ept) < size2;\n\ + idx += blockDim.x * gridDim.x) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idz, idy, idx);\n\ + }\n\ + else {\n\ + op.template operator()(idz, idy, idx);\n\ + }\n\ + }\n\ + }\n\ + }\n\ + }\n\ + \n\ + template \n\ + __global__ void matxOpT4Kernel(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2, matx::index_t size3) {\n\ + matx::index_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;\n\ + matx::index_t nmy = static_cast(blockIdx.y) * blockDim.y + threadIdx.y;\n\ + matx::index_t idy = nmy % size2;\n\ + matx::index_t idz = nmy / size2;\n\ + matx::index_t idw = static_cast(blockIdx.z) * blockDim.z + threadIdx.z;\n\ + if (idx * static_cast(CurrentCapabilities::ept) < size3 && idy < size2 && idz < size1 && idw < size0) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idw, idz, idy, idx);\n\ + }\n\ + else {\n\ + op.template operator()(idw, idz, idy, idx);\n\ + }\n\ + }\n\ + }\n\ + \n\ + template \n\ + __global__ void matxOpT4StrideKernel(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2, matx::index_t size3) {\n\ + for(matx::index_t nmy = static_cast(blockIdx.y) * blockDim.y + threadIdx.y;\n\ + nmy < size1 * size2;\n\ + nmy += blockDim.y * gridDim.y) {\n\ + matx::index_t idy = nmy % size2;\n\ + matx::index_t idz = nmy / size2;\n\ + if(idy < size2 && idz < size1) {\n\ + for(matx::index_t idw = static_cast(blockIdx.z) * blockDim.z + threadIdx.z;\n\ + idw < size0;\n\ + idw += blockDim.z * gridDim.z) {\n\ + for(matx::index_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;\n\ + idx * static_cast(CurrentCapabilities::ept) < size3;\n\ + idx += blockDim.x * gridDim.x) {\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idw, idz, idy, idx);\n\ + }\n\ + else {\n\ + op.template operator()(idw, idz, idy, idx);\n\ + }\n\ + }\n\ + }\n\ + }\n\ + }\n\ + }\n\ + \n\ + template \n\ + __global__ void matxOpTDKernel(Op op, const cuda::std::array sizes, matx::index_t mult) {\n\ + cuda::std::array indices;\n\ + static_assert(Op::Rank() >= 1, \"rank must exceed zero\");\n\ + matx::index_t x_abs = static_cast(blockIdx.x) * blockDim.x + threadIdx.x;\n\ + const bool valid = x_abs < mult*sizes[0];\n\ + if (valid) {\n\ + MATX_LOOP_UNROLL\n\ + for (int r = 0; r < Op::Rank()-1; r++) {\n\ + indices[r] = x_abs / mult;\n\ + x_abs -= indices[r] * mult;\n\ + mult /= sizes[r+1];\n\ + }\n\ + indices[Op::Rank()-1] = x_abs / mult;\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + cuda::std::apply([&](auto... args){\n\ + (*op).template operator()(args...);\n\ + }, indices);\n\ + }\n\ + else {\n\ + cuda::std::apply([&](auto... args){\n\ + op.template operator()(args...);\n\ + }, indices);\n\ + }\n\ + }\n\ + }\n\ }\n\ }"; #else diff --git a/include/matx/executors/kernel.h b/include/matx/executors/kernel.h index 86e6cc9f..4eee1edb 100644 --- a/include/matx/executors/kernel.h +++ b/include/matx/executors/kernel.h @@ -154,7 +154,7 @@ template __global__ void matxOpT4StrideKernel(Op op, index_t size0, index_t size1, index_t size2, index_t size3) { for(index_t nmy = static_cast(blockIdx.y) * blockDim.y + threadIdx.y; - nmy < size2 * size3; + nmy < size1 * size2; nmy += blockDim.y * gridDim.y) { index_t idy = nmy % size2; index_t idz = nmy / size2; diff --git a/include/matx/generators/alternate.h b/include/matx/generators/alternate.h index 1afac9f4..ba2d9305 100644 --- a/include/matx/generators/alternate.h +++ b/include/matx/generators/alternate.h @@ -95,6 +95,13 @@ namespace matx return get_jit_class_name() + "<" + type_to_string() + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/generators/bartlett.h b/include/matx/generators/bartlett.h index d7f15cea..b5db82a1 100644 --- a/include/matx/generators/bartlett.h +++ b/include/matx/generators/bartlett.h @@ -107,6 +107,13 @@ namespace matx return get_jit_class_name() + "<" + type_to_string() + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/generators/blackman.h b/include/matx/generators/blackman.h index 839542d8..142a8d28 100644 --- a/include/matx/generators/blackman.h +++ b/include/matx/generators/blackman.h @@ -108,6 +108,13 @@ namespace matx return get_jit_class_name() + "<" + type_to_string() + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/generators/chirp.h b/include/matx/generators/chirp.h index 043295fc..91c353cd 100644 --- a/include/matx/generators/chirp.h +++ b/include/matx/generators/chirp.h @@ -130,6 +130,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), sop_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { @@ -265,6 +272,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), sop_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/generators/diag.h b/include/matx/generators/diag.h index b9d1a890..978d6399 100644 --- a/include/matx/generators/diag.h +++ b/include/matx/generators/diag.h @@ -117,7 +117,14 @@ namespace matx return ""; #endif } - else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { #ifdef MATX_EN_JIT const auto [key, value] = get_jit_op_str(); if (in.find(key) == in.end()) { diff --git a/include/matx/generators/fftfreq.h b/include/matx/generators/fftfreq.h index 5db2f0b6..f2da8883 100644 --- a/include/matx/generators/fftfreq.h +++ b/include/matx/generators/fftfreq.h @@ -114,6 +114,13 @@ namespace matx return get_jit_class_name() + "<" + type_to_string() + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/generators/flattop.h b/include/matx/generators/flattop.h index 726b107c..d02fcaa3 100644 --- a/include/matx/generators/flattop.h +++ b/include/matx/generators/flattop.h @@ -127,6 +127,13 @@ namespace matx return get_jit_class_name() + "<" + type_to_string() + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/generators/generator1d.h b/include/matx/generators/generator1d.h index 05e60db8..01d7b7db 100644 --- a/include/matx/generators/generator1d.h +++ b/include/matx/generators/generator1d.h @@ -60,7 +60,8 @@ namespace matx } } else if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY || - Cap == OperatorCapability::JIT_CLASS_QUERY) { + Cap == OperatorCapability::JIT_CLASS_QUERY || + Cap == OperatorCapability::SUPPORTS_JIT) { // Forward JIT-related capabilities to the generator return f_.template get_capability(in); } diff --git a/include/matx/generators/hamming.h b/include/matx/generators/hamming.h index 5e779d3c..bb2542ff 100644 --- a/include/matx/generators/hamming.h +++ b/include/matx/generators/hamming.h @@ -108,6 +108,13 @@ namespace matx return get_jit_class_name() + "<" + type_to_string() + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/generators/hanning.h b/include/matx/generators/hanning.h index 13ccbea7..37172a77 100644 --- a/include/matx/generators/hanning.h +++ b/include/matx/generators/hanning.h @@ -108,6 +108,13 @@ namespace matx return get_jit_class_name() + "<" + type_to_string() + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/generators/linspace.h b/include/matx/generators/linspace.h index 67123c6d..1d1dc268 100644 --- a/include/matx/generators/linspace.h +++ b/include/matx/generators/linspace.h @@ -142,6 +142,13 @@ namespace matx return get_jit_class_name() + "<" + type_to_string() + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/generators/logspace.h b/include/matx/generators/logspace.h index 7c3829c0..bfadfb04 100644 --- a/include/matx/generators/logspace.h +++ b/include/matx/generators/logspace.h @@ -136,6 +136,13 @@ namespace matx return "JITLogspace<" + range_jit_name + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/generators/meshgrid.h b/include/matx/generators/meshgrid.h index 1c69dff0..510beeeb 100644 --- a/include/matx/generators/meshgrid.h +++ b/include/matx/generators/meshgrid.h @@ -127,7 +127,14 @@ namespace matx return ""; #endif } - else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { #ifdef MATX_EN_JIT const auto [key, value] = get_jit_op_str(); if (in.find(key) == in.end()) { diff --git a/include/matx/generators/range.h b/include/matx/generators/range.h index 019f314c..7182e4d0 100644 --- a/include/matx/generators/range.h +++ b/include/matx/generators/range.h @@ -128,6 +128,13 @@ MATX_IGNORE_WARNING_POP_GCC return get_jit_class_name() + "<" + type_to_string() + " >"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/apply.h b/include/matx/operators/apply.h index 05e8555c..79ab3a04 100644 --- a/include/matx/operators/apply.h +++ b/include/matx/operators/apply.h @@ -34,6 +34,7 @@ #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" +#include namespace matx { @@ -83,8 +84,18 @@ namespace matx template __MATX_INLINE__ __MATX_HOST__ auto get_capability([[maybe_unused]] InType &in) const { - auto self_has_cap = capability_attributes::default_value; - return combine_capabilities(self_has_cap, get_combined_ops_capability(in, ops_)); + if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { + // Cannot JIT compile user-defined lambdas/functors - no way to get source code at runtime + return ""; + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { + // Cannot JIT compile user-defined lambdas/functors - no way to get source code at runtime + return false; + } + else { + auto self_has_cap = capability_attributes::default_value; + return combine_capabilities(self_has_cap, get_combined_ops_capability(in, ops_)); + } } template diff --git a/include/matx/operators/apply_idx.h b/include/matx/operators/apply_idx.h index 9a8d09c3..cc911b20 100644 --- a/include/matx/operators/apply_idx.h +++ b/include/matx/operators/apply_idx.h @@ -34,6 +34,7 @@ #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" +#include namespace matx { @@ -87,7 +88,15 @@ namespace matx template __MATX_INLINE__ __MATX_HOST__ auto get_capability([[maybe_unused]] InType& in) const { - if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { + if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { + // Cannot JIT compile user-defined lambdas/functors - no way to get source code at runtime + return ""; + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { + // Cannot JIT compile user-defined lambdas/functors - no way to get source code at runtime + return false; + } + else if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { const auto my_cap = cuda::std::array{ElementsPerThread::ONE, ElementsPerThread::ONE}; return combine_capabilities(my_cap, get_combined_ops_capability(in, ops_)); diff --git a/include/matx/operators/at.h b/include/matx/operators/at.h index 236099db..9adc0d42 100644 --- a/include/matx/operators/at.h +++ b/include/matx/operators/at.h @@ -63,9 +63,9 @@ namespace matx __MATX_INLINE__ std::string get_jit_class_name() const { std::string idx_str; - for (size_t i = 0; i < sizeof...(Is); i++) { + for (int32_t i = 0; i < static_cast(sizeof...(Is)); i++) { idx_str += std::to_string(idx_[i]); - if (i < sizeof...(Is) - 1) idx_str += "_"; + if (i < static_cast(sizeof...(Is)) - 1) idx_str += "_"; } return std::format("JITAt_idx{}", idx_str); } @@ -128,6 +128,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/binary_operators.h b/include/matx/operators/binary_operators.h index 9475a945..02f632c7 100644 --- a/include/matx/operators/binary_operators.h +++ b/include/matx/operators/binary_operators.h @@ -163,7 +163,7 @@ namespace matx } __MATX_INLINE__ auto get_jit_op_str() const { - cuda::std::array out_dims_; + cuda::std::array(Rank())> out_dims_; for (int i = 0; i < Rank(); ++i) { out_dims_[i] = Size(i); } @@ -213,6 +213,15 @@ namespace matx return get_jit_class_name() + "<" + lhs_jit_name + "," + rhs_jit_name + "," + op_jit_name + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, + detail::get_operator_capability(in1_, in), + detail::get_operator_capability(in2_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/cart2sph.h b/include/matx/operators/cart2sph.h index 0e20c795..e87fd978 100644 --- a/include/matx/operators/cart2sph.h +++ b/include/matx/operators/cart2sph.h @@ -159,6 +159,16 @@ namespace matx return std::format("{}<{},{},{}>", get_jit_class_name(), x_jit_name, y_jit_name, z_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, + detail::get_operator_capability(x_, in), + detail::get_operator_capability(y_, in), + detail::get_operator_capability(z_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/cast.h b/include/matx/operators/cast.h index 8fb0de22..a5cc37f2 100644 --- a/include/matx/operators/cast.h +++ b/include/matx/operators/cast.h @@ -149,6 +149,13 @@ namespace matx return get_jit_class_name() + "<" + op_jit_name + "," + detail::type_to_string() + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/clone.h b/include/matx/operators/clone.h index 6d1228d7..1e11f290 100644 --- a/include/matx/operators/clone.h +++ b/include/matx/operators/clone.h @@ -67,9 +67,9 @@ namespace matx sizes_str += std::to_string(sizes_[i]); if (i < CRank - 1) sizes_str += "_"; } - for (size_t i = 0; i < T::Rank(); i++) { + for (int32_t i = 0; i < T::Rank(); i++) { dims_str += std::to_string(dims_[i]); - if (i < T::Rank() - 1) dims_str += "_"; + if (i < static_cast(T::Rank()) - 1) dims_str += "_"; } return std::format("JITClone_sizes{}_dims{}", sizes_str, dims_str); } @@ -229,6 +229,13 @@ MATX_IGNORE_WARNING_POP_GCC return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/collapse.h b/include/matx/operators/collapse.h index 79615b17..e9a4df44 100644 --- a/include/matx/operators/collapse.h +++ b/include/matx/operators/collapse.h @@ -241,6 +241,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { @@ -499,6 +506,13 @@ MATX_LOOP_UNROLL return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/comma.h b/include/matx/operators/comma.h index 3b965e20..14966522 100644 --- a/include/matx/operators/comma.h +++ b/include/matx/operators/comma.h @@ -160,7 +160,16 @@ namespace matx return ""; #endif } - else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, + detail::get_operator_capability(op1_, in), + detail::get_operator_capability(op2_, in)); +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { #ifdef MATX_EN_JIT const auto [key, value] = get_jit_op_str(); if (in.find(key) == in.end()) { diff --git a/include/matx/operators/concat.h b/include/matx/operators/concat.h index 3f7d3c79..a6bfead5 100644 --- a/include/matx/operators/concat.h +++ b/include/matx/operators/concat.h @@ -34,7 +34,9 @@ #include "matx/core/type_utils.h" +#include "matx/core/utils.h" #include "matx/operators/base_operator.h" +#include namespace matx { @@ -74,6 +76,161 @@ namespace matx return get_str<-1>(); } +#ifdef MATX_EN_JIT + struct JIT_Storage { + cuda::std::tuple>...> ops_; + }; + + JIT_Storage ToJITStorage() const { + return JIT_Storage{cuda::std::apply([](const auto&... ops) { + return cuda::std::make_tuple(detail::to_jit_storage(ops)...); + }, ops_)}; + } + + template + __MATX_INLINE__ std::string get_sizes_str() const { + if constexpr (I < sizeof...(Ts)) { + const auto& op = cuda::std::get(ops_); + std::string sizes = "op" + std::to_string(I) + "_"; + for (int d = 0; d < RANK; d++) { + sizes += std::to_string(op.Size(d)); + if (d < RANK - 1) sizes += "x"; + } + if constexpr (I < sizeof...(Ts) - 1) { + return sizes + "_" + get_sizes_str(); + } else { + return sizes; + } + } else { + return ""; + } + } + + __MATX_INLINE__ std::string get_jit_class_name() const { + return std::format("JITConcat_axis{}_num{}_{}", axis_, sizeof...(Ts), get_sizes_str<0>()); + } + + template + __MATX_INLINE__ std::string get_jit_type_list() const { + if constexpr (I < sizeof...(Ts) - 1) { + return "typename T" + std::to_string(I) + ", " + get_jit_type_list(); + } else if constexpr (I == sizeof...(Ts) - 1) { + return "typename T" + std::to_string(I); + } else { + return ""; + } + } + + template + __MATX_INLINE__ std::string get_jit_storage_tuple_types() const { + if constexpr (I < sizeof...(Ts) - 1) { + return "typename detail::inner_storage_or_self_t>, " + get_jit_storage_tuple_types(); + } else if constexpr (I == sizeof...(Ts) - 1) { + return "typename detail::inner_storage_or_self_t>"; + } else { + return ""; + } + } + + __MATX_INLINE__ std::string get_jit_storage_tuple() const { + return "cuda::std::tuple<" + get_jit_storage_tuple_types<0>() + "> ops_;\n"; + } + + __MATX_INLINE__ auto get_jit_op_str() const { + std::string func_name = get_jit_class_name(); + cuda::std::array out_dims_; + for (int i = 0; i < RANK; i++) { + out_dims_[i] = Size(i); + } + + return cuda::std::make_tuple( + func_name, + std::format("template <{}> struct {} {{\n" + " using value_type = typename T0::value_type;\n" + " using matxop = bool;\n" + " constexpr static int RANK_ = {};\n" + " constexpr static cuda::std::array sizes_ = {{ {} }};\n" + " constexpr static int axis_ = {};\n" + " constexpr static index_t size_ = {};\n" + " {}" + " // Non-const get_impl for lvalue assignments\n" + " template \n" + " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) get_impl(cuda::std::array& indices) {{\n" + " if constexpr ( I == N ) {{\n" + " auto &op = cuda::std::get<0>(ops_);\n" + " return cuda::std::apply([&](auto &&...call_args) -> decltype(auto) {{\n" + " return op.template operator()(call_args...);\n" + " }}, indices);\n" + " }} else {{\n" + " auto &op = cuda::std::get(ops_);\n" + " auto idx = indices[axis_];\n" + " auto size = op.Size(axis_);\n" + " if(idx < size) {{\n" + " return cuda::std::apply([&](auto &&...call_args) -> decltype(auto) {{\n" + " return op.template operator()(call_args...);\n" + " }}, indices);\n" + " }} else {{\n" + " indices[axis_] -= size;\n" + " return get_impl(indices);\n" + " }}\n" + " }}\n" + " }}\n" + " // Const get_impl\n" + " template \n" + " __MATX_INLINE__ __MATX_DEVICE__ auto get_impl(cuda::std::array& indices) const {{\n" + " using return_t = cuda::std::conditional_t<\n" + " (CapType::ept == ElementsPerThread::ONE),\n" + " value_type,\n" + " Vector(CapType::ept)>>;\n" + " if constexpr ( I == N ) {{\n" + " const auto &op = cuda::std::get<0>(ops_);\n" + " return cuda::std::apply([&](auto &&...call_args) -> return_t {{\n" + " return op.template operator()(call_args...);\n" + " }}, indices);\n" + " }} else {{\n" + " const auto &op = cuda::std::get(ops_);\n" + " auto idx = indices[axis_];\n" + " auto size = op.Size(axis_);\n" + " if(idx < size) {{\n" + " return cuda::std::apply([&](auto &&...call_args) -> return_t {{\n" + " return op.template operator()(call_args...);\n" + " }}, indices);\n" + " }} else {{\n" + " indices[axis_] -= size;\n" + " return get_impl(indices);\n" + " }}\n" + " }}\n" + " }}\n" + " // Const operator()\n" + " template \n" + " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const {{\n" + " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" + " cuda::std::array idx{{indices...}};\n" + " return get_impl(idx);\n" + " }} else {{\n" + " return Vector(CapType::ept)>{{}};\n" + " }}\n" + " }}\n" + " // Non-const operator()\n" + " template \n" + " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) {{\n" + " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" + " cuda::std::array idx{{indices...}};\n" + " return get_impl(idx);\n" + " }} else {{\n" + " return Vector(CapType::ept)>{{}};\n" + " }}\n" + " }}\n" + " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank() {{ return RANK_; }}\n" + " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const {{\n" + " return sizes_[dim];\n" + " }}\n" + "}};\n", + get_jit_type_list<0>(), func_name, RANK, detail::array_to_string(out_dims_), axis_, size_, get_jit_storage_tuple(), sizeof...(Ts), sizeof...(Ts)) + ); + } +#endif + __MATX_INLINE__ ConcatOp(int axis, const Ts&... ts) : ops_(ts...), axis_(axis) { static_assert(RANK > 0, "Cannot concatenate rank-0 tensors"); @@ -98,7 +255,8 @@ namespace matx if constexpr ( I == N ) { // This should never happen, but we return a fake value from the first tuple element anyways auto &op = cuda::std::get<0>(ops_); - return cuda::std::apply([&](auto &&...call_args) -> decltype(auto) { return op.template operator()(call_args...); }, indices); + return cuda::std::apply([&](auto &&...call_args) -> decltype(auto) { + return op.template operator()(call_args...); }, indices); } else { auto &op = cuda::std::get(ops_); auto idx = indices[axis_]; @@ -106,7 +264,8 @@ namespace matx // If in range of this operator if(idx < size) { // evaluate operator - return cuda::std::apply([&](auto &&...call_args) -> decltype(auto) { return op.template operator()(call_args...); }, indices); + return cuda::std::apply([&](auto &&...call_args) -> decltype(auto) { + return op.template operator()(call_args...); }, indices); } else { // otherwise remove this operator and recurse indices[axis_] -= size; @@ -183,7 +342,41 @@ namespace matx template __MATX_INLINE__ __MATX_HOST__ auto get_capability([[maybe_unused]] InType & in) const { - if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { + if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { +#ifdef MATX_EN_JIT + return get_jit_class_name() + "<" + get_jit_type_params<0>() + ">"; +#else + return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, get_combined_ops_capability(in, ops_)); +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { +#ifdef MATX_EN_JIT + // Get the key/value pair from get_jit_op_str() + const auto [key, value] = get_jit_op_str(); + + // Insert into the map if the key doesn't exist + if (in.find(key) == in.end()) { + in[key] = value; + } + + // Also handle child operators + cuda::std::apply([&in](const auto&... ops) { + (detail::get_operator_capability(ops, in), ...); + }, ops_); + + return true; +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { const auto my_cap = cuda::std::array{ElementsPerThread::ONE, ElementsPerThread::ONE}; return combine_capabilities(my_cap, get_combined_ops_capability(in, ops_)); } else { @@ -193,6 +386,23 @@ namespace matx } } +#ifdef MATX_EN_JIT + template + __MATX_INLINE__ std::string get_jit_type_params() const { + if constexpr (I < sizeof...(Ts)) { + VoidCapabilityType void_type{}; + auto type_name = detail::get_operator_capability(cuda::std::get(ops_), void_type); + if constexpr (I < sizeof...(Ts) - 1) { + return type_name + "," + get_jit_type_params(); + } else { + return type_name; + } + } else { + return ""; + } + } +#endif + static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() noexcept { return RANK; diff --git a/include/matx/operators/constval.h b/include/matx/operators/constval.h index 59e8b431..73329cc0 100644 --- a/include/matx/operators/constval.h +++ b/include/matx/operators/constval.h @@ -32,6 +32,7 @@ #pragma once +#include "matx/core/utils.h" namespace matx { @@ -58,20 +59,36 @@ namespace matx } __MATX_INLINE__ std::string get_jit_class_name() const { - std::string val_str; - if constexpr (std::is_floating_point_v) { - val_str = std::format("{}", v_); - } else { - val_str = std::to_string(v_); - } - return std::format("JITConstVal_val{}_rank{}", val_str, Rank()); + // Convert the numeric value to a valid C++ symbol name + std::string val_str = detail::number_to_symbol(v_); + return std::format("JITConstVal_val{}_rank{}", val_str, Rank() == matxNoRank ? "No" : std::to_string(Rank())); } __MATX_INLINE__ auto get_jit_op_str() const { std::string func_name = get_jit_class_name(); - cuda::std::array out_dims_; - for (int i = 0; i < Rank(); ++i) { - out_dims_[i] = Size(i); + std::string dims_array_str; + std::string size_func_str; + + if constexpr (!is_noshape_v) { + cuda::std::array(RANK)> out_dims_; + for (int i = 0; i < RANK; ++i) { + out_dims_[i] = Size(i); + } + dims_array_str = std::format("constexpr static cuda::std::array out_dims_ = {{ {} }};\n ", + RANK, detail::array_to_string(out_dims_)); + size_func_str = std::format("constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const {{ return out_dims_[dim]; }}\n "); + } else { + dims_array_str = ""; + size_func_str = std::format("constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const {{ return 0; }}\n "); + } + + // Format the value for code generation + std::string val_init_str; + if constexpr (is_complex_v) { + // For complex numbers, use constructor syntax: T{real, imag} + val_init_str = std::format("T{{{}, {}}}", v_.real(), v_.imag()); + } else { + val_init_str = std::format("{}", v_); } return cuda::std::make_tuple( @@ -79,8 +96,8 @@ namespace matx std::format("template struct {} {{\n" " using value_type = T;\n" " using matxop = bool;\n" - " constexpr static cuda::std::array out_dims_ = {{ {} }};\n" - " constexpr static T v_ = static_cast({});\n" + " {}" + " constexpr static T v_ = {};\n" " template \n" " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is...) const\n" " {{\n" @@ -91,9 +108,9 @@ namespace matx " }}\n" " }}\n" " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank() {{ return {}; }}\n" - " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const {{ return out_dims_[dim]; }}\n" + " {}" "}};\n", - func_name, Rank(), detail::array_to_string(out_dims_), v_, Rank()) + func_name, dims_array_str, val_init_str, Rank(), size_func_str) ); } #endif @@ -139,6 +156,13 @@ namespace matx return get_jit_class_name() + "<" + type_to_string() + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/cross.h b/include/matx/operators/cross.h index 300c25db..2342b162 100644 --- a/include/matx/operators/cross.h +++ b/include/matx/operators/cross.h @@ -93,7 +93,49 @@ namespace matx " typename detail::inner_storage_or_self_t> a_;\n" " typename detail::inner_storage_or_self_t> b_;\n" " template \n" - " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) const {{ /* cross product logic */ }}\n" + " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const {{\n" + " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" + " cuda::std::array idx{{indices...}};\n" + " auto idxOut = idx[idx.size() - 1];\n" + " cuda::std::array idx0{{idx}};\n" + " cuda::std::array idx1{{idx}};\n" + " cuda::std::array idx2{{idx}};\n" + " idx0[idx0.size() - 1] = 0LL;\n" + " idx1[idx1.size() - 1] = 1LL;\n" + " idx2[idx2.size() - 1] = 2LL;\n" + " auto a0 = get_value(a_, idx0);\n" + " auto a1 = get_value(a_, idx1);\n" + " auto b0 = get_value(b_, idx0);\n" + " auto b1 = get_value(b_, idx1);\n" + " if (idxOut == 2 || (isA2D_ && isB2D_)) {{\n" + " return a0 * b1 - a1 * b0;\n" + " }}\n" + " if (!isA2D_ && !isB2D_) {{\n" + " auto a2 = get_value(a_, idx2);\n" + " auto b2 = get_value(b_, idx2);\n" + " if (idxOut == 0) {{\n" + " return a1 * b2 - a2 * b1;\n" + " }}\n" + " return a2 * b0 - a0 * b2;\n" + " }}\n" + " else if (isA2D_ && !isB2D_) {{\n" + " auto b2 = get_value(b_, idx2);\n" + " if (idxOut == 0) {{\n" + " return a1 * b2;\n" + " }}\n" + " return -a0 * b2;\n" + " }}\n" + " else {{\n" + " auto a2 = get_value(a_, idx2);\n" + " if (idxOut == 0) {{\n" + " return -a2 * b1;\n" + " }}\n" + " return a2 * b0;\n" + " }}\n" + " }} else {{\n" + " return Vector(CapType::ept)>();\n" + " }}\n" + " }}\n" " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank() {{ return Rank_; }}\n" " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const {{ return out_dims_[dim]; }}\n" "}};\n", @@ -209,6 +251,15 @@ namespace matx return std::format("{}<{},{}>", get_jit_class_name(), a_jit_name, b_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, + detail::get_operator_capability(a_, in), + detail::get_operator_capability(b_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { @@ -240,7 +291,7 @@ namespace matx return combine_capabilities( self_has_cap, detail::get_operator_capability(a_, in), - detail::get_operator_capability(b_, in) + detail::get_operator_capability(b_, in) ); } } diff --git a/include/matx/operators/diag.h b/include/matx/operators/diag.h index 28bb5807..0e633d6c 100644 --- a/include/matx/operators/diag.h +++ b/include/matx/operators/diag.h @@ -90,7 +90,37 @@ namespace matx " constexpr static cuda::std::array out_dims_ = {{ {} }};\n" " typename detail::inner_storage_or_self_t> op_;\n" " template \n" - " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) const {{ /* diag logic */ }}\n" + " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const {{\n" + " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" + " if constexpr (RANK_ == 1) {{\n" + " cuda::std::array idx{{indices...}};\n" + " if (idx[0] == idx[1]) {{\n" + " return get_value(op_, cuda::std::array{{idx[0]}});\n" + " }}\n" + " else {{\n" + " return static_cast(0);\n" + " }}\n" + " }}\n" + " else {{\n" + " cuda::std::array idx{{indices...}};\n" + " cuda::std::array tmp;\n" + " for (int i = 0; i < RANK_ - 2; i++) {{\n" + " tmp[i] = idx[i];\n" + " }}\n" + " if (k_ < 0) {{\n" + " tmp[RANK_ - 1] = idx[RANK_ - 2];\n" + " tmp[RANK_ - 2] = idx[RANK_ - 2] - k_;\n" + " }}\n" + " else {{\n" + " tmp[RANK_ - 2] = idx[RANK_ - 2];\n" + " tmp[RANK_ - 1] = idx[RANK_ - 2] + k_;\n" + " }}\n" + " return get_value(op_, tmp);\n" + " }}\n" + " }} else {{\n" + " return Vector(CapType::ept)>{{}};\n" + " }}\n" + " }}\n" " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank() {{ return OutRank_; }}\n" " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const {{ return out_dims_[dim]; }}\n" "}};\n", @@ -165,6 +195,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/fft.h b/include/matx/operators/fft.h index a245385c..a825bdf3 100644 --- a/include/matx/operators/fft.h +++ b/include/matx/operators/fft.h @@ -186,6 +186,11 @@ namespace matx dx_fft_helper_.set_fft_type(DeduceFFTTransformType::ctype, value_type>()); dx_fft_helper_.set_direction(Direction); dx_fft_helper_.set_cc(cc); + // if (fft_size_ <= 32) { + // dx_fft_helper_.set_method(cuFFTDxMethod::REGISTER); + // } else { + dx_fft_helper_.set_method(cuFFTDxMethod::SHARED); + //} bool contiguous = false; if constexpr (is_tensor_view_v) { @@ -274,10 +279,7 @@ namespace matx } else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { bool supported = true; - if (((fft_size_ & (fft_size_ - 1)) != 0 || fft_size_ == 0) // Only support power-of-2 FFT sizes for JIT support - || is_complex_half_v // No half support in MatX for fusion yet - || !is_complex_v) // Only support C2C for JIT support - { + if (!dx_fft_helper_.template CheckJITSizeAndTypeRequirements()) { supported = false; } else { @@ -309,7 +311,10 @@ namespace matx // Currently MatX only attempts to use the "best" EPT as returned by cuFFTDx. In the future we may // try other EPT values that yield different SHM values. if (dx_fft_helper_.IsSupported()) { - auto result = combine_capabilities(dx_fft_helper_.GetEPTs(), detail::get_operator_capability(a_, in)); + auto epts = dx_fft_helper_.GetEPTs(); + // epts[0] = ElementsPerThread::EIGHT; + // epts[1] = ElementsPerThread::EIGHT; + auto result = combine_capabilities(epts, detail::get_operator_capability(a_, in)); MATX_LOG_DEBUG("ELEMENTS_PER_THREAD (JIT supported): [{},{}]", static_cast(result[0]), static_cast(result[1])); return result; } @@ -330,10 +335,10 @@ namespace matx else if constexpr (Cap == OperatorCapability::GROUPS_PER_BLOCK) { int ffts_per_block_candidate; - if constexpr (RANK > 1) { + if constexpr (Rank() > 1) { const int ffts_per_block = dx_fft_helper_.GetFFTsPerBlock(); const auto last_dim = a_.Size(a_.Rank() - 2); - int ffts_per_block_candidate = ffts_per_block; + ffts_per_block_candidate = ffts_per_block; // Try to find an ffts_per_block that evenly divides into last dimension size // Decrease ffts_per_block until it divides evenly or until 1 while (ffts_per_block_candidate > 1 && (last_dim % ffts_per_block_candidate != 0)) { @@ -353,9 +358,14 @@ namespace matx else if constexpr (Cap == OperatorCapability::SET_ELEMENTS_PER_THREAD) { dx_fft_helper_.set_current_elements_per_thread(in.ept); auto result = combine_capabilities(capability_attributes::default_value, detail::get_operator_capability(a_, in)); - MATX_LOG_DEBUG("SET_ELEMENTS_PER_THREAD: {}", result); + MATX_LOG_DEBUG("SET_ELEMENTS_PER_THREAD: {}", static_cast(in.ept)); return result; } + else if constexpr (Cap == OperatorCapability::GLOBAL_KERNEL) { + // If MathDx is enabled we always return false. Other checks on size and type may prevent JIT compilation. + MATX_LOG_DEBUG("GLOBAL_KERNEL: false"); + return false; + } else if constexpr (Cap == OperatorCapability::SET_GROUPS_PER_BLOCK) { dx_fft_helper_.set_ffts_per_block(in.groups_per_block); auto result = combine_capabilities(capability_attributes::default_value, detail::get_operator_capability(a_, in)); diff --git a/include/matx/operators/fftshift.h b/include/matx/operators/fftshift.h index c6d544e6..bddee9d1 100644 --- a/include/matx/operators/fftshift.h +++ b/include/matx/operators/fftshift.h @@ -147,6 +147,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { @@ -337,6 +344,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { @@ -522,6 +536,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { @@ -709,6 +730,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/flatten.h b/include/matx/operators/flatten.h index 169ba175..c5700a34 100644 --- a/include/matx/operators/flatten.h +++ b/include/matx/operators/flatten.h @@ -173,6 +173,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op1_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/frexp.h b/include/matx/operators/frexp.h index 7f1d125f..0b2f387d 100644 --- a/include/matx/operators/frexp.h +++ b/include/matx/operators/frexp.h @@ -79,7 +79,57 @@ namespace detail { " constexpr static cuda::std::array out_dims_ = {{ {} }};\n" " typename detail::inner_storage_or_self_t> a_;\n" " template \n" - " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const {{ /* Complex frexp logic */ }}\n" + " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const {{\n" + " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" + " const auto val = get_value(a_, indices...);\n" + " int rexp;\n" + " if constexpr (is_cuda_complex_v) {{\n" + " if constexpr (cuda::std::is_same_v) {{\n" + " if constexpr (WHICH_ == 0) {{\n" + " return cuda::std::frexpf(val.real(), &rexp);\n" + " }} else if constexpr (WHICH_ == 1) {{\n" + " cuda::std::frexpf(val.real(), &rexp);\n" + " return rexp;\n" + " }} else if constexpr (WHICH_ == 2) {{\n" + " return cuda::std::frexpf(val.imag(), &rexp);\n" + " }} else {{\n" + " cuda::std::frexpf(val.imag(), &rexp);\n" + " return rexp;\n" + " }}\n" + " }} else {{\n" + " if constexpr (WHICH_ == 0) {{\n" + " return cuda::std::frexp(val.real(), &rexp);\n" + " }} else if constexpr (WHICH_ == 1) {{\n" + " cuda::std::frexp(val.real(), &rexp);\n" + " return rexp;\n" + " }} else if constexpr (WHICH_ == 2) {{\n" + " return cuda::std::frexp(val.imag(), &rexp);\n" + " }} else {{\n" + " cuda::std::frexp(val.imag(), &rexp);\n" + " return rexp;\n" + " }}\n" + " }}\n" + " }} else {{\n" + " if constexpr (cuda::std::is_same_v) {{\n" + " const float frac = cuda::std::frexpf(val, &rexp);\n" + " if constexpr (WHICH_ == 0) {{\n" + " return frac;\n" + " }} else {{\n" + " return rexp;\n" + " }}\n" + " }} else {{\n" + " const double frac = cuda::std::frexp(val, &rexp);\n" + " if constexpr (WHICH_ == 0) {{\n" + " return frac;\n" + " }} else {{\n" + " return rexp;\n" + " }}\n" + " }}\n" + " }}\n" + " }} else {{\n" + " return Vector(CapType::ept)>{{}};\n" + " }}\n" + " }}\n" " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank() {{ return Rank_; }}\n" " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const {{ return out_dims_[dim]; }}\n" "}};\n", @@ -207,7 +257,14 @@ namespace detail { return ""; #endif } - else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(a_, in)); +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { #ifdef MATX_EN_JIT const auto [key, value] = get_jit_op_str(); if (in.find(key) == in.end()) { diff --git a/include/matx/operators/hermitian.h b/include/matx/operators/hermitian.h index 6800d7e6..716a7c13 100644 --- a/include/matx/operators/hermitian.h +++ b/include/matx/operators/hermitian.h @@ -163,6 +163,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/if.h b/include/matx/operators/if.h index 44113fff..ba37b711 100644 --- a/include/matx/operators/if.h +++ b/include/matx/operators/if.h @@ -221,6 +221,15 @@ namespace matx return std::format("{}<{},{}>", get_jit_class_name(), cond_jit_name, op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, + detail::get_operator_capability(cond_, in), + detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/ifelse.h b/include/matx/operators/ifelse.h index 7241d6b2..deae48a2 100644 --- a/include/matx/operators/ifelse.h +++ b/include/matx/operators/ifelse.h @@ -98,10 +98,10 @@ namespace matx " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const\n" " {{\n" " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" - " if (get_value(cond_, indices...)) {{\n" - " return get_value(op1_, indices...);\n" + " if (get_value(cond_, indices...)) {{\n" + " return get_value(op1_, indices...);\n" " }} else {{\n" - " return get_value(op2_, indices...);\n" + " return get_value(op2_, indices...);\n" " }}\n" " }} else {{\n" " return Vector(CapType::ept)>{{}};\n" @@ -253,7 +253,17 @@ namespace matx return ""; #endif } - else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, + detail::get_operator_capability(cond_, in), + detail::get_operator_capability(op1_, in), + detail::get_operator_capability(op2_, in)); +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { #ifdef MATX_EN_JIT const auto [key, value] = get_jit_op_str(); if (in.find(key) == in.end()) { diff --git a/include/matx/operators/index.h b/include/matx/operators/index.h index a4aca903..03749b69 100644 --- a/include/matx/operators/index.h +++ b/include/matx/operators/index.h @@ -130,6 +130,13 @@ namespace matx return get_jit_class_name(); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return true; +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/interleaved.h b/include/matx/operators/interleaved.h index d7773384..ba475289 100644 --- a/include/matx/operators/interleaved.h +++ b/include/matx/operators/interleaved.h @@ -180,6 +180,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/isclose.h b/include/matx/operators/isclose.h index 71614df6..97662ac3 100644 --- a/include/matx/operators/isclose.h +++ b/include/matx/operators/isclose.h @@ -143,6 +143,15 @@ namespace matx return std::format("{}<{},{}>", get_jit_class_name(), op1_jit_name, op2_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, + detail::get_operator_capability(op1_, in), + detail::get_operator_capability(op2_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/kronecker.h b/include/matx/operators/kronecker.h index cfb05810..ec3f45b2 100644 --- a/include/matx/operators/kronecker.h +++ b/include/matx/operators/kronecker.h @@ -186,6 +186,15 @@ namespace matx return std::format("{}<{},{}>", get_jit_class_name(), op1_jit_name, op2_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, + detail::get_operator_capability(op1_, in), + detail::get_operator_capability(op2_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/legendre.h b/include/matx/operators/legendre.h index 69bfd657..e5d61cb4 100644 --- a/include/matx/operators/legendre.h +++ b/include/matx/operators/legendre.h @@ -88,12 +88,60 @@ namespace matx " typename detail::inner_storage_or_self_t> n_;\n" " typename detail::inner_storage_or_self_t> m_;\n" " typename detail::inner_storage_or_self_t> in_;\n" + " template \n" + " static __MATX_INLINE__ __MATX_DEVICE__ TypeParam legendre_calc(int n, int m, TypeParam x) {{\n" + " if (m > n) return 0;\n" + " TypeParam a = cuda::std::sqrt(TypeParam(1)-x*x);\n" + " TypeParam d1 = 1, d0;\n" + " for(int i=0; i < m; i++) {{\n" + " d0 = d1;\n" + " d1 = -TypeParam(2*i+1)*a*d0;\n" + " }}\n" + " TypeParam p0, p1 = 0, p2 = d1;\n" + " for(int l=m; l\n" - " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const {{ /* legendre logic */ }}\n" + " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const {{\n" + " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" + " cuda::std::array inds{{indices...}};\n" + " cuda::std::array xinds;\n" + " int axis1 = axis_[0];\n" + " int axis2 = axis_[1];\n" + " index_t nind = inds[axis1];\n" + " int n = get_value(n_, nind);\n" + " index_t mind = inds[axis2];\n" + " int m = get_value(m_, mind);\n" + " if(axis1>axis2) {{\n" + " int tmp = axis1; axis1 = axis2; axis2 = tmp;\n" + " }}\n" + " int idx = 0;\n" + " for(int i = 0; i < Rank_; i++) {{\n" + " index_t ind = inds[i];\n" + " if(i != axis_[0] && i != axis_[1]) {{\n" + " xinds[idx++] = ind;\n" + " }}\n" + " }}\n" + " auto x = get_value(in_, xinds);\n" + " if constexpr (is_complex_half_v) {{\n" + " return static_cast(legendre_calc(n, m, cuda::std::complex(x)));\n" + " }} else if constexpr (is_matx_half_v) {{\n" + " return static_cast(legendre_calc(n, m, float(x)));\n" + " }} else {{\n" + " return legendre_calc(n, m, x);\n" + " }}\n" + " }} else {{\n" + " return Vector(CapType::ept)>{{}};\n" + " }}\n" + " }}\n" " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank() {{ return Rank_; }}\n" " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const {{ return out_dims_[dim]; }}\n" "}};\n", - func_name, Rank(), axis_[0], axis_[1], detail::array_to_string(out_dims_)) + func_name, Rank(), axis_[0], axis_[1], detail::array_to_string(out_dims_), T3::Rank()) ); } #endif @@ -217,6 +265,16 @@ namespace matx return std::format("{}<{},{},{}>", get_jit_class_name(), n_jit_name, m_jit_name, in_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, + detail::get_operator_capability(n_, in), + detail::get_operator_capability(m_, in), + detail::get_operator_capability(in_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/overlap.h b/include/matx/operators/overlap.h index fe8584ea..33a5c9ea 100644 --- a/include/matx/operators/overlap.h +++ b/include/matx/operators/overlap.h @@ -204,6 +204,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/pad.h b/include/matx/operators/pad.h index aa0a2e6b..6b90fa59 100644 --- a/include/matx/operators/pad.h +++ b/include/matx/operators/pad.h @@ -208,7 +208,14 @@ namespace matx return ""; #endif } - else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { #ifdef MATX_EN_JIT const auto [key, value] = get_jit_op_str(); if (in.find(key) == in.end()) { diff --git a/include/matx/operators/permute.h b/include/matx/operators/permute.h index 07d1da5d..e6ae04ba 100644 --- a/include/matx/operators/permute.h +++ b/include/matx/operators/permute.h @@ -232,6 +232,13 @@ MATX_LOOP_UNROLL return get_jit_class_name() + "<" + op_jit_name + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/planar.h b/include/matx/operators/planar.h index 44059c05..cb3e1932 100644 --- a/include/matx/operators/planar.h +++ b/include/matx/operators/planar.h @@ -177,6 +177,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/polyval.h b/include/matx/operators/polyval.h index a7dc9e02..781ae035 100644 --- a/include/matx/operators/polyval.h +++ b/include/matx/operators/polyval.h @@ -164,6 +164,13 @@ namespace matx return std::format("{}<{},{}>", get_jit_class_name(), op_jit_name, coeffs_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/r2c.h b/include/matx/operators/r2c.h index f22fcfdb..b64c050d 100644 --- a/include/matx/operators/r2c.h +++ b/include/matx/operators/r2c.h @@ -158,6 +158,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/remap.h b/include/matx/operators/remap.h index 93a8ec44..a54d5590 100644 --- a/include/matx/operators/remap.h +++ b/include/matx/operators/remap.h @@ -210,6 +210,13 @@ namespace matx return std::format("{}<{},{}>", get_jit_class_name(), op_jit_name, idx_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/repmat.h b/include/matx/operators/repmat.h index 8bc20e98..7f934e65 100644 --- a/include/matx/operators/repmat.h +++ b/include/matx/operators/repmat.h @@ -202,6 +202,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/reshape.h b/include/matx/operators/reshape.h index 1a0c67a6..f9bb2509 100644 --- a/include/matx/operators/reshape.h +++ b/include/matx/operators/reshape.h @@ -215,6 +215,13 @@ MATX_LOOP_UNROLL return get_jit_class_name() + "<" + op_jit_name + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/reverse.h b/include/matx/operators/reverse.h index 54a84354..fca2b71a 100644 --- a/include/matx/operators/reverse.h +++ b/include/matx/operators/reverse.h @@ -203,6 +203,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/scalar_internal.h b/include/matx/operators/scalar_internal.h index 316834e5..2737183d 100644 --- a/include/matx/operators/scalar_internal.h +++ b/include/matx/operators/scalar_internal.h @@ -221,11 +221,11 @@ static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto scalar_internal_not(T template static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto scalar_internal_isnan(T v1) { using conversionType = typename matx::detail::value_promote_t; - if constexpr(!std::is_floating_point_v) { + if constexpr(!cuda::std::is_floating_point_v) { return false; } - using castType = matx::detail::matx_convert_complex_type; + using castType = matx::detail::matx_convert_cuda_complex_type; if constexpr(is_complex_v) { return cuda::std::isnan(static_cast(v1.real())) || cuda::std::isnan(static_cast(v1.imag())); } else { @@ -236,11 +236,11 @@ static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto scalar_internal_isnan( template static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto scalar_internal_isinf(T v1) { using conversionType = typename matx::detail::value_promote_t; - if constexpr(!std::is_floating_point_v) { + if constexpr(!cuda::std::is_floating_point_v) { return false; } - using castType = matx::detail::matx_convert_complex_type; + using castType = matx::detail::matx_convert_cuda_complex_type; if constexpr(is_complex_v) { return cuda::std::isinf(static_cast(v1.real())) || cuda::std::isinf(static_cast(v1.imag())); } else { diff --git a/include/matx/operators/scalar_ops.h b/include/matx/operators/scalar_ops.h index 553d6957..0f123733 100644 --- a/include/matx/operators/scalar_ops.h +++ b/include/matx/operators/scalar_ops.h @@ -121,6 +121,9 @@ namespace detail { else if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { \ return get_jit_class_name() + "<" + detail::type_to_string() + ">"; \ } \ + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { \ + return true; \ + } \ else { \ return capability_attributes::default_value; \ } \ @@ -190,6 +193,9 @@ namespace detail { else if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { \ return get_jit_class_name() + "<" + detail::type_to_string() + ">"; \ } \ + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { \ + return true; \ + } \ else { \ return capability_attributes::default_value; \ } \ @@ -267,6 +273,9 @@ namespace detail { else if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { \ return get_jit_class_name() + "<" + detail::type_to_string() + "," + detail::type_to_string() + ">"; \ } \ + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { \ + return true; \ + } \ else { \ return capability_attributes::default_value; \ } \ @@ -338,6 +347,9 @@ namespace detail { else if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { \ return get_jit_class_name() + "<" + detail::type_to_string() + "," + detail::type_to_string() + ">"; \ } \ + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { \ + return true; \ + } \ else { \ return capability_attributes::default_value; \ } \ @@ -407,6 +419,9 @@ namespace detail { else if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { \ return get_jit_class_name() + "<" + detail::type_to_string() + "," + detail::type_to_string() + ">"; \ } \ + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { \ + return true; \ + } \ else { \ return capability_attributes::default_value; \ } \ diff --git a/include/matx/operators/select.h b/include/matx/operators/select.h index 2e566358..5dccd6fc 100644 --- a/include/matx/operators/select.h +++ b/include/matx/operators/select.h @@ -190,6 +190,15 @@ namespace matx return std::format("{}<{},{}>", get_jit_class_name(), op_jit_name, idx_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, + detail::get_operator_capability(op_, in), + detail::get_operator_capability(idx_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/self.h b/include/matx/operators/self.h index 016feb8f..aa02f82f 100644 --- a/include/matx/operators/self.h +++ b/include/matx/operators/self.h @@ -160,6 +160,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/set.h b/include/matx/operators/set.h index 9d096ed4..acc52df9 100644 --- a/include/matx/operators/set.h +++ b/include/matx/operators/set.h @@ -171,14 +171,15 @@ class set : public BaseOp> { " mutable typename detail::inner_storage_or_self_t> out_;\n" + " mutable typename detail::inner_storage_or_self_t> op_;\n" + " template \n" + - " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) const\n" + + " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const -> \n" + + "remove_cvref_t(op_, indices...))>\n" + " {\n" + - " auto in_val = detail::get_value(op_, indices...);\n" + - " using out_type = decltype(out_.template operator()(indices...));\n" + - " using in_val_type = decltype(in_val);\n" + + " using in_val_type = remove_cvref_t(op_, indices...))>;\n" + " if ((threadIdx.x * static_cast(CapType::ept)) >= Size(Rank() - 1)) {\n" + - " return detail::GetJitSentinelValue();\n" + + " return in_val_type{};\n" + " }\n" + + " auto in_val = detail::get_value(op_, indices...);\n" + + " using out_type = decltype(out_.template operator()(indices...));\n" + " if (out_.Rank() == 0 || threadIdx.x < out_.Size(out_.Rank() - 1)) {\n" + " if constexpr (!is_vector_v && is_vector_v) {\n" + " Vector, static_cast(CapType::ept)> vec{in_val};\n" + @@ -249,7 +250,14 @@ class set : public BaseOp> { return ""; #endif } - else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in), detail::get_operator_capability(out_, in)); +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { #ifdef MATX_EN_JIT // Get the key/value pair from get_jit_op_str() const auto [key, value] = get_jit_op_str(); diff --git a/include/matx/operators/shift.h b/include/matx/operators/shift.h index 09444815..66285686 100644 --- a/include/matx/operators/shift.h +++ b/include/matx/operators/shift.h @@ -188,6 +188,13 @@ namespace matx return std::format("{}<{},{}>", get_jit_class_name(), op_jit_name, shift_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/sign.h b/include/matx/operators/sign.h index 10c218fe..52877034 100644 --- a/include/matx/operators/sign.h +++ b/include/matx/operators/sign.h @@ -179,6 +179,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/slice.h b/include/matx/operators/slice.h index 625110a8..1a4da2d7 100644 --- a/include/matx/operators/slice.h +++ b/include/matx/operators/slice.h @@ -99,7 +99,29 @@ namespace matx " typename detail::inner_storage_or_self_t> op_;\n" " StrideType strides_;\n" " template \n" - " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) const {{ /* slice logic */ }}\n" + " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const {{\n" + " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" + " cuda::std::array ind = starts_;\n" + " cuda::std::array inds{{indices...}};\n" + " MATX_LOOP_UNROLL\n" + " for (int32_t i = 0; i < OpRank_; i++) {{\n" + " MATX_LOOP_UNROLL\n" + " for(int32_t j = 0; j < DIM_; j++) {{\n" + " if(dims_[j] == i) {{\n" + " if constexpr (!cuda::std::is_same_v) {{\n" + " ind[i] = starts_[j] + inds[j] * strides_[i];\n" + " }}\n" + " else {{\n" + " ind[i] = starts_[j] + inds[j];\n" + " }}\n" + " }}\n" + " }}\n" + " }}\n" + " return get_value(op_, ind);\n" + " }} else {{\n" + " return Vector(CapType::ept)>{{}};\n" + " }}\n" + " }}\n" " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank() {{ return DIM_; }}\n" " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int32_t dim) const {{ return sizes_[dim]; }}\n" "}};\n", @@ -198,6 +220,13 @@ namespace matx return std::format("{}<{}>", get_jit_class_name(), op_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/sph2cart.h b/include/matx/operators/sph2cart.h index f6a60818..fce39a90 100644 --- a/include/matx/operators/sph2cart.h +++ b/include/matx/operators/sph2cart.h @@ -88,7 +88,22 @@ namespace matx " typename detail::inner_storage_or_self_t> phi_;\n" " typename detail::inner_storage_or_self_t> r_;\n" " template \n" - " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const {{ /* sph2cart logic */ }}\n" + " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const {{\n" + " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" + " auto theta = get_value(theta_, indices...);\n" + " auto phi = get_value(phi_, indices...);\n" + " auto r = get_value(r_, indices...);\n" + " if constexpr (WHICH_ == 0) {{\n" + " return r * (scalar_internal_cos(phi) * scalar_internal_cos(theta));\n" + " }} else if constexpr (WHICH_ == 1) {{\n" + " return r * (scalar_internal_cos(phi) * scalar_internal_sin(theta));\n" + " }} else {{\n" + " return r * scalar_internal_sin(phi);\n" + " }}\n" + " }} else {{\n" + " return Vector(CapType::ept)>{{}};\n" + " }}\n" + " }}\n" " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank() {{ return Rank_; }}\n" " constexpr __MATX_INLINE__ __MATX_DEVICE__ auto Size(int dim) const {{ return out_dims_[dim]; }}\n" "}};\n", @@ -189,6 +204,16 @@ namespace matx return std::format("{}<{},{},{}>", get_jit_class_name(), theta_jit_name, phi_jit_name, r_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, + detail::get_operator_capability(theta_, in), + detail::get_operator_capability(phi_, in), + detail::get_operator_capability(r_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/stack.h b/include/matx/operators/stack.h index 90b392fe..8e9d092c 100644 --- a/include/matx/operators/stack.h +++ b/include/matx/operators/stack.h @@ -34,7 +34,9 @@ #include "matx/core/type_utils.h" +#include "matx/core/utils.h" #include "matx/operators/base_operator.h" +#include namespace matx { @@ -73,6 +75,158 @@ namespace matx return get_str<-1>(); } +#ifdef MATX_EN_JIT + struct JIT_Storage { + cuda::std::tuple>...> ops_; + }; + + JIT_Storage ToJITStorage() const { + return JIT_Storage{cuda::std::apply([](const auto&... ops) { + return cuda::std::make_tuple(detail::to_jit_storage(ops)...); + }, ops_)}; + } + + template + __MATX_INLINE__ std::string get_sizes_str() const { + if constexpr (I < sizeof...(Ts)) { + const auto& op = cuda::std::get(ops_); + std::string sizes = "op" + std::to_string(I) + "_"; + for (int d = 0; d < RANK; d++) { + sizes += std::to_string(op.Size(d)); + if (d < RANK - 1) sizes += "x"; + } + if constexpr (I < sizeof...(Ts) - 1) { + return sizes + "_" + get_sizes_str(); + } else { + return sizes; + } + } else { + return ""; + } + } + + __MATX_INLINE__ std::string get_jit_class_name() const { + return std::format("JITStack_axis{}_num{}_{}", axis_, sizeof...(Ts), get_sizes_str<0>()); + } + + template + __MATX_INLINE__ std::string get_jit_type_list() const { + if constexpr (I < sizeof...(Ts) - 1) { + return "typename T" + std::to_string(I) + ", " + get_jit_type_list(); + } else if constexpr (I == sizeof...(Ts) - 1) { + return "typename T" + std::to_string(I); + } else { + return ""; + } + } + + template + __MATX_INLINE__ std::string get_jit_storage_tuple_types() const { + if constexpr (I < sizeof...(Ts) - 1) { + return "typename detail::inner_storage_or_self_t>, " + get_jit_storage_tuple_types(); + } else if constexpr (I == sizeof...(Ts) - 1) { + return "typename detail::inner_storage_or_self_t>"; + } else { + return ""; + } + } + + __MATX_INLINE__ std::string get_jit_storage_tuple() const { + return "cuda::std::tuple<" + get_jit_storage_tuple_types<0>() + "> ops_;\n"; + } + + __MATX_INLINE__ auto get_jit_op_str() const { + std::string func_name = get_jit_class_name(); + cuda::std::array out_dims_; + for (int i = 0; i < RANK + 1; i++) { + out_dims_[i] = Size(i); + } + + return cuda::std::make_tuple( + func_name, + std::format("template <{}> struct {} {{\n" + " using value_type = typename T0::value_type;\n" + " using matxop = bool;\n" + " constexpr static int RANK_ = {};\n" + " constexpr static cuda::std::array sizes_ = {{ {} }};\n" + " constexpr static int axis_ = {};\n" + " {}" + " // Const GetVal\n" + " template \n" + " __MATX_INLINE__ __MATX_DEVICE__ auto GetVal(index_t oidx, cuda::std::array& indices) const {{\n" + " if constexpr ( I == N ) {{\n" + " const auto &op = cuda::std::get<0>(ops_);\n" + " return get_value(op, indices);\n" + " }} else {{\n" + " if ( I < oidx ) {{\n" + " return GetVal(oidx, indices);\n" + " }} else {{\n" + " const auto &op = cuda::std::get(ops_);\n" + " return get_value(op, indices);\n" + " }}\n" + " }}\n" + " }}\n" + " // Non-const GetVal for lvalue assignments\n" + " template \n" + " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) GetVal(index_t oidx, cuda::std::array& indices) {{\n" + " if constexpr ( I == N ) {{\n" + " auto &op = cuda::std::get<0>(ops_);\n" + " return get_value(op, indices);\n" + " }} else {{\n" + " if ( I < oidx ) {{\n" + " return GetVal(oidx, indices);\n" + " }} else {{\n" + " auto &op = cuda::std::get(ops_);\n" + " return get_value(op, indices);\n" + " }}\n" + " }}\n" + " }}\n" + " // Const operator()\n" + " template \n" + " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... is) const {{\n" + " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" + " cuda::std::array indices{{is...}};\n" + " cuda::std::array indices_o;\n" + " index_t oidx = indices[axis_];\n" + " for(int i = 0; i < axis_; i++) {{\n" + " indices_o[i] = indices[i];\n" + " }}\n" + " for(int i = axis_; i < (int)indices_o.size(); i++) {{\n" + " indices_o[i] = indices[i+1];\n" + " }}\n" + " return GetVal(oidx, indices_o);\n" + " }} else {{\n" + " return Vector(CapType::ept)>{{}};\n" + " }}\n" + " }}\n" + " // Non-const operator()\n" + " template \n" + " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... is) {{\n" + " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" + " cuda::std::array indices{{is...}};\n" + " cuda::std::array indices_o;\n" + " index_t oidx = indices[axis_];\n" + " for(int i = 0; i < axis_; i++) {{\n" + " indices_o[i] = indices[i];\n" + " }}\n" + " for(int i = axis_; i < (int)indices_o.size(); i++) {{\n" + " indices_o[i] = indices[i+1];\n" + " }}\n" + " return GetVal(oidx, indices_o);\n" + " }} else {{\n" + " return Vector(CapType::ept)>{{}};\n" + " }}\n" + " }}\n" + " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank() {{ return RANK_+1; }}\n" + " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const {{\n" + " return sizes_[dim];\n" + " }}\n" + "}};\n", + get_jit_type_list<0>(), func_name, RANK, detail::array_to_string(out_dims_), axis_, get_jit_storage_tuple(), sizeof...(Ts), sizeof...(Ts)) + ); + } +#endif + __MATX_INLINE__ StackOp(int axis, const Ts&... ts) : ops_(ts...), axis_(axis) { MATX_LOG_TRACE("{} constructor: axis={}, num_tensors={}", str(), axis, sizeof...(Ts)); @@ -196,7 +350,41 @@ namespace matx template __MATX_INLINE__ __MATX_HOST__ auto get_capability([[maybe_unused]] InType& in) const { - if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { + if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { +#ifdef MATX_EN_JIT + return get_jit_class_name() + "<" + get_jit_type_params<0>() + ">"; +#else + return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, get_combined_ops_capability(in, ops_)); +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { +#ifdef MATX_EN_JIT + // Get the key/value pair from get_jit_op_str() + const auto [key, value] = get_jit_op_str(); + + // Insert into the map if the key doesn't exist + if (in.find(key) == in.end()) { + in[key] = value; + } + + // Also handle child operators + cuda::std::apply([&in](const auto&... ops) { + (detail::get_operator_capability(ops, in), ...); + }, ops_); + + return true; +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { const auto my_cap = cuda::std::array{ElementsPerThread::ONE, ElementsPerThread::ONE}; return combine_capabilities(my_cap, get_combined_ops_capability(in, ops_)); } @@ -211,6 +399,23 @@ namespace matx } } +#ifdef MATX_EN_JIT + template + __MATX_INLINE__ std::string get_jit_type_params() const { + if constexpr (I < sizeof...(Ts)) { + VoidCapabilityType void_type{}; + auto type_name = detail::get_operator_capability(cuda::std::get(ops_), void_type); + if constexpr (I < sizeof...(Ts) - 1) { + return type_name + "," + get_jit_type_params(); + } else { + return type_name; + } + } else { + return ""; + } + } +#endif + private: cuda::std::tuple ...> ops_; index_t size_; diff --git a/include/matx/operators/sum.h b/include/matx/operators/sum.h index 6b0f7448..dc6aa37d 100644 --- a/include/matx/operators/sum.h +++ b/include/matx/operators/sum.h @@ -71,7 +71,7 @@ namespace detail { __MATX_HOST__ __MATX_INLINE__ auto Data() const noexcept { return ptr; } - __MATX_INLINE__ std::string get_capability_str(int EPT) const { + __MATX_INLINE__ std::string get_capability_str([[maybe_unused]] int EPT) const { return std::string("sum"); } diff --git a/include/matx/operators/toeplitz.h b/include/matx/operators/toeplitz.h index 6bbe3e85..7ea2877f 100644 --- a/include/matx/operators/toeplitz.h +++ b/include/matx/operators/toeplitz.h @@ -84,7 +84,18 @@ namespace matx " typename detail::inner_storage_or_self_t> op1_;\n" " typename detail::inner_storage_or_self_t> op2_;\n" " template \n" - " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(index_t i, index_t j) const {{ /* toeplitz logic */ }}\n" + " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(index_t i, index_t j) const {{\n" + " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" + " if (j > i) {{\n" + " return get_value(op2_, j - i);\n" + " }}\n" + " else {{\n" + " return get_value(op1_, i - j);\n" + " }}\n" + " }} else {{\n" + " return Vector(CapType::ept)>{{}};\n" + " }}\n" + " }}\n" " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank() {{ return 2; }}\n" " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const {{ return (dim == 0) ? size1_ : size2_; }}\n" "}};\n", @@ -172,6 +183,15 @@ namespace matx return std::format("{}<{},{}>", get_jit_class_name(), op1_jit_name, op2_jit_name); #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, + detail::get_operator_capability(op1_, in), + detail::get_operator_capability(op2_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/unary_operators.h b/include/matx/operators/unary_operators.h index 0d1adc07..7b46e76b 100644 --- a/include/matx/operators/unary_operators.h +++ b/include/matx/operators/unary_operators.h @@ -167,6 +167,13 @@ namespace matx return get_jit_class_name() + "<" + lhs_jit_name + "," + op_jit_name + ">"; #else return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(in1_, in)); +#else + return false; #endif } else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { diff --git a/include/matx/operators/updownsample.h b/include/matx/operators/updownsample.h index fb910ff5..c7330b5f 100644 --- a/include/matx/operators/updownsample.h +++ b/include/matx/operators/updownsample.h @@ -157,7 +157,14 @@ namespace matx return ""; #endif } - else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, detail::get_operator_capability(op_, in)); +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { #ifdef MATX_EN_JIT const auto [key, value] = get_jit_op_str(); if (in.find(key) == in.end()) { diff --git a/include/matx/operators/zipvec.h b/include/matx/operators/zipvec.h index 9f31ac6c..d12a3d86 100644 --- a/include/matx/operators/zipvec.h +++ b/include/matx/operators/zipvec.h @@ -34,7 +34,9 @@ #include "matx/core/type_utils.h" +#include "matx/core/utils.h" #include "matx/operators/base_operator.h" +#include namespace matx { @@ -72,6 +74,142 @@ namespace matx return get_str<-1>(); } +#ifdef MATX_EN_JIT + struct JIT_Storage { + cuda::std::tuple>...> ops_; + }; + + JIT_Storage ToJITStorage() const { + return JIT_Storage{cuda::std::apply([](const auto&... ops) { + return cuda::std::make_tuple(detail::to_jit_storage(ops)...); + }, ops_)}; + } + + template + __MATX_INLINE__ std::string get_sizes_str() const { + if constexpr (I < sizeof...(Ts)) { + const auto& op = cuda::std::get(ops_); + std::string sizes = "op" + std::to_string(I) + "_"; + for (int d = 0; d < RANK; d++) { + sizes += std::to_string(op.Size(d)); + if (d < RANK - 1) sizes += "x"; + } + if constexpr (I < sizeof...(Ts) - 1) { + return sizes + "_" + get_sizes_str(); + } else { + return sizes; + } + } else { + return ""; + } + } + + __MATX_INLINE__ std::string get_jit_class_name() const { + return std::format("JITZipVec_num{}_{}", sizeof...(Ts), get_sizes_str<0>()); + } + + template + __MATX_INLINE__ std::string get_jit_type_list() const { + if constexpr (I < sizeof...(Ts) - 1) { + return "typename T" + std::to_string(I) + ", " + get_jit_type_list(); + } else if constexpr (I == sizeof...(Ts) - 1) { + return "typename T" + std::to_string(I); + } else { + return ""; + } + } + + template + __MATX_INLINE__ std::string get_jit_storage_tuple_types() const { + if constexpr (I < sizeof...(Ts) - 1) { + return "typename detail::inner_storage_or_self_t>, " + get_jit_storage_tuple_types(); + } else if constexpr (I == sizeof...(Ts) - 1) { + return "typename detail::inner_storage_or_self_t>"; + } else { + return ""; + } + } + + __MATX_INLINE__ std::string get_jit_storage_tuple() const { + return "cuda::std::tuple<" + get_jit_storage_tuple_types<0>() + "> ops_;\n"; + } + + template + __MATX_INLINE__ std::string get_jit_value_types() const { + if constexpr (I < sizeof...(Ts)) { + std::string type_str = "typename T" + std::to_string(I) + "::value_type"; + if constexpr (I < sizeof...(Ts) - 1) { + return type_str + ", " + get_jit_value_types(); + } else { + return type_str; + } + } else { + return ""; + } + } + + template + __MATX_INLINE__ std::string get_jit_operator_calls() const { + if constexpr (I < sizeof...(Ts)) { + std::string call = "static_cast(cuda::std::get<" + std::to_string(I) + ">(ops_).template operator()(cuda::std::forward(is)...))"; + if constexpr (I < sizeof...(Ts) - 1) { + return call + ", " + get_jit_operator_calls(); + } else { + return call; + } + } else { + return ""; + } + } + + __MATX_INLINE__ auto get_jit_op_str() const { + std::string func_name = get_jit_class_name(); + cuda::std::array out_dims_; + for (int i = 0; i < RANK; i++) { + out_dims_[i] = Size(i); + } + + std::string value_types = get_jit_value_types<0>(); + + return cuda::std::make_tuple( + func_name, + std::format("template <{}> struct {} {{\n" + " using value_type = AggregateToVecType<{}>;\n" + " using matxop = bool;\n" + " constexpr static int RANK_ = {};\n" + " constexpr static cuda::std::array sizes_ = {{ {} }};\n" + " {}" + " // Const operator()\n" + " template \n" + " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... is) const {{\n" + " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" + " using scalar_type = typename AggregateToVec<{}>::common_type;\n" + " return value_type{{ {} }};\n" + " }} else {{\n" + " return Vector(CapType::ept)>{{}};\n" + " }}\n" + " }}\n" + " // Non-const operator()\n" + " template \n" + " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... is) {{\n" + " if constexpr (CapType::ept == ElementsPerThread::ONE) {{\n" + " using scalar_type = typename AggregateToVec<{}>::common_type;\n" + " return value_type{{ {} }};\n" + " }} else {{\n" + " return Vector(CapType::ept)>{{}};\n" + " }}\n" + " }}\n" + " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank() {{ return RANK_; }}\n" + " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const {{\n" + " return sizes_[dim];\n" + " }}\n" + "}};\n", + get_jit_type_list<0>(), func_name, value_types, RANK, detail::array_to_string(out_dims_), get_jit_storage_tuple(), + value_types, get_jit_operator_calls<0>(), value_types, get_jit_operator_calls<0>()) + ); + } +#endif + __MATX_INLINE__ ZipVecOp(const Ts&... ts) : ops_(ts...) { MATX_LOG_TRACE("{} constructor: num_ops={}, rank={}", str(), sizeof...(Ts), Rank()); @@ -123,7 +261,41 @@ namespace matx template __MATX_INLINE__ __MATX_HOST__ auto get_capability([[maybe_unused]] InType &in) const { - if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { + if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { +#ifdef MATX_EN_JIT + return get_jit_class_name() + "<" + get_jit_type_params<0>() + ">"; +#else + return ""; +#endif + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { +#ifdef MATX_EN_JIT + return combine_capabilities(true, get_combined_ops_capability(in, ops_)); +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { +#ifdef MATX_EN_JIT + // Get the key/value pair from get_jit_op_str() + const auto [key, value] = get_jit_op_str(); + + // Insert into the map if the key doesn't exist + if (in.find(key) == in.end()) { + in[key] = value; + } + + // Also handle child operators + cuda::std::apply([&in](const auto&... ops) { + (detail::get_operator_capability(ops, in), ...); + }, ops_); + + return true; +#else + return false; +#endif + } + else if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { // For now, we do not support vectorization. We could support it, but it will require some // rework of the assumptions used in the matx::Vector class. const auto my_cap = cuda::std::array{ElementsPerThread::ONE, ElementsPerThread::ONE}; @@ -134,6 +306,23 @@ namespace matx } } +#ifdef MATX_EN_JIT + template + __MATX_INLINE__ std::string get_jit_type_params() const { + if constexpr (I < sizeof...(Ts)) { + VoidCapabilityType void_type{}; + auto type_name = detail::get_operator_capability(cuda::std::get(ops_), void_type); + if constexpr (I < sizeof...(Ts) - 1) { + return type_name + "," + get_jit_type_params(); + } else { + return type_name; + } + } else { + return ""; + } + } +#endif + static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() noexcept { return RANK; diff --git a/include/matx/transforms/fft/fft_cufftdx.h b/include/matx/transforms/fft/fft_cufftdx.h index 04909f99..e33fa9d7 100644 --- a/include/matx/transforms/fft/fft_cufftdx.h +++ b/include/matx/transforms/fft/fft_cufftdx.h @@ -50,6 +50,11 @@ namespace matx { namespace detail { + enum class cuFFTDxMethod { + REGISTER, + SHARED + }; + template class cuFFTDxHelper { @@ -61,6 +66,7 @@ namespace matx { int ffts_per_block_ = 1; int cc_; bool contiguous_input_; + cuFFTDxMethod method_; public: // Constructor cuFFTDxHelper() = default; @@ -73,7 +79,7 @@ namespace matx { int get_ffts_per_block() const { return ffts_per_block_; } int get_cc() const { return cc_; } bool get_contiguous_input() const { return contiguous_input_; } - + cuFFTDxMethod get_method() const { return method_; } // Setters void set_fft_size(index_t size) { fft_size_ = size; } @@ -83,13 +89,21 @@ namespace matx { void set_ffts_per_block(int ffts_per_block) { ffts_per_block_ = ffts_per_block; } void set_cc(int cc) { cc_ = cc; } void set_contiguous_input(bool contiguous_input) { contiguous_input_ = contiguous_input; } + void set_method(cuFFTDxMethod method) { method_ = method; } #if defined(MATX_EN_MATHDX) && defined(__CUDACC__) cufftdxDescriptor GeneratePlan() const { cufftdxDescriptor h_; LIBMATHDX_CHECK(cufftdxCreateDescriptor(&h_)); - LIBMATHDX_CHECK(cufftdxSetOperatorInt64(h_, CUFFTDX_OPERATOR_API, cufftdxApi::CUFFTDX_API_SMEM)); + + // if (fft_size_ <= 32) { + // LIBMATHDX_CHECK(cufftdxSetOperatorInt64(h_, CUFFTDX_OPERATOR_API, cufftdxApi::CUFFTDX_API_LMEM)); + // method_ = cuFFTDxMethod::REGISTER; + // } else { + LIBMATHDX_CHECK(cufftdxSetOperatorInt64(h_, CUFFTDX_OPERATOR_API, method_ == cuFFTDxMethod::REGISTER ? cufftdxApi::CUFFTDX_API_LMEM : cufftdxApi::CUFFTDX_API_SMEM)); + //} + LIBMATHDX_CHECK( - cufftdxSetOperatorInt64(h_, CUFFTDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK)); + cufftdxSetOperatorInt64(h_, CUFFTDX_OPERATOR_EXECUTION, commondxExecution::COMMONDX_EXECUTION_BLOCK)); LIBMATHDX_CHECK(cufftdxSetOperatorInt64(h_, CUFFTDX_OPERATOR_SIZE, fft_size_)); @@ -174,13 +188,39 @@ namespace matx { return static_cast(valid); } + template + bool CheckJITSizeAndTypeRequirements() const { + using OpInputType = typename OpType::value_type; + + // Only support power-of-2 FFT sizes for JIT support + if ((fft_size_ & (fft_size_ - 1)) != 0 || fft_size_ == 0) { + return false; + } + + // No half support in MatX for fusion yet + if constexpr (is_complex_half_v) { + return false; + } + + // Only support C2C for JIT support + if constexpr (!is_complex_v) { + return false; + } + + return true; + } + int GetShmRequired() const { auto handle = GeneratePlan(); long long int shared_memory_size = 0; LIBMATHDX_CHECK(cufftdxGetTraitInt64(handle, CUFFTDX_TRAIT_SHARED_MEMORY_SIZE, &shared_memory_size)); MATX_LOG_DEBUG("Shared memory size from cuFFTDx: {}", shared_memory_size); - shared_memory_size = static_cast(current_elements_per_thread_) * sizeof(InputType) * static_cast(ffts_per_block_) + shared_memory_size; + if (method_ == cuFFTDxMethod::SHARED) { + // Add in the input/output shm + shared_memory_size = static_cast(fft_size_) * sizeof(InputType) * static_cast(ffts_per_block_) + shared_memory_size; + } + MATX_LOG_DEBUG("Shared memory size computed: {}", shared_memory_size); return static_cast(shared_memory_size); } @@ -201,6 +241,9 @@ namespace matx { return my_cap; } + for (size_t i = 0; i < epts.size(); ++i) { + MATX_LOG_DEBUG("cuFFTDx EPT[{}]: {}", i, epts[i]); + } return cuda::std::array{static_cast(*std::min_element(epts.begin(), epts.end())), static_cast(*std::max_element(epts.begin(), epts.end()))}; } @@ -223,7 +266,7 @@ namespace matx { LIBMATHDX_CHECK(cufftdxGetTraitInt64(handle, CUFFTDX_TRAIT_SUGGESTED_FFTS_PER_BLOCK, &sfpb)); - MATX_LOG_DEBUG("Getting FFTs per block {} elements per thread {}", sfpb, static_cast(current_elements_per_thread_)); + MATX_LOG_DEBUG("Getting FFTs per block {} elements per thread {} and fft_size {}", sfpb, static_cast(current_elements_per_thread_), fft_size_); return static_cast(sfpb); } @@ -321,6 +364,9 @@ namespace matx { result += R"(; [[maybe_unused]] static constexpr int contiguous_input = )"; result += std::to_string(contiguous_input_); + // result += R"(; + // [[maybe_unused]] static constexpr bool register_api = )"; + // result += std::to_string(method_ == cuFFTDxMethod::REGISTER) ? "true" : "false"; result += R"(; const int local_fft_id = threadIdx.y; @@ -333,50 +379,51 @@ namespace matx { // using BlockLoadToShm = cub::detail::BlockLoadToShared; // using TempStorage = BlockLoadToShm::TempStorage; using VecType = Vector(CapType::ept)>; - constexpr size_t to_copy = sizeof(input_type) * static_cast(CapType::ept) * static_cast(total_threads_per_block); + //constexpr size_t to_copy = sizeof(input_type) * static_cast(CapType::ept) * static_cast(total_threads_per_block); // constexpr int buff_align = BlockLoadToShm::template SharedBufferAlignBytes(); // constexpr int buff_size = BlockLoadToShm::template SharedBufferSizeBytes(total_threads_per_block); extern __shared__ VecType thread_data[]; - if constexpr (contiguous_input) { - // __shared__ TempStorage temp_storage; - // BlockLoadToShm load_to_shared(temp_storage); - // cuda::std::span gmem_src(reinterpret_cast(a_.template data_ptr(blockIdx.x, blockDim.x * blockDim.y)), total_threads_per_block); - // cuda::std::span smem_dst_buff(thread_data, total_threads_per_block); - // auto smem_dst = load_to_shared.CopyAsync(smem_dst_buff, gmem_src); - // load_to_shared.Commit(); - // load_to_shared.Wait(); - cuda::barrier bar; - init(&bar, 1); - cuda::memcpy_async(thread_data, reinterpret_cast(a_.template data_ptr(blockIdx.x, blockDim.x * blockDim.y)), to_copy, bar); - bar.arrive_and_wait(); - } - else { + // if constexpr (contiguous_input) { + // // __shared__ TempStorage temp_storage; + // // BlockLoadToShm load_to_shared(temp_storage); + // // cuda::std::span gmem_src(reinterpret_cast(a_.template data_ptr(blockIdx.x, blockDim.x * blockDim.y)), total_threads_per_block); + // // cuda::std::span smem_dst_buff(thread_data, total_threads_per_block); + // // auto smem_dst = load_to_shared.CopyAsync(smem_dst_buff, gmem_src); + // // load_to_shared.Commit(); + // // load_to_shared.Wait(); + // cuda::barrier bar; + // init(&bar, 1); + // cuda::memcpy_async(thread_data, reinterpret_cast(a_.template data_ptr(blockIdx.x, blockDim.x * blockDim.y)), to_copy, bar); + // bar.arrive_and_wait(); + // } + // else { thread_data[local_fft_id * blockDim.x + threadIdx.x] = a_.template operator()(indices...); - __syncthreads(); - } + __syncthreads(); + //} )"; result += fft_func_name; - result += R"((reinterpret_cast(&thread_data[local_fft_id * blockDim.x])); - - if constexpr (fft_norm == 2) { // ORTHO - #pragma unroll - for (int i = 0; i < static_cast(CapType::ept); i++) { - thread_data[local_fft_id * blockDim.x + threadIdx.x].data[i] = thread_data[local_fft_id * blockDim.x + threadIdx.x].data[i] * static_cast(1.f) / static_cast(cuda::std::sqrt(fft_size)); - } + result += R"((reinterpret_cast(&thread_data[0])); + __syncthreads(); + + if constexpr (fft_norm == 2) { // ORTHO + #pragma unroll + for (int i = 0; i < static_cast(CapType::ept); i++) { + thread_data[local_fft_id * blockDim.x + threadIdx.x].data[i] = thread_data[local_fft_id * blockDim.x + threadIdx.x].data[i] * static_cast(1.f) / static_cast(cuda::std::sqrt(fft_size)); } - else if constexpr ((fft_norm == 1 && fft_forward) || (fft_norm == 0 && !fft_forward)) { - #pragma unroll - for (int i = 0; i < static_cast(CapType::ept); i++) { - thread_data[local_fft_id * blockDim.x + threadIdx.x].data[i] = thread_data[local_fft_id * blockDim.x + threadIdx.x].data[i] * static_cast(1.f) / static_cast(fft_size); - } + } + else if constexpr ((fft_norm == 1 && fft_forward) || (fft_norm == 0 && !fft_forward)) { + #pragma unroll + for (int i = 0; i < static_cast(CapType::ept); i++) { + thread_data[local_fft_id * blockDim.x + threadIdx.x].data[i] = thread_data[local_fft_id * blockDim.x + threadIdx.x].data[i] * static_cast(1.f) / static_cast(fft_size); } + } - return thread_data[local_fft_id * blockDim.x + threadIdx.x]; + return thread_data[local_fft_id * blockDim.x + threadIdx.x]; )"; return result; diff --git a/test/00_operators/CMakeLists.txt b/test/00_operators/CMakeLists.txt index 463c160d..8a995395 100644 --- a/test/00_operators/CMakeLists.txt +++ b/test/00_operators/CMakeLists.txt @@ -25,12 +25,21 @@ set(OPERATOR_TEST_FILES frexp_test.cu frexpc_test.cu get_string_test.cu + ifelse_test.cu interleaved_test.cu interp_test.cu isclose_test.cu isnaninf_test.cu legendre_test.cu - operator_func_test.cu + operator_func_boolean_test.cu + operator_func_complex_test.cu + operator_func_div_complex_test.cu + operator_func_float_noncomplex_test.cu + operator_func_integral_test.cu + operator_func_nd_test.cu + operator_func_numeric_noncomplex_test.cu + operator_func_numeric_test.cu + operator_func_r2c_test.cu overlap_test.cu pad_test.cu permute_test.cu diff --git a/test/00_operators/at_test.cu b/test/00_operators/at_test.cu index f04a608f..0e5ab6a4 100644 --- a/test/00_operators/at_test.cu +++ b/test/00_operators/at_test.cu @@ -36,7 +36,7 @@ TYPED_TEST(OperatorTestsNumericNonComplexAllExecs, AtOp) ASSERT_EQ(t0(), t2(1, 4)); - if constexpr (is_cuda_executor_v && (std::is_same_v || std::is_same_v)) { + if constexpr (is_cuda_non_jit_executor_v && (std::is_same_v || std::is_same_v)) { using ComplexType = detail::complex_from_scalar_t; auto c0 = make_tensor({}); (c0 = at(fft(t1), 0)).run(exec); diff --git a/test/00_operators/broadcast_test.cu b/test/00_operators/broadcast_test.cu index c6a63d66..7e074f58 100644 --- a/test/00_operators/broadcast_test.cu +++ b/test/00_operators/broadcast_test.cu @@ -17,12 +17,12 @@ TYPED_TEST(OperatorTestsNumericAllExecs, Broadcast) { auto t0 = make_tensor({}); + t0() = (TestType)2.0f; tensor_t t4i({10, 20, 30, 40}); tensor_t t4o({10, 20, 30, 40}); (t4o = t0).run(exec); exec.sync(); - t0() = (TestType)2.0f; for (index_t i = 0; i < t4i.Size(0); i++) { for (index_t j = 0; j < t4i.Size(1); j++) { for (index_t k = 0; k < t4i.Size(2); k++) { diff --git a/test/00_operators/clone_test.cu b/test/00_operators/clone_test.cu index 7324e031..81822e77 100644 --- a/test/00_operators/clone_test.cu +++ b/test/00_operators/clone_test.cu @@ -82,9 +82,6 @@ TYPED_TEST(OperatorTestsNumericAllExecs, CloneOp) (tov = op).run(exec); exec.sync(); - print(op); - print(tov); - print(tiv); for(int n = 0; n < N; n++) { for(int m = 0; m < M; m++) { @@ -220,14 +217,16 @@ TYPED_TEST(OperatorTestsNumericAllExecs, CloneOp) exec.sync(); - (tov = clone<3>(conv2d(tiv, delta, MATX_C_MODE_SAME), {N, matxKeepDim, matxKeepDim})).run(exec); + if (jit_supported(conv2d(tiv, delta, MATX_C_MODE_SAME))) { + (tov = clone<3>(conv2d(tiv, delta, MATX_C_MODE_SAME), {N, matxKeepDim, matxKeepDim})).run(exec); - exec.sync(); + exec.sync(); - for(int n = 0; n < N; n++) { - for(int m = 0; m < M; m++) { - for(int k = 0; k < K; k++) { - ASSERT_EQ(tov(n,m,k) , tiv(m,k)); + for(int n = 0; n < N; n++) { + for(int m = 0; m < M; m++) { + for(int k = 0; k < K; k++) { + ASSERT_EQ(tov(n,m,k) , tiv(m,k)); + } } } } diff --git a/test/00_operators/collapse_test.cu b/test/00_operators/collapse_test.cu index 92ac409f..4ba7e324 100644 --- a/test/00_operators/collapse_test.cu +++ b/test/00_operators/collapse_test.cu @@ -123,7 +123,7 @@ TYPED_TEST(OperatorTestsNumericAllExecs, CollapseOp) } } - if constexpr (is_cuda_executor_v && (std::is_same_v || std::is_same_v)) + if constexpr (is_cuda_non_jit_executor && (std::is_same_v || std::is_same_v)) { // rcollapse with nested transform operator auto tov = make_tensor({N,M*K}); auto delta = make_tensor({1,1}); @@ -148,7 +148,7 @@ TYPED_TEST(OperatorTestsNumericAllExecs, CollapseOp) } } - if constexpr (is_cuda_executor_v && (std::is_same_v || std::is_same_v)) + if constexpr (is_cuda_non_jit_executor && (std::is_same_v || std::is_same_v)) { // lcollapse with nested transform operator auto tov = make_tensor({N*M,K}); auto delta = make_tensor({1,1}); diff --git a/test/00_operators/concat_test.cu b/test/00_operators/concat_test.cu index 910d6f54..beb82e7b 100644 --- a/test/00_operators/concat_test.cu +++ b/test/00_operators/concat_test.cu @@ -40,7 +40,7 @@ TYPED_TEST(OperatorTestsFloatNonComplexAllExecs, Concatenate) } // Test contcat with nested transforms - if constexpr (is_cuda_executor_v && (std::is_same_v || std::is_same_v)) { + if constexpr (is_cuda_non_jit_executor && (std::is_same_v || std::is_same_v)) { auto delta = make_tensor({1}); delta.SetVals({1.0}); diff --git a/test/00_operators/fftshift_test.cu b/test/00_operators/fftshift_test.cu index c537a4ea..c984ab3c 100644 --- a/test/00_operators/fftshift_test.cu +++ b/test/00_operators/fftshift_test.cu @@ -33,6 +33,15 @@ TYPED_TEST(OperatorTestsFloatNonHalf, FFTShiftWithTransform) { const int N1 = 3; const int N2 = 4; + + // Skip if using JIT executor but expression doesn't support JIT + if constexpr (is_cuda_jit_executor_v) { + auto test_t3 = make_tensor({N1}); + auto test_expr = fftshift1D(fft(test_t3)); + if (!jit_supported(test_expr)) { + GTEST_SKIP(); + } + } auto t3 = make_tensor({N1}); auto t4 = make_tensor({N2}); @@ -95,6 +104,15 @@ TYPED_TEST(OperatorTestsFloatNonHalf, FFTShiftWithTransform) if constexpr (is_complex_v) { [[maybe_unused]] const int N = 4; + // Skip if using JIT executor but expression doesn't support JIT + if constexpr (is_cuda_jit_executor_v) { + auto test_x = make_tensor({N,N}); + auto test_expr = fftshift2D(fft2(test_x)); + if (!jit_supported(test_expr)) { + GTEST_SKIP(); + } + } + auto x = make_tensor({N,N}); auto X = make_tensor({N,N}); diff --git a/test/00_operators/ifelse_test.cu b/test/00_operators/ifelse_test.cu new file mode 100644 index 00000000..0d3f3178 --- /dev/null +++ b/test/00_operators/ifelse_test.cu @@ -0,0 +1,47 @@ +#include "operator_test_types.hpp" +#include "matx.h" +#include "test_types.h" +#include "utilities.h" + +using namespace matx; +using namespace matx::test; + +TYPED_TEST(OperatorTestsAllExecs, OperatorFuncs) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + + ExecType exec{}; + auto tiv0 = make_tensor({}); + auto tov0 = make_tensor({}); + + TestType c = GenerateData(); + TestType d = c; + TestType z = 0; + tiv0() = c; + + auto tov00 = make_tensor({}); + + // example-begin IFELSE-test-1 + IFELSE(tiv0 == d, tov0 = z, tov0 = d).run(exec); + // example-end IFELSE-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), z)); + + IFELSE(tiv0 == d, tov0 = tiv0, tov0 = d).run(exec); + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), tiv0())); + + IFELSE(tiv0 != d, tov0 = d, tov0 = z).run(exec); + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), z)); + + (tov0 = c, tov00 = c).run(exec); + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), c)); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov00(), c)); + + MATX_EXIT_HANDLER(); +} + diff --git a/test/00_operators/isclose_test.cu b/test/00_operators/isclose_test.cu index 22d747dc..88245c77 100644 --- a/test/00_operators/isclose_test.cu +++ b/test/00_operators/isclose_test.cu @@ -6,6 +6,7 @@ using namespace matx; using namespace matx::test; +// assigning ones() with a half precision type is not constexpr TYPED_TEST(OperatorTestsFloatAllExecs, IsClose) { MATX_ENTER_HANDLER(); diff --git a/test/00_operators/legendre_test.cu b/test/00_operators/legendre_test.cu index 28e23b8c..9cb8c561 100644 --- a/test/00_operators/legendre_test.cu +++ b/test/00_operators/legendre_test.cu @@ -40,6 +40,7 @@ TypeParam legendre_check(int n, int m, TypeParam x) { return p2; } +// No JIT until constexpr half is fixed TYPED_TEST(OperatorTestsFloatNonComplexAllExecs, Legendre) { MATX_ENTER_HANDLER(); diff --git a/test/00_operators/operator_func_boolean_test.cu b/test/00_operators/operator_func_boolean_test.cu new file mode 100644 index 00000000..9459065a --- /dev/null +++ b/test/00_operators/operator_func_boolean_test.cu @@ -0,0 +1,61 @@ +#include "operator_test_types.hpp" +#include "matx.h" +#include "test_types.h" +#include "utilities.h" + +using namespace matx; +using namespace matx::test; + +TYPED_TEST(OperatorTestsBooleanAllExecs, OperatorFuncs) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + + ExecType exec{}; + auto tiv0 = make_tensor({}); + auto tov0 = make_tensor({}); + + TestType c = GenerateData(); + TestType d = false; + tiv0() = c; + + // example-begin land-test-1 + (tov0 = tiv0 && d).run(exec); + // example-end land-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), c && d)); + + // example-begin lor-test-1 + (tov0 = tiv0 || d).run(exec); + // example-end lor-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), c || d)); + + // example-begin lnot-test-1 + (tov0 = !tiv0).run(exec); + // example-end lnot-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), !c)); + + // example-begin xor-test-1 + (tov0 = tiv0 ^ d).run(exec); + // example-end xor-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), c ^ d)); + + // example-begin or-test-1 + (tov0 = tiv0 | d).run(exec); + // example-end or-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), c | d)); + + // example-begin and-test-1 + (tov0 = tiv0 & d).run(exec); + // example-end and-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), c & d)); + + MATX_EXIT_HANDLER(); +} + diff --git a/test/00_operators/operator_func_complex_test.cu b/test/00_operators/operator_func_complex_test.cu new file mode 100644 index 00000000..7a91ba39 --- /dev/null +++ b/test/00_operators/operator_func_complex_test.cu @@ -0,0 +1,46 @@ +#include "operator_test_types.hpp" +#include "matx.h" +#include "test_types.h" +#include "utilities.h" + +using namespace matx; +using namespace matx::test; + +TYPED_TEST(OperatorTestsComplexTypesAllExecs, OperatorFuncs) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + + ExecType exec{}; + + auto tiv0 = make_tensor({}); + auto tov0 = make_tensor({}); + + TestType c = GenerateData(); + tiv0() = c; + + // example-begin exp-test-1 + (tov0 = exp(tiv0)).run(exec); + // example-end exp-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::scalar_internal_exp(c))); + + // example-begin conj-test-1 + (tov0 = conj(tiv0)).run(exec); + // example-end conj-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::scalar_internal_conj(c))); + + // abs takes a complex and output a floating point value + auto tdd0 = make_tensor({}); + + // example-begin abs-test-1 + (tdd0 = abs(tiv0)).run(exec); + // example-end abs-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tdd0(), detail::scalar_internal_abs(c))); + + MATX_EXIT_HANDLER(); +} + diff --git a/test/00_operators/operator_func_div_complex_test.cu b/test/00_operators/operator_func_div_complex_test.cu new file mode 100644 index 00000000..3fd8a986 --- /dev/null +++ b/test/00_operators/operator_func_div_complex_test.cu @@ -0,0 +1,29 @@ +#include "operator_test_types.hpp" +#include "matx.h" +#include "test_types.h" +#include "utilities.h" + +using namespace matx; +using namespace matx::test; + +TYPED_TEST(OperatorTestsComplexTypesAllExecs, OperatorFuncDivComplex) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + + ExecType exec{}; + auto tiv0 = make_tensor({}); + auto tov0 = make_tensor({}); + typename TestType::value_type s = 5.0; + + TestType c = GenerateData(); + tiv0() = c; + + (tov0 = s / tiv0).run(exec); + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), s / tiv0())); + + MATX_EXIT_HANDLER(); +} + diff --git a/test/00_operators/operator_func_float_noncomplex_test.cu b/test/00_operators/operator_func_float_noncomplex_test.cu new file mode 100644 index 00000000..b1955d6a --- /dev/null +++ b/test/00_operators/operator_func_float_noncomplex_test.cu @@ -0,0 +1,72 @@ +#include "operator_test_types.hpp" +#include "matx.h" +#include "test_types.h" +#include "utilities.h" + +using namespace matx; +using namespace matx::test; + +TYPED_TEST(OperatorTestsFloatNonComplexAllExecs, OperatorFuncs) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + + ExecType exec{}; + auto tiv0 = make_tensor({}); + auto tov0 = make_tensor({}); + + TestType c = GenerateData(); + tiv0() = c; + + // example-begin log10-test-1 + (tov0 = log10(tiv0)).run(exec); + // example-end log10-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::scalar_internal_log10(c))); + + // example-begin log-test-1 + (tov0 = log(tiv0)).run(exec); + // example-end log-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::scalar_internal_log(c))); + + // example-begin log2-test-1 + (tov0 = log2(tiv0)).run(exec); + // example-end log2-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::scalar_internal_log2(c))); + + // example-begin floor-test-1 + (tov0 = floor(tiv0)).run(exec); + // example-end floor-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::scalar_internal_floor(c))); + + // example-begin ceil-test-1 + (tov0 = ceil(tiv0)).run(exec); + // example-end ceil-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::scalar_internal_ceil(c))); + + // example-begin round-test-1 + (tov0 = round(tiv0)).run(exec); + // example-end round-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::scalar_internal_round(c))); + + // example-begin sqrt-test-1 + (tov0 = sqrt(tiv0)).run(exec); + // example-end sqrt-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::scalar_internal_sqrt(c))); + + // example-begin rsqrt-test-1 + (tov0 = rsqrt(tiv0)).run(exec); + // example-end rsqrt-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::scalar_internal_rsqrt(c))); + + MATX_EXIT_HANDLER(); +} + diff --git a/test/00_operators/operator_func_integral_test.cu b/test/00_operators/operator_func_integral_test.cu new file mode 100644 index 00000000..53d90d4e --- /dev/null +++ b/test/00_operators/operator_func_integral_test.cu @@ -0,0 +1,31 @@ +#include "operator_test_types.hpp" +#include "matx.h" +#include "test_types.h" +#include "utilities.h" + +using namespace matx; +using namespace matx::test; + +TYPED_TEST(OperatorTestsIntegralAllExecs, OperatorFuncs) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + + ExecType exec{}; + auto tiv0 = make_tensor({}); + auto tov0 = make_tensor({}); + + TestType c = GenerateData(); + tiv0() = c; + TestType mod = 2; + + // example-begin mod-test-1 + (tov0 = tiv0 % mod).run(exec); + // example-end mod-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), c % mod)); + + MATX_EXIT_HANDLER(); +} + diff --git a/test/00_operators/operator_func_nd_test.cu b/test/00_operators/operator_func_nd_test.cu new file mode 100644 index 00000000..cdf84795 --- /dev/null +++ b/test/00_operators/operator_func_nd_test.cu @@ -0,0 +1,35 @@ +#include "operator_test_types.hpp" +#include "matx.h" +#include "test_types.h" +#include "utilities.h" + +using namespace matx; +using namespace matx::test; + +TYPED_TEST(OperatorTestsFloatNonComplexAllExecs, NDOperatorFuncs) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + + ExecType exec{}; + + auto a = make_tensor({1,2,3,4,5}); + auto b = make_tensor({1,2,3,4,5}); + (a = ones(a.Shape())).run(exec); + exec.sync(); + (b = ones(b.Shape())).run(exec); + exec.sync(); + (a = a + b).run(exec); + + { + if constexpr (is_cuda_non_jit_executor) { + auto t0 = make_tensor({}); + (t0 = sum(a)).run(exec); + exec.sync(); + ASSERT_EQ(t0(), static_cast(2 * a.TotalSize())); + } + } + MATX_EXIT_HANDLER(); +} + diff --git a/test/00_operators/operator_func_numeric_noncomplex_test.cu b/test/00_operators/operator_func_numeric_noncomplex_test.cu new file mode 100644 index 00000000..cc8afc21 --- /dev/null +++ b/test/00_operators/operator_func_numeric_noncomplex_test.cu @@ -0,0 +1,76 @@ +#include "operator_test_types.hpp" +#include "matx.h" +#include "test_types.h" +#include "utilities.h" + +using namespace matx; +using namespace matx::test; + +TYPED_TEST(OperatorTestsNumericNonComplexAllExecs, OperatorFuncs) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + + ExecType exec{}; + auto tiv0 = make_tensor({}); + auto tov0 = make_tensor({}); + + TestType c = GenerateData(); + tiv0() = c; + TestType d = c + 1; + + // example-begin max-el-test-1 + (tov0 = max(tiv0, d)).run(exec); + // example-end max-el-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), std::max(c, d))); + + // example-begin min-el-test-1 + (tov0 = min(tiv0, d)).run(exec); + // example-end min-el-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), std::min(c, d))); + + // These operators convert type T into bool + auto tob = make_tensor({}); + + // example-begin lt-test-1 + (tob = tiv0 < d).run(exec); + // example-end lt-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tob(), c < d)); + + // example-begin gt-test-1 + (tob = tiv0 > d).run(exec); + // example-end gt-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tob(), c > d)); + + // example-begin lte-test-1 + (tob = tiv0 <= d).run(exec); + // example-end lte-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tob(), c <= d)); + + // example-begin gte-test-1 + (tob = tiv0 >= d).run(exec); + // example-end gte-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tob(), c >= d)); + + // example-begin eq-test-1 + (tob = tiv0 == d).run(exec); + // example-end eq-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tob(), c == d)); + + // example-begin neq-test-1 + (tob = tiv0 != d).run(exec); + // example-end neq-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tob(), c != d)); + + MATX_EXIT_HANDLER(); +} + diff --git a/test/00_operators/operator_func_numeric_test.cu b/test/00_operators/operator_func_numeric_test.cu new file mode 100644 index 00000000..b79ce157 --- /dev/null +++ b/test/00_operators/operator_func_numeric_test.cu @@ -0,0 +1,77 @@ +#include "operator_test_types.hpp" +#include "matx.h" +#include "test_types.h" +#include "utilities.h" + +using namespace matx; +using namespace matx::test; + +TYPED_TEST(OperatorTestsNumericAllExecs, OperatorFuncs) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + + ExecType exec{}; + auto tiv0 = make_tensor({}); + auto tov0 = make_tensor({}); + + TestType c = GenerateData(); + tiv0() = c; + + // example-begin add-test-1 + (tov0 = tiv0 + tiv0).run(exec); + // example-end add-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), c + c)); + + // example-begin sub-test-1 + (tov0 = tiv0 - tiv0).run(exec); + // example-end sub-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), c - c)); + + // example-begin mul-test-1 + (tov0 = tiv0 * tiv0).run(exec); + // example-end mul-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), c * c)); + + // example-begin div-test-1 + (tov0 = tiv0 / tiv0).run(exec); + // example-end div-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), c / c)); + + // example-begin neg-test-1 + (tov0 = -tiv0).run(exec); + // example-end neg-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), -c)); + + // example-begin IF-test-1 + IF(tiv0 == tiv0, tov0 = c).run(exec); + // example-end IF-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), c)); + + TestType p = 2.0f; + // example-begin pow-test-1 + (tov0 = as_type(pow(tiv0, p))).run(exec); + // example-end pow-test-1 + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::scalar_internal_pow(c, p))); + + TestType three = 3.0f; + + (tov0 = tiv0 * tiv0 * (tiv0 + tiv0) / tiv0 + three).run(exec); + exec.sync(); + + TestType res; + res = c * c * (c + c) / c + three; + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), res, 0.07)); + + + MATX_EXIT_HANDLER(); +} + diff --git a/test/00_operators/operator_func_r2c_test.cu b/test/00_operators/operator_func_r2c_test.cu new file mode 100644 index 00000000..c40372b8 --- /dev/null +++ b/test/00_operators/operator_func_r2c_test.cu @@ -0,0 +1,49 @@ +#include "operator_test_types.hpp" +#include "matx.h" +#include "test_types.h" +#include "utilities.h" + +using namespace matx; +using namespace matx::test; + +TYPED_TEST(OperatorTestsFloatNonComplexAllExecs, OperatorFuncsR2C) +{ + MATX_ENTER_HANDLER(); + using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; + + ExecType exec{}; + + auto tiv0 = make_tensor({}); + auto tov0 = make_tensor>({}); + // example-begin expj-test-1 + // TestType is float, double, bf16, etc. + tiv0() = static_cast(M_PI/2.0); + (tov0 = expj(tiv0)).run(exec); + // tov0 is complex with value 0 + 1j + // example-end expj-test-1 + + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::complex_from_scalar_t(0.0, 1.0))); + + tiv0() = static_cast(-1.0 * M_PI); + (tov0 = expj(tiv0)).run(exec); + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::complex_from_scalar_t(-1.0, 0.0))); + + tiv0() = 0; + (tov0 = expj(tiv0)).run(exec); + exec.sync(); + EXPECT_TRUE(MatXUtils::MatXTypeCompare(tov0(), detail::complex_from_scalar_t(1.0, 0.0))); + + TestType c = GenerateData(); + tiv0() = c; + (tov0 = expj(tiv0)).run(exec); + exec.sync(); + + EXPECT_TRUE(MatXUtils::MatXTypeCompare( + tov0(), + typename detail::complex_from_scalar_t(detail::scalar_internal_cos(tiv0()), detail::scalar_internal_sin(tiv0())))); + MATX_EXIT_HANDLER(); +} + diff --git a/test/00_operators/operator_test_types.hpp b/test/00_operators/operator_test_types.hpp index d5bc57dc..02b6b0b9 100644 --- a/test/00_operators/operator_test_types.hpp +++ b/test/00_operators/operator_test_types.hpp @@ -82,27 +82,120 @@ class OperatorTestsBooleanAllExecs : public ::testing::Test {}; template class OperatorTestsCastToFloatAllExecs : public ::testing::Test {}; +// Operator-specific type aliases using ExecutorTypesAllWithJIT instead of ExecutorTypesAll +using MatXFloatNonHalfTypesAllExecsWithJIT = TupleToTypes::type>::type; +using MatXNumericNonComplexTypesAllExecsWithJIT = TupleToTypes::type>::type; +using MatXFloatNonComplexNonHalfTypesAllExecsWithJIT = TupleToTypes::type>::type; +using MatXTypesFloatNonComplexAllExecsWithJIT = TupleToTypes::type>::type; +using MatXTypesNumericAllExecsWithJIT = TupleToTypes::type>::type; +using MatXNumericNoHalfTypesAllExecsWithJIT = TupleToTypes::type>::type; +using MatXComplexNonHalfTypesAllExecsWithJIT = TupleToTypes::type>::type; +using MatXComplexTypesAllExecsWithJIT = TupleToTypes::type>::type; +using MatXAllTypesAllExecsWithJIT = TupleToTypes::type>::type; +using MatXTypesFloatAllExecsWithJIT = TupleToTypes::type>::type; +using MatXTypesIntegralAllExecsWithJIT = TupleToTypes::type>::type; +using MatXTypesBooleanAllExecsWithJIT = TupleToTypes::type>::type; +using MatXTypesCastToFloatAllExecsWithJIT = TupleToTypes::type>::type; + TYPED_TEST_SUITE(OperatorTestsFloatNonHalf, - MatXFloatNonHalfTypesAllExecs); + MatXFloatNonHalfTypesAllExecsWithJIT); TYPED_TEST_SUITE(OperatorTestsNumericNonComplexAllExecs, - MatXNumericNonComplexTypesAllExecs); + MatXNumericNonComplexTypesAllExecsWithJIT); TYPED_TEST_SUITE(OperatorTestsFloatNonComplexNonHalfAllExecs, - MatXFloatNonComplexNonHalfTypesAllExecs); + MatXFloatNonComplexNonHalfTypesAllExecsWithJIT); TYPED_TEST_SUITE(OperatorTestsFloatNonComplexAllExecs, - MatXTypesFloatNonComplexAllExecs); + MatXTypesFloatNonComplexAllExecsWithJIT); TYPED_TEST_SUITE(OperatorTestsFloatNonComplexSingleThreadedHostAllExecs, MatXTypesFloatNonComplexSingleThreadedHostAllExecs); TYPED_TEST_SUITE(OperatorTestsNumericAllExecs, - MatXTypesNumericAllExecs); -TYPED_TEST_SUITE(OperatorTestsNumericNoHalfAllExecs, MatXNumericNoHalfTypesAllExecs); -TYPED_TEST_SUITE(OperatorTestsComplexNonHalfTypesAllExecs, MatXComplexNonHalfTypesAllExecs); -TYPED_TEST_SUITE(OperatorTestsComplexTypesAllExecs, MatXComplexTypesAllExecs); -TYPED_TEST_SUITE(OperatorTestsAllExecs, MatXAllTypesAllExecs); -TYPED_TEST_SUITE(OperatorTestsFloatAllExecs, MatXTypesFloatAllExecs); -TYPED_TEST_SUITE(OperatorTestsIntegralAllExecs, MatXTypesIntegralAllExecs); -TYPED_TEST_SUITE(OperatorTestsBooleanAllExecs, MatXTypesBooleanAllExecs); -TYPED_TEST_SUITE(OperatorTestsCastToFloatAllExecs, MatXTypesCastToFloatAllExecs); + MatXTypesNumericAllExecsWithJIT); +TYPED_TEST_SUITE(OperatorTestsNumericNoHalfAllExecs, MatXNumericNoHalfTypesAllExecsWithJIT); +TYPED_TEST_SUITE(OperatorTestsComplexNonHalfTypesAllExecs, MatXComplexNonHalfTypesAllExecsWithJIT); +TYPED_TEST_SUITE(OperatorTestsComplexTypesAllExecs, MatXComplexTypesAllExecsWithJIT); +TYPED_TEST_SUITE(OperatorTestsAllExecs, MatXAllTypesAllExecsWithJIT); +TYPED_TEST_SUITE(OperatorTestsFloatAllExecs, MatXTypesFloatAllExecsWithJIT); +TYPED_TEST_SUITE(OperatorTestsIntegralAllExecs, MatXTypesIntegralAllExecsWithJIT); +TYPED_TEST_SUITE(OperatorTestsBooleanAllExecs, MatXTypesBooleanAllExecsWithJIT); +TYPED_TEST_SUITE(OperatorTestsCastToFloatAllExecs, MatXTypesCastToFloatAllExecsWithJIT); + +// Operator-specific type aliases using ExecutorTypesAllWithoutJIT instead of ExecutorTypesAll +using MatXFloatNonHalfTypesAllExecsWithoutJIT = TupleToTypes::type>::type; +using MatXNumericNonComplexTypesAllExecsWithoutJIT = TupleToTypes::type>::type; +using MatXFloatNonComplexNonHalfTypesAllExecsWithoutJIT = TupleToTypes::type>::type; +using MatXTypesFloatNonComplexAllExecsWithoutJIT = TupleToTypes::type>::type; +using MatXTypesNumericAllExecsWithoutJIT = TupleToTypes::type>::type; +using MatXNumericNoHalfTypesAllExecsWithoutJIT = TupleToTypes::type>::type; +using MatXComplexNonHalfTypesAllExecsWithoutJIT = TupleToTypes::type>::type; +using MatXComplexTypesAllExecsWithoutJIT = TupleToTypes::type>::type; +using MatXAllTypesAllExecsWithoutJIT = TupleToTypes::type>::type; +using MatXTypesFloatAllExecsWithoutJIT = TupleToTypes::type>::type; +using MatXTypesIntegralAllExecsWithoutJIT = TupleToTypes::type>::type; +using MatXTypesBooleanAllExecsWithoutJIT = TupleToTypes::type>::type; +using MatXTypesCastToFloatAllExecsWithoutJIT = TupleToTypes::type>::type; + +TYPED_TEST_SUITE(OperatorTestsFloatNonHalfWithoutJIT, + MatXFloatNonHalfTypesAllExecsWithoutJIT); +TYPED_TEST_SUITE(OperatorTestsNumericNonComplexAllExecsWithoutJIT, + MatXNumericNonComplexTypesAllExecsWithoutJIT); +TYPED_TEST_SUITE(OperatorTestsFloatNonComplexNonHalfAllExecsWithoutJIT, + MatXFloatNonComplexNonHalfTypesAllExecsWithoutJIT); +TYPED_TEST_SUITE(OperatorTestsFloatNonComplexAllExecsWithoutJIT, + MatXTypesFloatNonComplexAllExecsWithoutJIT); +TYPED_TEST_SUITE(OperatorTestsFloatNonComplexSingleThreadedHostAllExecsWithoutJIT, + MatXTypesFloatNonComplexSingleThreadedHostAllExecs); +TYPED_TEST_SUITE(OperatorTestsNumericAllExecsWithoutJIT, + MatXTypesNumericAllExecsWithoutJIT); +TYPED_TEST_SUITE(OperatorTestsNumericNoHalfAllExecsWithoutJIT, MatXNumericNoHalfTypesAllExecsWithoutJIT); +TYPED_TEST_SUITE(OperatorTestsComplexNonHalfTypesAllExecsWithoutJIT, MatXComplexNonHalfTypesAllExecsWithoutJIT); +TYPED_TEST_SUITE(OperatorTestsComplexTypesAllExecsWithoutJIT, MatXComplexTypesAllExecsWithoutJIT); +TYPED_TEST_SUITE(OperatorTestsAllExecsWithoutJIT, MatXAllTypesAllExecsWithoutJIT); +TYPED_TEST_SUITE(OperatorTestsFloatAllExecsWithoutJIT, MatXTypesFloatAllExecsWithoutJIT); +TYPED_TEST_SUITE(OperatorTestsIntegralAllExecsWithoutJIT, MatXTypesIntegralAllExecsWithoutJIT); +TYPED_TEST_SUITE(OperatorTestsBooleanAllExecsWithoutJIT, MatXTypesBooleanAllExecsWithoutJIT); +TYPED_TEST_SUITE(OperatorTestsCastToFloatAllExecsWithoutJIT, MatXTypesCastToFloatAllExecsWithoutJIT); + +// Template class declarations for WithoutJIT types +template +class OperatorTestsFloatNonHalfWithoutJIT : public ::testing::Test {}; + +template +class OperatorTestsNumericNonComplexAllExecsWithoutJIT : public ::testing::Test {}; + +template +class OperatorTestsFloatNonComplexNonHalfAllExecsWithoutJIT : public ::testing::Test {}; +template +class OperatorTestsFloatNonComplexAllExecsWithoutJIT : public ::testing::Test {}; + +template +class OperatorTestsFloatNonComplexSingleThreadedHostAllExecsWithoutJIT : public ::testing::Test {}; + +template +class OperatorTestsNumericAllExecsWithoutJIT : public ::testing::Test {}; + +template +class OperatorTestsNumericNoHalfAllExecsWithoutJIT : public ::testing::Test {}; + +template +class OperatorTestsComplexNonHalfTypesAllExecsWithoutJIT : public ::testing::Test {}; + +template +class OperatorTestsComplexTypesAllExecsWithoutJIT : public ::testing::Test {}; + +template +class OperatorTestsAllExecsWithoutJIT : public ::testing::Test {}; + +template +class OperatorTestsFloatAllExecsWithoutJIT : public ::testing::Test {}; + +template +class OperatorTestsIntegralAllExecsWithoutJIT : public ::testing::Test {}; + +template +class OperatorTestsBooleanAllExecsWithoutJIT : public ::testing::Test {}; + +template +class OperatorTestsCastToFloatAllExecsWithoutJIT : public ::testing::Test {}; } // namespace test } // namespace matx \ No newline at end of file diff --git a/test/00_operators/r2c_test.cu b/test/00_operators/r2c_test.cu index 15cadda4..abe8031b 100644 --- a/test/00_operators/r2c_test.cu +++ b/test/00_operators/r2c_test.cu @@ -18,7 +18,13 @@ TYPED_TEST(OperatorTestsFloatNonComplexNonHalfAllExecs, R2COp) // r2c requires FFT support, so we need to check the executor here if constexpr (!detail::CheckFFTSupport()) { GTEST_SKIP(); - } + } + +#ifndef MATX_EN_MATHDX + if constexpr (is_cuda_jit_executor_v) { + GTEST_SKIP(); + } +#endif const int N1 = 5; const int N2 = 6; diff --git a/test/00_operators/simple_executor_accessor_test.cu b/test/00_operators/simple_executor_accessor_test.cu index 922055c8..c3d157a0 100644 --- a/test/00_operators/simple_executor_accessor_test.cu +++ b/test/00_operators/simple_executor_accessor_test.cu @@ -6,7 +6,7 @@ using namespace matx; using namespace matx::test; -TYPED_TEST(OperatorTestsAllExecs, SimpleExecutorAccessorTests) +TYPED_TEST(OperatorTestsAllExecsWithoutJIT, SimpleExecutorAccessorTests) { MATX_ENTER_HANDLER(); diff --git a/test/00_operators/square_copy_transpose_test.cu b/test/00_operators/square_copy_transpose_test.cu index b94c860e..4e0231ba 100644 --- a/test/00_operators/square_copy_transpose_test.cu +++ b/test/00_operators/square_copy_transpose_test.cu @@ -8,7 +8,7 @@ using namespace matx::test; -TYPED_TEST(OperatorTestsNumericAllExecs, SquareCopyTranspose) +TYPED_TEST(OperatorTestsNumericNoHalfAllExecs, SquareCopyTranspose) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; @@ -32,8 +32,7 @@ TYPED_TEST(OperatorTestsNumericAllExecs, SquareCopyTranspose) for (index_t i = 0; i < count; i++) { for (index_t j = 0; j < count; j++) { - EXPECT_TRUE(MatXUtils::MatXTypeCompare(t2t(i, j), - TestType(i * count + (double)j))); + ASSERT_EQ(t2t(i, j), TestType(i * count + (double)j)); } } @@ -52,7 +51,7 @@ TYPED_TEST(OperatorTestsNumericAllExecs, SquareCopyTranspose) MATX_EXIT_HANDLER(); } -TYPED_TEST(OperatorTestsNumericAllExecs, NonSquareTranspose) +TYPED_TEST(OperatorTestsNumericNoHalfAllExecs, NonSquareTranspose) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; diff --git a/test/00_operators/zipvec_test.cu b/test/00_operators/zipvec_test.cu index 7f001a00..d9e626ff 100644 --- a/test/00_operators/zipvec_test.cu +++ b/test/00_operators/zipvec_test.cu @@ -6,7 +6,8 @@ using namespace matx; using namespace matx::test; -TYPED_TEST(OperatorTestsFloatNonComplexNonHalfAllExecs, ZipVecOp) +// No JIT since we use custom types that can't be stringified +TYPED_TEST(OperatorTestsFloatNonComplexNonHalfAllExecsWithoutJIT, ZipVecOp) { MATX_ENTER_HANDLER(); diff --git a/test/include/test_types.h b/test/include/test_types.h index ea0c08c4..45566356 100644 --- a/test/include/test_types.h +++ b/test/include/test_types.h @@ -75,6 +75,13 @@ template <> auto inline GenerateData>() using ExecutorTypesAll = cuda::std::tuple; using ExecutorTypesCUDAOnly = cuda::std::tuple; using ExecutorTypesAllSingleThreadedHost = cuda::std::tuple; +#ifdef MATX_EN_JIT +using ExecutorTypesAllWithJIT = cuda::std::tuple; +using ExecutorTypesAllWithoutJIT = cuda::std::tuple; +#else +using ExecutorTypesAllWithJIT = ExecutorTypesAll; +using ExecutorTypesAllWithoutJIT = ExecutorTypesAll; +#endif /* Taken from https://stackoverflow.com/questions/70404549/cartesian-product-of-stdtuple */ template