|
46 | 46 | #include "matx/core/tensor.h" |
47 | 47 | #include "matx/core/iterator.h" |
48 | 48 | #include "matx/core/operator_utils.h" |
| 49 | +#include "matx/core/type_utils_both.h" |
49 | 50 | #include "matx/transforms/cccl_iterators.h" |
50 | 51 |
|
51 | 52 |
|
@@ -1078,17 +1079,53 @@ inline void ExecSort(OutputTensor &a_out, |
1078 | 1079 | #ifdef __CUDACC__ |
1079 | 1080 | struct CustomArgMaxCmp |
1080 | 1081 | { |
| 1082 | +private: |
1081 | 1083 | template <typename T> |
1082 | | - __MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ T operator()(const T &a, const T &b) const { |
1083 | | - return thrust::get<1>(a) < thrust::get<1>(b) ? b : a; |
| 1084 | + __MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ auto ToTuple(const T &value) const |
| 1085 | + { |
| 1086 | + if constexpr (is_tuple_c<T>) |
| 1087 | + { |
| 1088 | + return value; |
| 1089 | + } |
| 1090 | + else |
| 1091 | + { |
| 1092 | + return cuda::std::make_tuple(value.first, value.second); |
| 1093 | + } |
| 1094 | + } |
| 1095 | + |
| 1096 | +public: |
| 1097 | + template <typename InitT, typename InputT> |
| 1098 | + __MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ auto operator()(const InitT &a, const InputT &b) const |
| 1099 | + { |
| 1100 | + auto at = ToTuple(a); |
| 1101 | + auto bt = ToTuple(b); |
| 1102 | + return thrust::get<1>(at) < thrust::get<1>(bt) ? bt : at; |
1084 | 1103 | } |
1085 | 1104 | }; |
1086 | 1105 |
|
1087 | 1106 | struct CustomArgMinCmp |
1088 | 1107 | { |
| 1108 | +private: |
1089 | 1109 | template <typename T> |
1090 | | - __MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ T operator()(const T &a, const T &b) const { |
1091 | | - return thrust::get<1>(a) >= thrust::get<1>(b) ? b : a; |
| 1110 | + __MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ auto ToTuple(const T &value) const |
| 1111 | + { |
| 1112 | + if constexpr (is_tuple_c<T>) |
| 1113 | + { |
| 1114 | + return value; |
| 1115 | + } |
| 1116 | + else |
| 1117 | + { |
| 1118 | + return cuda::std::make_tuple(value.first, value.second); |
| 1119 | + } |
| 1120 | + } |
| 1121 | + |
| 1122 | +public: |
| 1123 | + template <typename InitT, typename InputT> |
| 1124 | + __MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ auto operator()(const InitT &a, const InputT &b) const |
| 1125 | + { |
| 1126 | + auto at = ToTuple(a); |
| 1127 | + auto bt = ToTuple(b); |
| 1128 | + return thrust::get<1>(at) >= thrust::get<1>(bt) ? bt : at; |
1092 | 1129 | } |
1093 | 1130 | }; |
1094 | 1131 |
|
|
0 commit comments