Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions include/matx/core/capabilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ namespace detail {
MAX_EPT_VEC_LOAD, // The maximum EPT for a vector load.
ELEMENT_WISE, // Whether the operator is element-wise (safe with aliasing)
ALIASED_MEMORY, // Whether the operator's input and output pointers alias
GLOBAL_KERNEL, // Kernel operates entirely on a global level per chunk of data. False when at least one operator works on a block level
// Add more capabilities as needed
};

Expand Down Expand Up @@ -123,7 +124,7 @@ namespace detail {
struct capability_attributes<OperatorCapability::SUPPORTS_JIT> {
using type = bool;
using input_type = VoidCapabilityType;
static constexpr bool default_value = true;
static constexpr bool default_value = false;
static constexpr bool or_identity = false;
static constexpr bool and_identity = true;
};
Expand All @@ -144,7 +145,16 @@ namespace detail {
static constexpr bool default_value = false;
static constexpr bool or_identity = false;
static constexpr bool and_identity = true;
};
};

template <>
struct capability_attributes<OperatorCapability::GLOBAL_KERNEL> {
using type = bool;
using input_type = VoidCapabilityType;
static constexpr bool default_value = true;
static constexpr bool or_identity = false;
static constexpr bool and_identity = true;
};

template <>
struct capability_attributes<OperatorCapability::ALIASED_MEMORY> {
Expand Down Expand Up @@ -250,6 +260,10 @@ namespace detail {
if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) {
return detail::type_to_string<OperatorType>();
}
else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) {
// If this is not a matx operator (like a constant or a lambda), we assume it supports JIT.
return true;
}
else {
return capability_attributes<Cap>::default_value;
}
Expand All @@ -274,6 +288,8 @@ namespace detail {
return CapabilityQueryType::AND_QUERY; // If any sub-operator supports JIT, the expression might be JIT-able.
case OperatorCapability::ASYNC_LOADS_REQUESTED:
return CapabilityQueryType::OR_QUERY; // If any sub-operator requires asynchronous loads, the expression might require asynchronous loads.
case OperatorCapability::GLOBAL_KERNEL:
return CapabilityQueryType::AND_QUERY; // If any sub-operator operates on a global level, the expression might operate on a global level.
case OperatorCapability::ELEMENTS_PER_THREAD:
return CapabilityQueryType::RANGE_QUERY; // The expression should use the range of elements per thread of its children.
case OperatorCapability::SET_ELEMENTS_PER_THREAD:
Expand Down
2 changes: 1 addition & 1 deletion include/matx/core/get_grid_dims.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ inline bool get_grid_dims(dim3 &blocks, dim3 &threads, const cuda::std::array<in

// For JIT code we want to use a grid-stride loop always
template <int RANK>
inline bool get_grid_dims_jit(dim3 &blocks, dim3 &threads, const cuda::std::array<index_t, RANK> &sizes, index_t ept, int groups_per_block,
inline bool get_grid_dims_block(dim3 &blocks, dim3 &threads, const cuda::std::array<index_t, RANK> &sizes, index_t ept, int groups_per_block,
int max_cta_size = 1024, bool force_size = false)
{
bool stride = false;
Expand Down
155 changes: 153 additions & 2 deletions include/matx/core/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

#pragma once

#include <cuda/std/cstdint>
#include <cuda/std/bit>
#include <cuda/std/cmath>
#include <cuda/std/type_traits>

Expand All @@ -41,6 +43,83 @@

namespace matx {

// Constexpr helper functions for float to half conversion
namespace detail {

/**
* @brief Constexpr conversion from float to FP16 bits
*
* @param f Input float value
* @return uint16_t FP16 bit representation
*/
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ uint16_t float_to_fp16_bits(float f) {
// Use bit_cast for constexpr context
uint32_t bits = cuda::std::bit_cast<uint32_t>(f);

uint32_t sign = (bits >> 16) & 0x8000;
int32_t exponent = static_cast<int32_t>(((bits >> 23) & 0xff)) - 127 + 15;
uint32_t mantissa = (bits >> 13) & 0x3ff;

// Handle special cases
if (exponent <= 0) {
// Subnormal or zero
if (exponent < -10) {
// Too small, flush to zero
return static_cast<uint16_t>(sign);
}
// Subnormal
mantissa = (mantissa | 0x400) >> (1 - exponent);
return static_cast<uint16_t>(sign | mantissa);
} else if (exponent >= 0x1f) {
// Overflow to infinity or NaN
if (exponent == 0x1f + (127 - 15) && mantissa != 0) {
// NaN
return static_cast<uint16_t>(sign | 0x7e00 | (mantissa != 0 ? 0x200 : 0));
}
// Infinity
return static_cast<uint16_t>(sign | 0x7c00);
}

return static_cast<uint16_t>(sign | (static_cast<uint32_t>(exponent) << 10) | mantissa);
}

/**
* @brief Constexpr conversion from float to BF16 bits
*
* @param f Input float value
* @return uint16_t BF16 bit representation (top 16 bits of float)
*/
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ uint16_t float_to_bf16_bits(float f) {
// BF16 is just the top 16 bits of a float32
// With rounding to nearest even
uint32_t bits = cuda::std::bit_cast<uint32_t>(f);

// Round to nearest even
uint32_t rounding_bias = 0x00007FFF + ((bits >> 16) & 1);
bits += rounding_bias;
uint16_t result = static_cast<uint16_t>(bits >> 16);

return result;
}

/**
* @brief Helper to convert float to half type at compile time
*
* @tparam T The target half type (__half or __nv_bfloat16)
* @param f Input float value
* @return T Half-precision value
*/
template <typename T>
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ T float_to_half_constexpr(float f) {
if constexpr (cuda::std::is_same_v<T, __half>) {
return cuda::std::bit_cast<__half>(float_to_fp16_bits(f));
} else {
return cuda::std::bit_cast<__nv_bfloat16>(float_to_bf16_bits(f));
}
}

} // namespace detail

/**
* Template class for half precison numbers (__half and __nv_bfloat16). CUDA
* does not have standardized classes/operators available on both host and
Expand All @@ -64,12 +143,49 @@ template <typename T> struct alignas(sizeof(T)) matxHalf {
__MATX_INLINE__ matxHalf(const matxHalf<T> &x_) noexcept = default;

/**
* @brief Copy constructor from arbitrary type
* @brief Constexpr constructor from float
*
* @param f Float value to convert
*/
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf(float f) noexcept
: x(detail::float_to_half_constexpr<T>(f))
{
}

/**
* @brief Constexpr constructor from double
*
* @param d Double value to convert
*/
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf(double d) noexcept
: x(detail::float_to_half_constexpr<T>(static_cast<float>(d)))
{
}

/**
* @brief Constructor from integral types (constexpr)
*
* @tparam T2 Integral type to copy from
* @param x_ Value to copy
*/
template <typename T2,
cuda::std::enable_if_t<cuda::std::is_integral_v<T2>, int> = 0>
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf(T2 x_) noexcept
: x(detail::float_to_half_constexpr<T>(static_cast<float>(x_)))
{
}

/**
* @brief Copy constructor from arbitrary type (non-constexpr for non-arithmetic types)
*
* @tparam T2 Type to copy from
* @param x_ Value to copy
*/
template <typename T2>
template <typename T2,
cuda::std::enable_if_t<
!cuda::std::is_same_v<cuda::std::decay_t<T2>, float> &&
!cuda::std::is_same_v<cuda::std::decay_t<T2>, double> &&
!cuda::std::is_integral_v<T2>, int> = 0>
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalf(const T2 &x_) noexcept
: x(static_cast<float>(x_))
{
Expand Down Expand Up @@ -1316,3 +1432,38 @@ using matxFp16 = matxHalf<__half>; ///< Alias for fp16
using matxBf16 = matxHalf<__nv_bfloat16>; ///< Alias for bf16

}; // namespace matx

#ifndef __CUDACC_RTC__
// Add std::formatter specializations for matxFp16 and matxBf16
#include <format>

namespace std {

/**
* @brief std::formatter specialization for matxFp16
*
* Enables matxFp16 to work with std::format by converting to float
*/
template <>
struct formatter<matx::matxFp16> : formatter<float> {
template <typename FormatContext>
auto format(const matx::matxFp16& val, FormatContext& ctx) const {
return formatter<float>::format(static_cast<float>(val), ctx);
}
};

/**
* @brief std::formatter specialization for matxBf16
*
* Enables matxBf16 to work with std::format by converting to float
*/
template <>
struct formatter<matx::matxBf16> : formatter<float> {
template <typename FormatContext>
auto format(const matx::matxBf16& val, FormatContext& ctx) const {
return formatter<float>::format(static_cast<float>(val), ctx);
}
};

} // namespace std
#endif
67 changes: 63 additions & 4 deletions include/matx/core/half_complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ template <typename T> struct alignas(sizeof(T) * 2) matxHalfComplex {
*
* @param x_ Object to copy from
*/
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__
matxHalfComplex(const cuda::std::complex<float> &x_) noexcept
: x(x_.real()), y(x_.imag())
{
Expand All @@ -73,7 +73,7 @@ template <typename T> struct alignas(sizeof(T) * 2) matxHalfComplex {
* @param x_ Value of scalar
*/
template <typename T2>
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(const T2 &x_) noexcept
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(const T2 &x_) noexcept
: x(static_cast<float>(x_)), y(0.0f)
{
}
Expand All @@ -87,7 +87,7 @@ template <typename T> struct alignas(sizeof(T) * 2) matxHalfComplex {
* @param y_ Imaginary value
*/
template <typename T2, typename T3>
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(const T2 &x_,
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(const T2 &x_,
const T3 &y_) noexcept
: x(static_cast<float>(x_)), y(static_cast<float>(y_))
{
Expand All @@ -103,7 +103,7 @@ template <typename T> struct alignas(sizeof(T) * 2) matxHalfComplex {
template <typename T2>
requires (cuda::std::is_same_v<cuda::std::decay_t<T2>, matxFp16> ||
cuda::std::is_same_v<cuda::std::decay_t<T2>, matxBf16>)
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(T2 &&x_, T2 &&y_) noexcept
constexpr __MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(T2 &&x_, T2 &&y_) noexcept
: x(static_cast<T>(x_)),
y(static_cast<T>(y_))
{
Expand Down Expand Up @@ -1056,3 +1056,62 @@ using matxFp16Complex = matxHalfComplex<matxFp16>; ///< Alias for a MatX fp16 co
using matxBf16Complex = matxHalfComplex<matxBf16>; ///< Alias for a MatXbf16 complex wrapper

}; // namespace matx

#ifndef __CUDACC_RTC__
// Add std::formatter specializations for matxFp16Complex and matxBf16Complex
#include <format>

namespace std {

/**
* @brief std::formatter specialization for matxFp16Complex
*
* Enables matxFp16Complex to work with std::format by converting to complex<float>
*/
template <>
struct formatter<matx::matxFp16Complex> {
template <typename ParseContext>
constexpr auto parse(ParseContext& ctx) {
return ctx.begin();
}

template <typename FormatContext>
auto format(const matx::matxFp16Complex& val, FormatContext& ctx) const {
float real_val = static_cast<float>(val.real());
float imag_val = static_cast<float>(val.imag());

if (imag_val >= 0) {
return std::format_to(ctx.out(), "({}+{}i)", real_val, imag_val);
} else {
return std::format_to(ctx.out(), "({}{}i)", real_val, imag_val);
}
}
};

/**
* @brief std::formatter specialization for matxBf16Complex
*
* Enables matxBf16Complex to work with std::format by converting to complex<float>
*/
template <>
struct formatter<matx::matxBf16Complex> {
template <typename ParseContext>
constexpr auto parse(ParseContext& ctx) {
return ctx.begin();
}

template <typename FormatContext>
auto format(const matx::matxBf16Complex& val, FormatContext& ctx) const {
float real_val = static_cast<float>(val.real());
float imag_val = static_cast<float>(val.imag());

if (imag_val >= 0) {
return std::format_to(ctx.out(), "({}+{}i)", real_val, imag_val);
} else {
return std::format_to(ctx.out(), "({}{}i)", real_val, imag_val);
}
}
};

} // namespace std
#endif // __CUDACC_RTC__
2 changes: 1 addition & 1 deletion include/matx/core/jit_includes.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@

// This file is used for jitify/NVRTC preprocessing. Do NOT include any files in here that can't be
// parsed on the device, and try to keep this minimal to avoid unnecessary dependencies.
#include <cuda/barrier>
#include <cuda/std/__algorithm/min.h>
#include <cuda/std/__algorithm/max.h>
#include <cuda/barrier>
#include "matx/core/defines.h"
#include "matx/core/type_utils_both.h"
#include "matx/core/vector.h"
Expand Down
Loading