Skip to content

Commit 7051e00

Browse files
authored
Replaced all uses of SFINAE with concepts for better error messages (#1081)
* Replaced all uses of SFINAE with concepts for better error messages
1 parent bf9c4c0 commit 7051e00

39 files changed

+679
-631
lines changed

docs_input/api/type_traits.rst

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,42 @@ Type Traits
44
MatX type traits help advanced developers to make compile-time decisions about types. Most of these are used extensively
55
inside of MatX, and are not needed in most user applications.
66

7+
MatX now uses C++20 concepts for type traits. Legacy variable templates (ending in ``_v``) and functions (ending in ``_t()``)
8+
are maintained for backward compatibility.
9+
10+
Type Manipulation
11+
=================
12+
713
.. doxygentypedef:: matx::promote_half_t
814
.. doxygenstruct:: matx::remove_cvref
15+
16+
Concepts (C++20)
17+
================
18+
19+
.. doxygenconcept:: matx::is_tensor
20+
.. doxygenconcept:: matx::is_matx_op_c
21+
.. doxygenconcept:: matx::is_executor
22+
.. doxygenconcept:: matx::is_matx_reduction
23+
.. doxygenconcept:: matx::is_matx_index_reduction
24+
.. doxygenconcept:: matx::is_cuda_complex
25+
.. doxygenconcept:: matx::is_complex
26+
.. doxygenconcept:: matx::is_complex_half
27+
.. doxygenconcept:: matx::is_half
28+
.. doxygenconcept:: matx::is_matx_half
29+
.. doxygenconcept:: matx::is_matx_type
30+
.. doxygenconcept:: matx::is_matx_shape
31+
.. doxygenconcept:: matx::is_matx_storage
32+
.. doxygenconcept:: matx::is_matx_storage_container
33+
.. doxygenconcept:: matx::is_matx_descriptor
34+
35+
Legacy Compatibility
36+
====================
37+
38+
Legacy functions and variables for backward compatibility:
39+
940
.. doxygenfunction:: matx::is_matx_op
1041
.. doxygenfunction:: matx::is_executor_t
11-
.. doxygenvariable:: matx::is_tensor_view_v
12-
.. doxygenvariable:: matx::is_matx_reduction_v
13-
.. doxygenvariable:: matx::is_matx_index_reduction_v
14-
.. doxygenvariable:: matx::is_cuda_complex_v
15-
.. doxygenvariable:: matx::is_complex_v
16-
.. doxygenvariable:: matx::is_complex_half_v
1742
.. doxygenfunction:: matx::IsHalfType
18-
.. doxygenvariable:: matx::is_half_v
19-
.. doxygenvariable:: matx::is_matx_type_v
20-
.. doxygenvariable:: matx::is_matx_shape_v
21-
.. doxygenvariable:: matx::is_matx_storage_v
22-
.. doxygenvariable:: matx::is_matx_storage_container_v
23-
.. doxygenvariable:: matx::is_matx_descriptor_v
43+
44+
Note: Legacy variable templates (``is_tensor_v``, ``is_matx_reduction_v``, etc.) are available
45+
for backward compatibility but are not documented here. Use the concepts above instead for new code.

include/matx/core/capabilities.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,21 @@ namespace detail {
9999

100100
using DefaultCapabilities = CapabilityParams<ElementsPerThread::ONE, false>;
101101

102-
// C++17-compatible trait to detect scoped enums
102+
// Concept to detect scoped enums
103+
template<typename T>
104+
concept is_scoped_enum_c = cuda::std::is_enum_v<T> &&
105+
!cuda::std::is_convertible_v<T, cuda::std::underlying_type_t<T>>;
106+
107+
// Legacy struct for backwards compatibility
103108
template<typename T, typename = void>
104109
struct is_scoped_enum : cuda::std::false_type {};
105110

106111
template<typename T>
107-
struct is_scoped_enum<T, cuda::std::enable_if_t<cuda::std::is_enum_v<T>>>
108-
: cuda::std::bool_constant<!cuda::std::is_convertible_v<T, cuda::std::underlying_type_t<T>>> {};
112+
requires is_scoped_enum_c<T>
113+
struct is_scoped_enum<T> : cuda::std::true_type {};
109114

110115
template<typename T>
111-
constexpr bool is_scoped_enum_v = is_scoped_enum<T>::value;
116+
constexpr bool is_scoped_enum_v = is_scoped_enum_c<T>;
112117

113118
// Trait to get default values and identities based on capability
114119
template <OperatorCapability Cap>

include/matx/core/half_complex.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,9 @@ template <typename T> struct alignas(sizeof(T) * 2) matxHalfComplex {
100100
* @param x_ Real value
101101
* @param y_ Imaginary value
102102
*/
103-
template <typename T2, cuda::std::enable_if_t<cuda::std::is_same_v<cuda::std::decay<T2>, matxFp16> ||
104-
cuda::std::is_same_v<cuda::std::decay<T2>, matxBf16>, bool> = true>
103+
template <typename T2>
104+
requires (cuda::std::is_same_v<cuda::std::decay_t<T2>, matxFp16> ||
105+
cuda::std::is_same_v<cuda::std::decay_t<T2>, matxBf16>)
105106
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex(T2 &&x_, T2 &&y_) noexcept
106107
: x(static_cast<T>(x_)),
107108
y(static_cast<T>(y_))
@@ -167,14 +168,14 @@ template <typename T> struct alignas(sizeof(T) * 2) matxHalfComplex {
167168
* @param rhs Value to copy from
168169
* @return Reference to copied object
169170
*/
170-
template <typename X, cuda::std::enable_if_t< cuda::std::is_same_v<cuda::std::decay<X>, cuda::std::complex<float>> ||
171-
cuda::std::is_same_v<cuda::std::decay<X>, cuda::std::complex<double>>, bool> = true>
171+
template <typename X>
172+
requires (cuda::std::is_same_v<cuda::std::decay_t<X>, cuda::std::complex<float>> ||
173+
cuda::std::is_same_v<cuda::std::decay_t<X>, cuda::std::complex<double>>)
172174
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ matxHalfComplex<T> &
173175
operator=(X rhs)
174176
{
175-
matxHalfComplex<X> tmp{rhs};
176-
x = static_cast<T>(tmp.real());
177-
y = static_cast<T>(tmp.imag());
177+
x = static_cast<T>(rhs.real());
178+
y = static_cast<T>(rhs.imag());
178179
return *this;
179180
}
180181

include/matx/core/iterator.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,12 @@ struct RandomOperatorIterator {
6565
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(const OperatorType &t, stride_type offset) : t_(t), offset_(offset) {}
6666
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(OperatorType &&t, stride_type offset) : t_(t), offset_(offset) {}
6767

68-
template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
68+
template<typename T = OperatorType>
69+
requires (!std::is_same_v<T, OperatorBaseType>)
6970
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}
7071

71-
template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
72+
template<typename T = OperatorType>
73+
requires (!std::is_same_v<T, OperatorBaseType>)
7274
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorIterator(OperatorBaseType &&t, stride_type offset) : t_(t), offset_(offset) {}
7375

7476
/**
@@ -208,10 +210,12 @@ struct RandomOperatorOutputIterator {
208210
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(const OperatorType &t, stride_type offset) : t_(t), offset_(offset) {}
209211
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(OperatorType &&t, stride_type offset) : t_(t), offset_(offset) {}
210212

211-
template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
213+
template<typename T = OperatorType>
214+
requires (!std::is_same_v<T, OperatorBaseType>)
212215
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}
213216

214-
template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
217+
template<typename T = OperatorType>
218+
requires (!std::is_same_v<T, OperatorBaseType>)
215219
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorOutputIterator(OperatorBaseType &&t, stride_type offset) : t_(t), offset_(offset) {}
216220

217221
[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator*()
@@ -354,10 +358,12 @@ struct RandomOperatorThrustIterator {
354358
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(const OperatorType &t, stride_type offset) : t_(t), offset_(offset) {}
355359
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(OperatorType &&t, stride_type offset) : t_(t), offset_(offset) {}
356360

357-
template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
361+
template<typename T = OperatorType>
362+
requires (!std::is_same_v<T, OperatorBaseType>)
358363
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(const OperatorBaseType &t, stride_type offset) : t_(t), offset_(offset) {}
359364

360-
template<typename T = OperatorType, std::enable_if_t<!std::is_same<T, OperatorBaseType>::value, bool> = true>
365+
template<typename T = OperatorType>
366+
requires (!std::is_same_v<T, OperatorBaseType>)
361367
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ RandomOperatorThrustIterator(OperatorBaseType &&t, stride_type offset) : t_(t), offset_(offset) {}
362368

363369
[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator*() const

include/matx/core/make_tensor.h

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ auto make_tensor( const index_t (&shape)[RANK],
7575
* @param shape Shape specification for the tensor
7676
* @returns New tensor
7777
**/
78-
template <typename T, typename ShapeType,
79-
std::enable_if_t<!is_matx_descriptor_v<ShapeType> && !std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
78+
template <typename T, typename ShapeType>
79+
requires (!is_matx_descriptor<ShapeType> && !std::is_array_v<remove_cvref_t<ShapeType>>)
8080
auto make_tensor(Storage<T> storage, ShapeType &&shape) {
8181
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
8282

@@ -95,7 +95,8 @@ auto make_tensor(Storage<T> storage, ShapeType &&shape) {
9595
* @param space memory space to allocate in. Default is manged memory.
9696
* @param stream cuda stream to allocate in (only applicable to async allocations)
9797
**/
98-
template <typename TensorType, std::enable_if_t< is_tensor_view_v<TensorType>, bool> = true>
98+
template <typename TensorType>
99+
requires is_tensor<TensorType>
99100
void make_tensor( TensorType &tensor,
100101
const index_t (&shape)[TensorType::Rank()],
101102
matxMemorySpace_t space = MATX_MANAGED_MEMORY,
@@ -156,10 +157,10 @@ auto make_tensor_p( const index_t (&shape)[RANK],
156157
* @returns New tensor
157158
*
158159
**/
159-
template <typename T, typename ShapeType,
160-
std::enable_if_t< !is_matx_shape_v<ShapeType> &&
161-
!is_matx_descriptor_v<ShapeType> &&
162-
!std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
160+
template <typename T, typename ShapeType>
161+
requires (!is_matx_shape<ShapeType> &&
162+
!is_matx_descriptor<ShapeType> &&
163+
!std::is_array_v<remove_cvref_t<ShapeType>>)
163164
auto make_tensor( ShapeType &&shape,
164165
matxMemorySpace_t space = MATX_MANAGED_MEMORY,
165166
cudaStream_t stream = 0) {
@@ -191,8 +192,8 @@ auto make_tensor( ShapeType &&shape,
191192
* @returns New tensor
192193
*
193194
**/
194-
template <typename TensorType,typename ShapeType,
195-
std::enable_if_t<is_tensor_view_v<TensorType> && !std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
195+
template <typename TensorType, typename ShapeType>
196+
requires (is_tensor<TensorType> && !std::is_array_v<remove_cvref_t<ShapeType>>)
196197
auto make_tensor( TensorType &tensor,
197198
ShapeType &&shape,
198199
matxMemorySpace_t space = MATX_MANAGED_MEMORY,
@@ -218,9 +219,9 @@ auto make_tensor( TensorType &tensor,
218219
* @returns Pointer to new tensor
219220
*
220221
**/
221-
template <typename T, typename ShapeType,
222-
std::enable_if_t< !is_matx_shape_v<ShapeType> &&
223-
!std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
222+
template <typename T, typename ShapeType>
223+
requires (!is_matx_shape<ShapeType> &&
224+
!std::is_array_v<remove_cvref_t<ShapeType>>)
224225
auto make_tensor_p( ShapeType &&shape,
225226
matxMemorySpace_t space = MATX_MANAGED_MEMORY,
226227
cudaStream_t stream = 0) {
@@ -266,8 +267,8 @@ auto make_tensor( [[maybe_unused]] const std::initializer_list<detail::no_size_t
266267
* @returns New tensor
267268
*
268269
**/
269-
template <typename TensorType,
270-
std::enable_if_t<is_tensor_view_v<TensorType>, bool> = true>
270+
template <typename TensorType>
271+
requires is_tensor<TensorType>
271272
auto make_tensor( TensorType &tensor,
272273
matxMemorySpace_t space = MATX_MANAGED_MEMORY,
273274
cudaStream_t stream = 0) {
@@ -339,8 +340,8 @@ auto make_tensor( T *data,
339340
* Shape of tensor
340341
* @returns New tensor
341342
**/
342-
template <typename TensorType,
343-
std::enable_if_t<is_tensor_view_v<TensorType>, bool> = true>
343+
template <typename TensorType>
344+
requires is_tensor<TensorType>
344345
auto make_tensor( TensorType &tensor,
345346
typename TensorType::value_type *data,
346347
const index_t (&shape)[TensorType::Rank()]) {
@@ -370,8 +371,8 @@ auto make_tensor( TensorType &tensor,
370371
* If this class owns memory of data
371372
* @returns New tensor
372373
**/
373-
template <typename T, typename ShapeType,
374-
std::enable_if_t<!is_matx_descriptor_v<ShapeType> && !std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
374+
template <typename T, typename ShapeType>
375+
requires (!is_matx_descriptor<ShapeType> && !std::is_array_v<remove_cvref_t<ShapeType>>)
375376
auto make_tensor( T *data,
376377
ShapeType &&shape,
377378
bool owning = false) {
@@ -398,8 +399,8 @@ auto make_tensor( T *data,
398399
* Shape of tensor
399400
* @returns New tensor
400401
**/
401-
template <typename TensorType,
402-
std::enable_if_t<is_tensor_view_v<TensorType>, bool> = true>
402+
template <typename TensorType>
403+
requires is_tensor<TensorType>
403404
auto make_tensor( TensorType &tensor,
404405
typename TensorType::value_type *data,
405406
typename TensorType::shape_container &&shape) {
@@ -440,8 +441,8 @@ auto make_tensor( T *ptr,
440441
* Pointer to data
441442
* @returns New tensor
442443
**/
443-
template <typename TensorType,
444-
std::enable_if_t<is_tensor_view_v<TensorType>, bool> = true>
444+
template <typename TensorType>
445+
requires is_tensor<TensorType>
445446
auto make_tensor( TensorType &tensor,
446447
typename TensorType::value_type *ptr) {
447448
MATX_LOG_DEBUG("make_tensor(tensor&, ptr, 0D): ptr={}", reinterpret_cast<void*>(ptr));
@@ -462,8 +463,8 @@ auto make_tensor( TensorType &tensor,
462463
* If this class owns memory of data
463464
* @returns New tensor
464465
**/
465-
template <typename T, typename ShapeType,
466-
std::enable_if_t<!is_matx_descriptor_v<ShapeType> && !std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
466+
template <typename T, typename ShapeType>
467+
requires (!is_matx_descriptor<ShapeType> && !std::is_array_v<remove_cvref_t<ShapeType>>)
467468
auto make_tensor_p( T *const data,
468469
ShapeType &&shape,
469470
bool owning = false) {
@@ -515,9 +516,9 @@ auto make_tensor( const index_t (&shape)[RANK],
515516
* Custom allocator (PMR allocator, custom allocator pointer, etc.)
516517
* @returns New tensor
517518
**/
518-
template <typename T, typename ShapeType, typename Allocator,
519-
std::enable_if_t<!is_matx_shape_v<ShapeType> && !is_matx_descriptor_v<ShapeType> &&
520-
!std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
519+
template <typename T, typename ShapeType, typename Allocator>
520+
requires (!is_matx_shape<ShapeType> && !is_matx_descriptor<ShapeType> &&
521+
!std::is_array_v<remove_cvref_t<ShapeType>>)
521522
auto make_tensor( ShapeType &&shape,
522523
Allocator&& alloc) {
523524
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
@@ -540,8 +541,8 @@ auto make_tensor( ShapeType &&shape,
540541
* @param alloc
541542
* Custom allocator (PMR allocator, custom allocator pointer, etc.)
542543
**/
543-
template <typename TensorType, typename Allocator,
544-
std::enable_if_t<is_tensor_view_v<TensorType>, bool> = true>
544+
template <typename TensorType, typename Allocator>
545+
requires is_tensor<TensorType>
545546
void make_tensor( TensorType &tensor,
546547
const index_t (&shape)[TensorType::Rank()],
547548
Allocator&& alloc) {
@@ -569,9 +570,9 @@ void make_tensor( TensorType &tensor,
569570
* @param alloc
570571
* Custom allocator (PMR allocator, custom allocator pointer, etc.)
571572
**/
572-
template <typename TensorType, typename ShapeType, typename Allocator,
573-
std::enable_if_t<is_tensor_view_v<TensorType> &&
574-
!std::is_array_v<typename remove_cvref<ShapeType>::type>, bool> = true>
573+
template <typename TensorType, typename ShapeType, typename Allocator>
574+
requires (is_tensor<TensorType> &&
575+
!std::is_array_v<remove_cvref_t<ShapeType>>)
575576
void make_tensor( TensorType &tensor,
576577
ShapeType &&shape,
577578
Allocator&& alloc) {
@@ -595,7 +596,8 @@ void make_tensor( TensorType &tensor,
595596
* If this class owns memory of data
596597
* @returns New tensor
597598
**/
598-
template <typename T, typename D, std::enable_if_t<is_matx_descriptor_v<typename remove_cvref<D>::type>, bool> = true>
599+
template <typename T, typename D>
600+
requires is_matx_descriptor<remove_cvref_t<D>>
599601
auto make_tensor( T* const data,
600602
D &&desc,
601603
bool owning = false) {
@@ -620,8 +622,8 @@ auto make_tensor( T* const data,
620622
* Tensor descriptor (tensor_desc_t)
621623
* @returns New tensor
622624
**/
623-
template <typename TensorType,
624-
std::enable_if_t<is_tensor_view_v<TensorType>, bool> = true>
625+
template <typename TensorType>
626+
requires is_tensor<TensorType>
625627
auto make_tensor( TensorType &tensor,
626628
typename TensorType::value_type* const data,
627629
typename TensorType::desc_type &&desc) {
@@ -642,7 +644,8 @@ auto make_tensor( TensorType &tensor,
642644
* @param stream cuda stream to allocate in (only applicable to async allocations)
643645
* @returns New tensor
644646
**/
645-
template <typename T, typename D, std::enable_if_t<is_matx_descriptor_v<typename remove_cvref<D>::type>, bool> = true>
647+
template <typename T, typename D>
648+
requires is_matx_descriptor<remove_cvref_t<D>>
646649
auto make_tensor( D &&desc,
647650
matxMemorySpace_t space = MATX_MANAGED_MEMORY,
648651
cudaStream_t stream = 0) {
@@ -666,8 +669,8 @@ auto make_tensor( D &&desc,
666669
* @param stream cuda stream to allocate in (only applicable to async allocations)
667670
* @returns New tensor
668671
**/
669-
template <typename TensorType,
670-
std::enable_if_t<is_tensor_view_v<TensorType> && is_matx_descriptor_v<typename TensorType::desc_type>, bool> = true>
672+
template <typename TensorType>
673+
requires (is_tensor<TensorType> && is_matx_descriptor<typename TensorType::desc_type>)
671674
auto make_tensor( TensorType &&tensor,
672675
typename TensorType::desc_type &&desc,
673676
matxMemorySpace_t space = MATX_MANAGED_MEMORY,
@@ -731,8 +734,8 @@ auto make_tensor( T *const data,
731734
* Strides of tensor
732735
* @returns New tensor
733736
**/
734-
template <typename TensorType,
735-
std::enable_if_t<is_tensor_view_v<TensorType>, bool> = true>
737+
template <typename TensorType>
738+
requires is_tensor<TensorType>
736739
auto make_tensor( TensorType &tensor,
737740
typename TensorType::value_type *const data,
738741
const index_t (&shape)[TensorType::Rank()],
@@ -771,8 +774,8 @@ auto make_static_tensor() {
771774
return tensor_t<T, desc.Rank(), decltype(desc)>{std::move(storage), std::move(desc)};
772775
}
773776

774-
template <typename TensorType,
775-
std::enable_if_t<is_tensor_view_v<TensorType>, bool> = true>
777+
template <typename TensorType>
778+
requires is_tensor<TensorType>
776779
auto make_tensor( TensorType &tensor,
777780
const DLManagedTensor dlp_tensor) {
778781
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

0 commit comments

Comments
 (0)