Skip to content

Commit afb1aeb

Browse files
authored
Fix argmin/argmax comparators with CUB tuple accumulators (#1096)
1 parent f32b0cd commit afb1aeb

File tree

1 file changed

+41
-4
lines changed
  • include/matx/transforms

1 file changed

+41
-4
lines changed

include/matx/transforms/cub.h

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include "matx/core/tensor.h"
4747
#include "matx/core/iterator.h"
4848
#include "matx/core/operator_utils.h"
49+
#include "matx/core/type_utils_both.h"
4950
#include "matx/transforms/cccl_iterators.h"
5051

5152

@@ -1078,17 +1079,53 @@ inline void ExecSort(OutputTensor &a_out,
10781079
#ifdef __CUDACC__
10791080
struct CustomArgMaxCmp
10801081
{
1082+
private:
10811083
template <typename T>
1082-
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ T operator()(const T &a, const T &b) const {
1083-
return thrust::get<1>(a) < thrust::get<1>(b) ? b : a;
1084+
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ auto ToTuple(const T &value) const
1085+
{
1086+
if constexpr (is_tuple_c<T>)
1087+
{
1088+
return value;
1089+
}
1090+
else
1091+
{
1092+
return cuda::std::make_tuple(value.first, value.second);
1093+
}
1094+
}
1095+
1096+
public:
1097+
template <typename InitT, typename InputT>
1098+
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ auto operator()(const InitT &a, const InputT &b) const
1099+
{
1100+
auto at = ToTuple(a);
1101+
auto bt = ToTuple(b);
1102+
return thrust::get<1>(at) < thrust::get<1>(bt) ? bt : at;
10841103
}
10851104
};
10861105

10871106
struct CustomArgMinCmp
10881107
{
1108+
private:
10891109
template <typename T>
1090-
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ T operator()(const T &a, const T &b) const {
1091-
return thrust::get<1>(a) >= thrust::get<1>(b) ? b : a;
1110+
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ auto ToTuple(const T &value) const
1111+
{
1112+
if constexpr (is_tuple_c<T>)
1113+
{
1114+
return value;
1115+
}
1116+
else
1117+
{
1118+
return cuda::std::make_tuple(value.first, value.second);
1119+
}
1120+
}
1121+
1122+
public:
1123+
template <typename InitT, typename InputT>
1124+
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ auto operator()(const InitT &a, const InputT &b) const
1125+
{
1126+
auto at = ToTuple(a);
1127+
auto bt = ToTuple(b);
1128+
return thrust::get<1>(at) >= thrust::get<1>(bt) ? bt : at;
10921129
}
10931130
};
10941131

0 commit comments

Comments
 (0)