Skip to content

Commit 59f726a

Browse files
Move Thrust to new detail::scan::dispatch*
This currently makes thrust.test.scan fail, which needs to be investigated, since it worked before in the presence of the warpspeed implementation
1 parent 203e807 commit 59f726a

File tree

3 files changed

+67
-42
lines changed

3 files changed

+67
-42
lines changed

cub/cub/device/dispatch/dispatch_scan.cuh

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ struct DeviceScanKernelSource
140140
template <typename LegacyActivePolicy>
141141
_CCCL_API constexpr auto convert_policy() -> scan_policy
142142
{
143+
// this does not convert any warpspeed policy data, which is fine because we merged warpspeed scan during the CCCL 3.4
144+
// development cycle, so it never had user exposure through the policy_hub, and we can just only support it through
145+
// the policy_selector, which CUB and Thrust.
143146
using scan_policy_t = typename LegacyActivePolicy::ScanPolicyT;
144147
return scan_policy{
145148
scan_policy_t::BLOCK_THREADS,
@@ -152,14 +155,34 @@ _CCCL_API constexpr auto convert_policy() -> scan_policy
152155
}
153156

154157
// TODO(griwes): remove in CCCL 4.0 when we drop the scan dispatcher after publishing the tuning API
155-
template <typename PolicyHub, typename InputValueT, typename OutputValueT, typename AccumT>
158+
template <typename PolicyHub>
156159
struct policy_selector_from_hub
157160
{
158-
// Called from device code during dispatch, and from host code when clang-cuda evaluates
159-
// scan_policy_selector concept checks.
160-
_CCCL_API constexpr auto operator()(::cuda::arch_id /*arch*/) const -> scan_policy
161+
private:
162+
struct extract_policy_dispatch_t
161163
{
162-
return convert_policy<typename PolicyHub::MaxPolicy::ActivePolicy>();
164+
scan_policy& policy;
165+
166+
template <typename ActivePolicyT>
167+
_CCCL_API constexpr cudaError_t Invoke()
168+
{
169+
policy = convert_policy<ActivePolicyT>();
170+
return cudaSuccess;
171+
}
172+
};
173+
174+
// Called from host (compile-time) and device code during dispatch
175+
_CCCL_API constexpr auto operator()(::cuda::arch_id arch) const -> scan_policy
176+
{
177+
NV_IF_ELSE_TARGET(NV_IS_HOST,
178+
({
179+
const int ptx_version = static_cast<int>(arch) * 10;
180+
scan_policy policy{};
181+
extract_policy_dispatch_t dispatch{policy};
182+
PolicyHub::MaxPolicy::Invoke(ptx_version, dispatch);
183+
return policy;
184+
}),
185+
({ return convert_policy<typename PolicyHub::MaxPolicy::ActivePolicy>(); }));
163186
}
164187
};
165188
} // namespace detail::scan
@@ -208,8 +231,7 @@ template <
208231
typename PolicyHub = detail::scan::
209232
policy_hub<detail::it_value_t<InputIteratorT>, detail::it_value_t<OutputIteratorT>, AccumT, OffsetT, ScanOpT>,
210233
typename KernelSource = detail::scan::DeviceScanKernelSource<
211-
detail::scan::
212-
policy_selector_from_hub<PolicyHub, detail::it_value_t<InputIteratorT>, detail::it_value_t<OutputIteratorT>, AccumT>,
234+
detail::scan::policy_selector_from_hub<PolicyHub>,
213235
THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator_t<InputIteratorT>,
214236
THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator_t<OutputIteratorT>,
215237
ScanOpT,

thrust/thrust/system/cuda/detail/dispatch.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ THRUST_NAMESPACE_END
7777
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
7878
_THRUST_INDEX_TYPE_DISPATCH(std::int64_t, status, call, count, arguments)
7979

80+
//! @brief Always dispatches to unsigned 64 bit offset version of an algorithm
81+
# define THRUST_UNSIGNED_INDEX_TYPE_DISPATCH(status, call, count, arguments) \
82+
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
83+
_THRUST_INDEX_TYPE_DISPATCH(std::uint64_t, status, call, count, arguments)
84+
8085
//! Like \ref THRUST_INDEX_TYPE_DISPATCH but with two counts
8186
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH(status, call, count1, count2, arguments) \
8287
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW2(count1, count2) \
@@ -124,6 +129,12 @@ THRUST_NAMESPACE_END
124129
_THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW(std::int32_t, count) \
125130
_THRUST_INDEX_TYPE_DISPATCH(std::int32_t, status, call, count, arguments)
126131

132+
//! @brief Always dispatches to unsigned 32 bit offset version of an algorithm but throws if count would overflow
133+
# define THRUST_UNSIGNED_INDEX_TYPE_DISPATCH(status, call, count, arguments) \
134+
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
135+
_THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW(std::uint32_t, count) \
136+
_THRUST_INDEX_TYPE_DISPATCH(std::uint32_t, status, call, count, arguments)
137+
127138
//! Like \ref THRUST_INDEX_TYPE_DISPATCH but with two counts
128139
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH(status, call, count1, count2, arguments) \
129140
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW2(count1, count2) \
@@ -161,6 +172,16 @@ THRUST_NAMESPACE_END
161172
else \
162173
_THRUST_INDEX_TYPE_DISPATCH(std::int64_t, status, call, count, arguments)
163174

175+
//! Dispatch between unsigned 32-bit and 64-bit index_type based versions of the same algorithm implementation. This
176+
//! version assumes that callables for both branches consist of the same tokens, and is intended to be used with
177+
//! Thrust-style dispatch interfaces, that always deduce the size type from the arguments.
178+
# define THRUST_UNSIGNED_INDEX_TYPE_DISPATCH(status, call, count, arguments) \
179+
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
180+
if _THRUST_INDEX_TYPE_DISPATCH_SELECT (std::int32_t, count) \
181+
_THRUST_INDEX_TYPE_DISPATCH(std::uint32_t, status, call, count, arguments) \
182+
else \
183+
_THRUST_INDEX_TYPE_DISPATCH(std::uint64_t, status, call, count, arguments)
184+
164185
//! Dispatch between 32-bit and 64-bit index_type based versions of the same algorithm implementation. This version
165186
//! assumes that callables for both branches consist of the same tokens, and is intended to be used with Thrust-style
166187
//! dispatch interfaces, that always deduce the size type from the arguments.

thrust/thrust/system/cuda/detail/scan.h

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,7 @@ template <typename Derived, typename InputIt, typename Size, typename OutputIt,
4040
_CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
4141
thrust::cuda_cub::execution_policy<Derived>& policy, InputIt first, Size num_items, OutputIt result, ScanOp scan_op)
4242
{
43-
using AccumT = thrust::detail::it_value_t<InputIt>;
44-
using Dispatch32 = cub::DispatchScan<InputIt, OutputIt, ScanOp, cub::NullType, std::uint32_t, AccumT>;
45-
using Dispatch64 = cub::DispatchScan<InputIt, OutputIt, ScanOp, cub::NullType, std::uint64_t, AccumT>;
46-
43+
using AccumT = thrust::detail::it_value_t<InputIt>;
4744
cudaStream_t stream = thrust::cuda_cub::stream(policy);
4845
cudaError_t status;
4946

@@ -56,10 +53,9 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
5653
// Determine temporary storage requirements:
5754
size_t tmp_size = 0;
5855
{
59-
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(
56+
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH(
6057
status,
61-
Dispatch32::Dispatch,
62-
Dispatch64::Dispatch,
58+
cub::detail::scan::dispatch_with_accum<AccumT>,
6359
num_items,
6460
(nullptr, tmp_size, first, result, scan_op, cub::NullType{}, num_items_fixed, stream));
6561
thrust::cuda_cub::throw_on_error(
@@ -72,10 +68,9 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
7268
{
7369
// Allocate temporary storage:
7470
thrust::detail::temporary_array<std::uint8_t, Derived> tmp{policy, tmp_size};
75-
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(
71+
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH(
7672
status,
77-
Dispatch32::Dispatch,
78-
Dispatch64::Dispatch,
73+
cub::detail::scan::dispatch_with_accum<AccumT>,
7974
num_items,
8075
(tmp.data().get(), tmp_size, first, result, scan_op, cub::NullType{}, num_items_fixed, stream));
8176
thrust::cuda_cub::throw_on_error(status, "after dispatching inclusive_scan kernel");
@@ -96,15 +91,9 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
9691
InitValueT init,
9792
ScanOp scan_op)
9893
{
99-
using InputValueT = cub::detail::InputValue<InitValueT>;
100-
using ValueT = cub::detail::it_value_t<InputIt>;
101-
using AccumT = ::cuda::std::__accumulator_t<ScanOp, ValueT, InitValueT>;
102-
103-
using Dispatch32 =
104-
cub::DispatchScan<InputIt, OutputIt, ScanOp, InputValueT, std::uint32_t, AccumT, cub::ForceInclusive::Yes>;
105-
using Dispatch64 =
106-
cub::DispatchScan<InputIt, OutputIt, ScanOp, InputValueT, std::uint64_t, AccumT, cub::ForceInclusive::Yes>;
107-
94+
using InputValueT = cub::detail::InputValue<InitValueT>;
95+
using ValueT = cub::detail::it_value_t<InputIt>;
96+
using AccumT = ::cuda::std::__accumulator_t<ScanOp, ValueT, InitValueT>;
10897
cudaStream_t stream = thrust::cuda_cub::stream(policy);
10998
cudaError_t status;
11099

@@ -117,10 +106,9 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
117106
// Determine temporary storage requirements:
118107
size_t tmp_size = 0;
119108
{
120-
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(
109+
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH(
121110
status,
122-
Dispatch32::Dispatch,
123-
Dispatch64::Dispatch,
111+
(cub::detail::scan::dispatch_with_accum<AccumT, cub::ForceInclusive::Yes>),
124112
num_items,
125113
(nullptr, tmp_size, first, result, scan_op, InputValueT(init), num_items_fixed, stream));
126114
thrust::cuda_cub::throw_on_error(
@@ -133,10 +121,9 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
133121
{
134122
// Allocate temporary storage:
135123
thrust::detail::temporary_array<std::uint8_t, Derived> tmp{policy, tmp_size};
136-
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(
124+
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH(
137125
status,
138-
Dispatch32::Dispatch,
139-
Dispatch64::Dispatch,
126+
(cub::detail::scan::dispatch_with_accum<AccumT, cub::ForceInclusive::Yes>),
140127
num_items,
141128
(tmp.data().get(), tmp_size, first, result, scan_op, InputValueT(init), num_items_fixed, stream));
142129
thrust::cuda_cub::throw_on_error(status, "after dispatching inclusive_scan kernel");
@@ -157,10 +144,7 @@ _CCCL_HOST_DEVICE OutputIt exclusive_scan_n_impl(
157144
InitValueT init,
158145
ScanOp scan_op)
159146
{
160-
using InputValueT = cub::detail::InputValue<InitValueT>;
161-
using Dispatch32 = cub::DispatchScan<InputIt, OutputIt, ScanOp, InputValueT, std::uint32_t, InitValueT>;
162-
using Dispatch64 = cub::DispatchScan<InputIt, OutputIt, ScanOp, InputValueT, std::uint64_t, InitValueT>;
163-
147+
using InputValueT = cub::detail::InputValue<InitValueT>;
164148
cudaStream_t stream = thrust::cuda_cub::stream(policy);
165149
cudaError_t status;
166150

@@ -173,10 +157,9 @@ _CCCL_HOST_DEVICE OutputIt exclusive_scan_n_impl(
173157
// Determine temporary storage requirements:
174158
size_t tmp_size = 0;
175159
{
176-
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(
160+
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH(
177161
status,
178-
Dispatch32::Dispatch,
179-
Dispatch64::Dispatch,
162+
cub::detail::scan::dispatch_with_accum<InitValueT>,
180163
num_items,
181164
(nullptr, tmp_size, first, result, scan_op, InputValueT(init), num_items_fixed, stream));
182165
thrust::cuda_cub::throw_on_error(
@@ -189,10 +172,9 @@ _CCCL_HOST_DEVICE OutputIt exclusive_scan_n_impl(
189172
{
190173
// Allocate temporary storage:
191174
thrust::detail::temporary_array<std::uint8_t, Derived> tmp{policy, tmp_size};
192-
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(
175+
THRUST_UNSIGNED_INDEX_TYPE_DISPATCH(
193176
status,
194-
Dispatch32::Dispatch,
195-
Dispatch64::Dispatch,
177+
cub::detail::scan::dispatch_with_accum<InitValueT>,
196178
num_items,
197179
(tmp.data().get(), tmp_size, first, result, scan_op, InputValueT(init), num_items_fixed, stream));
198180
thrust::cuda_cub::throw_on_error(status, "after dispatching exclusive_scan kernel");

0 commit comments

Comments
 (0)