Skip to content

Commit 2bc41d5

Browse files
Enables sort_by_key with long double keys on the host (#7869)
- Sets use_primitive_sort to false for long double - Implements TestSortByKeyLongDouble to verify that sorting using long double keys compiles and runs on the host - typo fix: __message__ --> message Fixes #7865
1 parent ef3072b commit 2bc41d5

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

libcudacxx/include/cuda/std/__cccl/assert.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ void __assert_fail(const char* __assertion, const char* __file, unsigned int __l
8989
# if _CCCL_OS(APPLE)
9090
# define _CCCL_ASSERT_IMPL_HOST(expression, message) \
9191
_CCCL_BUILTIN_EXPECT(static_cast<bool>(expression), 1) \
92-
? (void) 0 : __assert_rtn(__func__, __FILE__, __LINE__, __message__)
92+
? (void) 0 : __assert_rtn(__func__, __FILE__, __LINE__, message)
9393
# elif _CCCL_OS(ANDROID)
9494
# define _CCCL_ASSERT_IMPL_HOST(expression, message) \
9595
_CCCL_BUILTIN_EXPECT(static_cast<bool>(expression), 1) \

thrust/testing/sort_by_key.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,16 @@ void TestSortByKeyBoolDescending()
137137
ASSERT_EQUAL(h_values, d_values);
138138
}
139139
DECLARE_UNITTEST(TestSortByKeyBoolDescending);
140+
141+
void TestSortByKeyLongDouble()
142+
{
143+
thrust::host_vector<long double> h_keys = {10.0L, 9.0L, 8.0L, 7.0L, 6.0L, 5.0L, 4.0L, 3.0L, 2.0L, 1.0L};
144+
thrust::host_vector<int> h_values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
145+
thrust::host_vector<int> h_values_expected = {10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
146+
147+
thrust::sort_by_key(h_keys.begin(), h_keys.end(), h_values.begin());
148+
149+
ASSERT_EQUAL(thrust::is_sorted(h_keys.begin(), h_keys.end()), true);
150+
ASSERT_EQUAL(h_values, h_values_expected);
151+
}
152+
DECLARE_UNITTEST(TestSortByKeyLongDouble);

thrust/thrust/system/detail/sequential/sort.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ namespace sort_detail
3434
template <typename KeyType, typename Compare>
3535
inline constexpr bool use_primitive_sort =
3636
::cuda::std::is_arithmetic_v<KeyType>
37+
// RadixEncoder is not specialized for long double
38+
&& !::cuda::std::is_same_v<KeyType, long double>
39+
// radix_sort_dispatcher only supports key sizes of 1, 2, 4, or 8 bytes
40+
&& (sizeof(KeyType) == 1 || sizeof(KeyType) == 2 || sizeof(KeyType) == 4 || sizeof(KeyType) == 8)
3741
&& (::cuda::std::is_same_v<Compare, ::cuda::std::less<KeyType>>
3842
|| ::cuda::std::is_same_v<Compare, ::cuda::std::greater<KeyType>>);
3943

0 commit comments

Comments
 (0)