Skip to content

Commit ce9d5cb

Browse files
committed
Start of working through unit tests with JIT executor
1 parent e78cd18 commit ce9d5cb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

102 files changed

+2743
-400
lines changed

include/matx/core/capabilities.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ namespace detail {
7070
MAX_EPT_VEC_LOAD, // The maximum EPT for a vector load.
7171
ELEMENT_WISE, // Whether the operator is element-wise (safe with aliasing)
7272
ALIASED_MEMORY, // Whether the operator's input and output pointers alias
73+
GLOBAL_KERNEL, // Kernel operates entirely on a global level per chunk of data. False when at least one operator works on a block level
7374
// Add more capabilities as needed
7475
};
7576

@@ -123,7 +124,7 @@ namespace detail {
123124
struct capability_attributes<OperatorCapability::SUPPORTS_JIT> {
124125
using type = bool;
125126
using input_type = VoidCapabilityType;
126-
static constexpr bool default_value = true;
127+
static constexpr bool default_value = false;
127128
static constexpr bool or_identity = false;
128129
static constexpr bool and_identity = true;
129130
};
@@ -144,7 +145,16 @@ namespace detail {
144145
static constexpr bool default_value = false;
145146
static constexpr bool or_identity = false;
146147
static constexpr bool and_identity = true;
147-
};
148+
};
149+
150+
template <>
151+
struct capability_attributes<OperatorCapability::GLOBAL_KERNEL> {
152+
using type = bool;
153+
using input_type = VoidCapabilityType;
154+
static constexpr bool default_value = true;
155+
static constexpr bool or_identity = false;
156+
static constexpr bool and_identity = true;
157+
};
148158

149159
template <>
150160
struct capability_attributes<OperatorCapability::ALIASED_MEMORY> {
@@ -250,6 +260,10 @@ namespace detail {
250260
if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) {
251261
return detail::type_to_string<OperatorType>();
252262
}
263+
else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) {
264+
// If this is not a matx operator (like a constant or a lambda), we assume it supports JIT.
265+
return true;
266+
}
253267
else {
254268
return capability_attributes<Cap>::default_value;
255269
}
@@ -274,6 +288,8 @@ namespace detail {
274288
return CapabilityQueryType::AND_QUERY; // If any sub-operator supports JIT, the expression might be JIT-able.
275289
case OperatorCapability::ASYNC_LOADS_REQUESTED:
276290
return CapabilityQueryType::OR_QUERY; // If any sub-operator requires asynchronous loads, the expression might require asynchronous loads.
291+
case OperatorCapability::GLOBAL_KERNEL:
292+
return CapabilityQueryType::AND_QUERY; // If any sub-operator operates on a global level, the expression might operate on a global level.
277293
case OperatorCapability::ELEMENTS_PER_THREAD:
278294
return CapabilityQueryType::RANGE_QUERY; // The expression should use the range of elements per thread of its children.
279295
case OperatorCapability::SET_ELEMENTS_PER_THREAD:

include/matx/core/half.h

Lines changed: 153 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
#pragma once
3434

35+
#include <cuda/std/cstdint>
36+
#include <cuda/std/bit>
3537
#include <cuda/std/cmath>
3638
#include <cuda/std/type_traits>
3739

@@ -41,6 +43,83 @@
4143

4244
namespace matx {
4345

46+
// Constexpr helper functions for float to half conversion
47+
namespace detail {
48+
49+
/**
50+
* @brief Constexpr conversion from float to FP16 bits
51+
*
52+
* @param f Input float value
53+
* @return uint16_t FP16 bit representation
54+
*/
55+
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ uint16_t float_to_fp16_bits(float f) {
56+
// Use bit_cast for constexpr context
57+
uint32_t bits = cuda::std::bit_cast<uint32_t>(f);
58+
59+
uint32_t sign = (bits >> 16) & 0x8000;
60+
int32_t exponent = static_cast<int32_t>(((bits >> 23) & 0xff)) - 127 + 15;
61+
uint32_t mantissa = (bits >> 13) & 0x3ff;
62+
63+
// Handle special cases
64+
if (exponent <= 0) {
65+
// Subnormal or zero
66+
if (exponent < -10) {
67+
// Too small, flush to zero
68+
return static_cast<uint16_t>(sign);
69+
}
70+
// Subnormal
71+
mantissa = (mantissa | 0x400) >> (1 - exponent);
72+
return static_cast<uint16_t>(sign | mantissa);
73+
} else if (exponent >= 0x1f) {
74+
// Overflow to infinity or NaN
75+
if (exponent == 0x1f + (127 - 15) && mantissa != 0) {
76+
// NaN
77+
return static_cast<uint16_t>(sign | 0x7e00 | (mantissa != 0 ? 0x200 : 0));
78+
}
79+
// Infinity
80+
return static_cast<uint16_t>(sign | 0x7c00);
81+
}
82+
83+
return static_cast<uint16_t>(sign | (static_cast<uint32_t>(exponent) << 10) | mantissa);
84+
}
85+
86+
/**
87+
* @brief Constexpr conversion from float to BF16 bits
88+
*
89+
* @param f Input float value
90+
* @return uint16_t BF16 bit representation (top 16 bits of float)
91+
*/
92+
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ uint16_t float_to_bf16_bits(float f) {
93+
// BF16 is just the top 16 bits of a float32
94+
// With rounding to nearest even
95+
uint32_t bits = cuda::std::bit_cast<uint32_t>(f);
96+
97+
// Round to nearest even
98+
uint32_t rounding_bias = 0x00007FFF + ((bits >> 16) & 1);
99+
bits += rounding_bias;
100+
uint16_t result = static_cast<uint16_t>(bits >> 16);
101+
102+
return result;
103+
}
104+
105+
/**
106+
* @brief Helper to convert float to half type at compile time
107+
*
108+
* @tparam T The target half type (__half or __nv_bfloat16)
109+
* @param f Input float value
110+
* @return T Half-precision value
111+
*/
112+
template <typename T>
113+
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ T float_to_half_constexpr(float f) {
114+
if constexpr (cuda::std::is_same_v<T, __half>) {
115+
return cuda::std::bit_cast<__half>(float_to_fp16_bits(f));
116+
} else {
117+
return cuda::std::bit_cast<__nv_bfloat16>(float_to_bf16_bits(f));
118+
}
119+
}
120+
121+
} // namespace detail
122+
44123
/**
45124
* Template class for half precison numbers (__half and __nv_bfloat16). CUDA
46125
* does not have standardized classes/operators available on both host and
@@ -64,12 +143,49 @@ template <typename T> struct alignas(sizeof(T)) matxHalf {
64143
__MATX_INLINE__ matxHalf(const matxHalf<T> &x_) noexcept = default;
65144

66145
/**
67-
* @brief Copy constructor from arbitrary type
146+
* @brief Constexpr constructor from float
147+
*
148+
* @param f Float value to convert
149+
*/
150+
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf(float f) noexcept
151+
: x(detail::float_to_half_constexpr<T>(f))
152+
{
153+
}
154+
155+
/**
156+
* @brief Constexpr constructor from double
157+
*
158+
* @param d Double value to convert
159+
*/
160+
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf(double d) noexcept
161+
: x(detail::float_to_half_constexpr<T>(static_cast<float>(d)))
162+
{
163+
}
164+
165+
/**
166+
* @brief Constructor from integral types (constexpr)
167+
*
168+
* @tparam T2 Integral type to copy from
169+
* @param x_ Value to copy
170+
*/
171+
template <typename T2,
172+
cuda::std::enable_if_t<cuda::std::is_integral_v<T2>, int> = 0>
173+
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf(T2 x_) noexcept
174+
: x(detail::float_to_half_constexpr<T>(static_cast<float>(x_)))
175+
{
176+
}
177+
178+
/**
179+
* @brief Copy constructor from arbitrary type (non-constexpr for non-arithmetic types)
68180
*
69181
* @tparam T2 Type to copy from
70182
* @param x_ Value to copy
71183
*/
72-
template <typename T2>
184+
template <typename T2,
185+
cuda::std::enable_if_t<
186+
!cuda::std::is_same_v<cuda::std::decay_t<T2>, float> &&
187+
!cuda::std::is_same_v<cuda::std::decay_t<T2>, double> &&
188+
!cuda::std::is_integral_v<T2>, int> = 0>
73189
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf(const T2 &x_) noexcept
74190
: x(static_cast<float>(x_))
75191
{
@@ -1316,3 +1432,38 @@ using matxFp16 = matxHalf<__half>; ///< Alias for fp16
13161432
using matxBf16 = matxHalf<__nv_bfloat16>; ///< Alias for bf16
13171433

13181434
}; // namespace matx
1435+
1436+
#ifndef __CUDACC_RTC__
1437+
// Add std::formatter specializations for matxFp16 and matxBf16
1438+
#include <format>
1439+
1440+
namespace std {
1441+
1442+
/**
1443+
* @brief std::formatter specialization for matxFp16
1444+
*
1445+
* Enables matxFp16 to work with std::format by converting to float
1446+
*/
1447+
template <>
1448+
struct formatter<matx::matxFp16> : formatter<float> {
1449+
template <typename FormatContext>
1450+
auto format(const matx::matxFp16& val, FormatContext& ctx) const {
1451+
return formatter<float>::format(static_cast<float>(val), ctx);
1452+
}
1453+
};
1454+
1455+
/**
1456+
* @brief std::formatter specialization for matxBf16
1457+
*
1458+
* Enables matxBf16 to work with std::format by converting to float
1459+
*/
1460+
template <>
1461+
struct formatter<matx::matxBf16> : formatter<float> {
1462+
template <typename FormatContext>
1463+
auto format(const matx::matxBf16& val, FormatContext& ctx) const {
1464+
return formatter<float>::format(static_cast<float>(val), ctx);
1465+
}
1466+
};
1467+
1468+
} // namespace std
1469+
#endif

include/matx/core/half_complex.h

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,3 +1056,62 @@ using matxFp16Complex = matxHalfComplex<matxFp16>; ///< Alias for a MatX fp16 co
10561056
using matxBf16Complex = matxHalfComplex<matxBf16>; ///< Alias for a MatXbf16 complex wrapper
10571057

10581058
}; // namespace matx
1059+
1060+
#ifndef __CUDACC_RTC__
1061+
// Add std::formatter specializations for matxFp16Complex and matxBf16Complex
1062+
#include <format>
1063+
1064+
namespace std {
1065+
1066+
/**
1067+
* @brief std::formatter specialization for matxFp16Complex
1068+
*
1069+
* Enables matxFp16Complex to work with std::format by converting to complex<float>
1070+
*/
1071+
template <>
1072+
struct formatter<matx::matxFp16Complex> {
1073+
template <typename ParseContext>
1074+
constexpr auto parse(ParseContext& ctx) {
1075+
return ctx.begin();
1076+
}
1077+
1078+
template <typename FormatContext>
1079+
auto format(const matx::matxFp16Complex& val, FormatContext& ctx) const {
1080+
float real_val = static_cast<float>(val.real());
1081+
float imag_val = static_cast<float>(val.imag());
1082+
1083+
if (imag_val >= 0) {
1084+
return std::format_to(ctx.out(), "({}+{}i)", real_val, imag_val);
1085+
} else {
1086+
return std::format_to(ctx.out(), "({}{}i)", real_val, imag_val);
1087+
}
1088+
}
1089+
};
1090+
1091+
/**
1092+
* @brief std::formatter specialization for matxBf16Complex
1093+
*
1094+
* Enables matxBf16Complex to work with std::format by converting to complex<float>
1095+
*/
1096+
template <>
1097+
struct formatter<matx::matxBf16Complex> {
1098+
template <typename ParseContext>
1099+
constexpr auto parse(ParseContext& ctx) {
1100+
return ctx.begin();
1101+
}
1102+
1103+
template <typename FormatContext>
1104+
auto format(const matx::matxBf16Complex& val, FormatContext& ctx) const {
1105+
float real_val = static_cast<float>(val.real());
1106+
float imag_val = static_cast<float>(val.imag());
1107+
1108+
if (imag_val >= 0) {
1109+
return std::format_to(ctx.out(), "({}+{}i)", real_val, imag_val);
1110+
} else {
1111+
return std::format_to(ctx.out(), "({}{}i)", real_val, imag_val);
1112+
}
1113+
}
1114+
};
1115+
1116+
} // namespace std
1117+
#endif // __CUDACC_RTC__

include/matx/core/jit_includes.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434

3535
// This file is used for jitify/NVRTC preprocessing. Do NOT include any files in here that can't be
3636
// parsed on the device, and try to keep this minimal to avoid unnecessary dependencies.
37-
#include <cuda/barrier>
3837
#include <cuda/std/__algorithm/min.h>
3938
#include <cuda/std/__algorithm/max.h>
4039
#include "matx/core/defines.h"

include/matx/core/log.h

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -89,38 +89,12 @@ namespace std {
8989
}
9090
};
9191

92-
// Formatter for matxHalfComplex (fp16/bf16 complex)
93-
template<typename T>
94-
struct formatter<matx::matxHalfComplex<T>> {
95-
constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); }
96-
97-
template<typename FormatContext>
98-
auto format(const matx::matxHalfComplex<T>& c, FormatContext& ctx) const {
99-
return format_to(ctx.out(), "{}", matx::detail::format_complex(c));
100-
}
101-
};
102-
103-
// Formatter for matxFp16 (half-precision float)
104-
template<>
105-
struct formatter<matx::matxFp16> {
106-
constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); }
107-
108-
template<typename FormatContext>
109-
auto format(const matx::matxFp16& val, FormatContext& ctx) const {
110-
return format_to(ctx.out(), "{:g}", static_cast<float>(val));
111-
}
112-
};
92+
// Formatter for matxHalfComplex (fp16/bf16 complex) - moved to half_complex.h
93+
// Formatter for matxFp16 (half-precision float) - moved to half.h
94+
// Formatter for matxBf16 (bfloat16) - moved to half.h
11395

114-
// Formatter for matxBf16 (bfloat16)
115-
template<>
116-
struct formatter<matx::matxBf16> {
117-
constexpr auto parse(format_parse_context& ctx) { return ctx.begin(); }
118-
119-
template<typename FormatContext>
120-
auto format(const matx::matxBf16& val, FormatContext& ctx) const {
121-
return format_to(ctx.out(), "{:g}", static_cast<float>(val));
122-
}
123-
};
96+
// Note: The formatters for matxHalfComplex, matxFp16, and matxBf16 are now defined
97+
// in their respective header files (half_complex.h and half.h) with proper guards.
12498
}
12599

126100
namespace matx {

0 commit comments

Comments
 (0)