Skip to content

Commit 04f83db

Browse files
authored
clean up snr.cuh and a bunch of nvexec algorithms (NVIDIA#1926)
* clean up snr.cuh and a bunch of nvexec algorithms
1 parent b6f8a0e commit 04f83db

File tree

16 files changed

+336
-342
lines changed

16 files changed

+336
-342
lines changed

examples/nvexec/maxwell/snr.cuh

Lines changed: 142 additions & 241 deletions
Large diffs are not rendered by default.

include/nvexec/stream/bulk.cuh

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -401,28 +401,35 @@ namespace nv::execution::_strm
401401
template <>
402402
struct transform_sender_for<STDEXEC::bulk_t>
403403
{
404-
template <class Env, class Data, stream_completing_sender<Env> Sender>
404+
template <class Env, class Data, class Sender>
405405
auto operator()(Env const & env, __ignore, Data data, Sender&& sndr) const
406406
{
407-
auto [policy, shape, fun] = static_cast<Data&&>(data);
408-
using shape_t = decltype(shape);
409-
using fun_t = decltype(fun);
410-
auto sched = get_completion_scheduler<set_value_t>(get_env(sndr), env);
411-
if constexpr (__std::same_as<decltype(sched), stream_scheduler>)
407+
if constexpr (stream_completing_sender<Sender, Env>)
412408
{
413-
// Use the bulk sender for a single GPU
414-
using _sender_t = bulk_sender<__decay_t<Sender>, shape_t, fun_t>;
415-
return _sender_t{{}, static_cast<Sender&&>(sndr), shape, static_cast<fun_t&&>(fun)};
409+
auto [policy, shape, fun] = static_cast<Data&&>(data);
410+
using shape_t = decltype(shape);
411+
using fun_t = decltype(fun);
412+
auto sched = get_completion_scheduler<set_value_t>(get_env(sndr), env);
413+
if constexpr (__std::same_as<decltype(sched), stream_scheduler>)
414+
{
415+
// Use the bulk sender for a single GPU
416+
using _sender_t = bulk_sender<__decay_t<Sender>, shape_t, fun_t>;
417+
return _sender_t{{}, static_cast<Sender&&>(sndr), shape, static_cast<fun_t&&>(fun)};
418+
}
419+
else
420+
{
421+
// Use the bulk sender for a multiple GPUs
422+
using _sender_t = multi_gpu_bulk_sender<__decay_t<Sender>, shape_t, fun_t>;
423+
return _sender_t{{},
424+
sched.num_devices_,
425+
static_cast<Sender&&>(sndr),
426+
shape,
427+
static_cast<fun_t&&>(fun)};
428+
}
416429
}
417430
else
418431
{
419-
// Use the bulk sender for a multiple GPUs
420-
using _sender_t = multi_gpu_bulk_sender<__decay_t<Sender>, shape_t, fun_t>;
421-
return _sender_t{{},
422-
sched.num_devices_,
423-
static_cast<Sender&&>(sndr),
424-
shape,
425-
static_cast<fun_t&&>(fun)};
432+
return _strm::_no_stream_scheduler_in_env<STDEXEC::bulk_t, Sender, Env>();
426433
}
427434
}
428435
};

include/nvexec/stream/common.cuh

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ namespace nv::execution
5454
};
5555

5656
#if defined(__clang__) && defined(__CUDA__) && !defined(STDEXEC_CLANG_TIDY_INVOKED)
57-
__host__ inline auto get_device_type() noexcept -> device_type
57+
inline __host__ auto get_device_type() noexcept -> device_type
5858
{
5959
return device_type::host;
6060
}
@@ -64,7 +64,7 @@ namespace nv::execution
6464
return device_type::device;
6565
}
6666
#else
67-
__host__ __device__ inline auto get_device_type() noexcept -> device_type
67+
inline __host__ __device__ auto get_device_type() noexcept -> device_type
6868
{
6969
NV_IF_TARGET(NV_IS_HOST, (return device_type::host;), (return device_type::device;));
7070
}
@@ -75,6 +75,12 @@ namespace nv::execution
7575
return get_device_type() == device_type::device;
7676
}
7777

78+
struct stream_context;
79+
struct stream_domain;
80+
81+
struct CANNOT_DISPATCH_THIS_ALGORITHM_TO_THE_CUDA_STREAM_SCHEDULER;
82+
struct BECAUSE_THERE_IS_NO_CUDA_STREAM_SCHEDULER_IN_THE_ENVIRONMENT;
83+
7884
namespace _strm
7985
{
8086
// Used by stream_domain to late-customize senders for execution
@@ -84,30 +90,64 @@ namespace nv::execution
8490

8591
template <class Tag>
8692
struct apply_sender_for;
93+
94+
struct context;
95+
96+
template <class Scheduler, class Env>
97+
concept gpu_stream_scheduler =
98+
scheduler<Scheduler>
99+
&& __std::derived_from<__result_of<get_completion_domain<set_value_t>, Scheduler, Env>,
100+
stream_domain>
101+
&& requires(Scheduler sched) {
102+
{ sched.ctx_ } -> __decays_to<context>;
103+
};
104+
105+
template <class Sender, class Env>
106+
concept stream_completing_sender =
107+
sender<Sender>
108+
&& gpu_stream_scheduler<
109+
__result_of<get_completion_scheduler<set_value_t>, env_of_t<Sender>, Env>,
110+
Env>;
111+
112+
template <class Sender, class Env>
113+
concept has_stream_transform =
114+
STDEXEC::__callable<STDEXEC::__structured_apply_t,
115+
transform_sender_for<STDEXEC::tag_of_t<Sender>>,
116+
Sender,
117+
Env const &>;
118+
119+
template <class Sender, class Env>
120+
concept has_nothrow_stream_transform =
121+
STDEXEC::__nothrow_callable<STDEXEC::__structured_apply_t,
122+
transform_sender_for<STDEXEC::tag_of_t<Sender>>,
123+
Sender,
124+
Env const &>;
125+
126+
template <class Tag, class Sender, class Env>
127+
auto _no_stream_scheduler_in_env() noexcept
128+
{
129+
using namespace STDEXEC;
130+
return __not_a_sender<_WHAT_(CANNOT_DISPATCH_THIS_ALGORITHM_TO_THE_CUDA_STREAM_SCHEDULER),
131+
_WHY_(BECAUSE_THERE_IS_NO_CUDA_STREAM_SCHEDULER_IN_THE_ENVIRONMENT),
132+
_WHERE_(_IN_ALGORITHM_, Tag),
133+
_WITH_PRETTY_SENDER_<Sender>,
134+
_WITH_ENVIRONMENT_(Env)>{};
135+
}
87136
} // namespace _strm
88137
} // namespace nv::execution
89138

90139
namespace nvexec = nv::execution;
91140

92141
namespace nv::execution
93142
{
94-
struct stream_context;
95-
96143
// The stream_domain is how the stream scheduler customizes the sender algorithms. All of the
97144
// algorithms use the current scheduler's domain to transform senders before starting them.
98145
struct stream_domain : STDEXEC::default_domain
99146
{
100147
template <::exec::sender_for Sender, class Tag = STDEXEC::tag_of_t<Sender>, class Env>
101-
requires STDEXEC::__callable<STDEXEC::__structured_apply_t,
102-
_strm::transform_sender_for<Tag>,
103-
Sender,
104-
Env const &>
148+
requires _strm::has_stream_transform<Sender, Env>
105149
static auto transform_sender(STDEXEC::set_value_t, Sender&& sndr, Env const & env)
106-
noexcept(STDEXEC::__nothrow_callable<STDEXEC::__structured_apply_t,
107-
_strm::transform_sender_for<Tag>,
108-
Sender,
109-
Env const &>)
110-
150+
noexcept(_strm::has_nothrow_stream_transform<Sender, Env>)
111151
{
112152
return STDEXEC::__structured_apply(_strm::transform_sender_for<Tag>{},
113153
static_cast<Sender&&>(sndr),
@@ -278,15 +318,6 @@ namespace nv::execution
278318
template <class Sender, class Shape, class Fn>
279319
struct multi_gpu_bulk_sender;
280320

281-
template <class Scheduler, class Env>
282-
concept gpu_stream_scheduler =
283-
scheduler<Scheduler>
284-
&& __std::derived_from<__result_of<get_completion_domain<set_value_t>, Scheduler, Env>,
285-
stream_domain>
286-
&& requires(Scheduler sched) {
287-
{ sched.ctx_ } -> __decays_to<context>;
288-
};
289-
290321
struct stream_sender_base
291322
{
292323
using sender_concept = STDEXEC::sender_t;
@@ -907,13 +938,6 @@ namespace nv::execution
907938
ctx);
908939
}
909940

910-
template <class Sender, class Env>
911-
concept stream_completing_sender =
912-
sender<Sender>
913-
&& gpu_stream_scheduler<
914-
__result_of<get_completion_scheduler<set_value_t>, env_of_t<Sender>, Env>,
915-
Env>;
916-
917941
template <class InnerReceiverProvider, class OuterReceiver>
918942
using inner_receiver_t = __call_result_t<InnerReceiverProvider, opstate_base<OuterReceiver>&>;
919943

@@ -957,8 +981,10 @@ namespace nv::execution
957981
inline constexpr _strm::get_stream_t get_stream{};
958982

959983
#if CUDART_VERSION >= 13'00'0
960-
__host__ inline cudaError_t
961-
cudaMemPrefetchAsync(const void* dev_ptr, size_t count, int dst_device, cudaStream_t stream = 0)
984+
inline __host__ cudaError_t cudaMemPrefetchAsync(void const * dev_ptr,
985+
size_t count,
986+
int dst_device,
987+
cudaStream_t stream = 0)
962988
{
963989
return ::cudaMemPrefetchAsync(dev_ptr,
964990
count,

include/nvexec/stream/ensure_started.cuh

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,11 +409,18 @@ namespace nv::execution::_strm
409409
template <class Sender>
410410
using _sender_t = ensure_started_sender<__decay_t<Sender>>;
411411

412-
template <class Env, stream_completing_sender<Env> Sender>
413-
auto operator()(Env const & env, __ignore, __ignore, Sender&& sndr) const -> _sender_t<Sender>
412+
template <class Env, class Sender>
413+
auto operator()(Env const & env, __ignore, __ignore, Sender&& sndr) const
414414
{
415-
auto sched = get_completion_scheduler<set_value_t>(get_env(sndr), env);
416-
return _sender_t<Sender>{sched.ctx_, static_cast<Sender&&>(sndr)};
415+
if constexpr (stream_completing_sender<Sender, Env>)
416+
{
417+
auto sched = get_completion_scheduler<set_value_t>(get_env(sndr), env);
418+
return _sender_t<Sender>{sched.ctx_, static_cast<Sender&&>(sndr)};
419+
}
420+
else
421+
{
422+
return _strm::_no_stream_scheduler_in_env<exec::ensure_started_t, Sender, Env>();
423+
}
417424
}
418425
};
419426
} // namespace nv::execution::_strm

include/nvexec/stream/let_xxx.cuh

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,9 @@ namespace nv::execution::_strm
8080
using _sch_env_t = __result_of<_mk_sch_env, CvSender, Receiver, SetTag>;
8181

8282
inline constexpr auto _mk_env2 =
83-
[]<class SchEnv, class Receiver>([[maybe_unused]]
84-
SchEnv const & sch_env,
83+
[]<class SchEnv, class Receiver>(SchEnv const & sch_env,
8584
_strm::opstate_base<Receiver> const & opstate)
8685
{
87-
//return opstate.make_env();
8886
return __env::__join(sch_env, opstate.make_env());
8987
};
9088

@@ -210,22 +208,25 @@ namespace nv::execution::_strm
210208
using _mk_opstate_variant_fn = __mtransform<__muncurry<_mk_opstate_fn_t>, __qq<__variant>>;
211209
using _opstate_variant_t = __mapply<_mk_opstate_variant_fn, _result_tuples_t>;
212210
using _propagate_receiver_t = _let::_propagate_receiver_t<CvSender, Receiver, Fun, SetTag>;
211+
using _sch_t =
212+
__result_of<get_completion_scheduler<set_value_t>, env_of_t<CvSender>, env_of_t<Receiver>>;
213213

214214
explicit _opstate(CvSender&& sndr, Receiver rcvr, Fun fun)
215215
: _opstate(static_cast<CvSender&&>(sndr),
216216
static_cast<Receiver&&>(rcvr),
217217
static_cast<Fun&&>(fun),
218+
get_completion_scheduler<set_value_t>(get_env(sndr), get_env(rcvr)),
218219
_mk_sch_env(sndr, rcvr, SetTag{}))
219220
{}
220221

221-
explicit _opstate(CvSender&& sndr, Receiver&& rcvr, Fun fun, _env2_t env2)
222+
explicit _opstate(CvSender&& sndr, Receiver&& rcvr, Fun fun, _sch_t sch, _env2_t env2)
222223
: _opstate_base_t<CvSender, Receiver, Fun, SetTag>(
223224
static_cast<CvSender&&>(sndr),
224225
static_cast<Receiver&&>(rcvr),
225226
[this](__ignore) noexcept { return _receiver_t{{}, this}; },
226-
get_completion_scheduler<set_value_t>(get_env(sndr), get_env(rcvr)).ctx_)
227+
sch.ctx_)
227228
, fun_(static_cast<Fun&&>(fun))
228-
, env2_(env2)
229+
, env2_(static_cast<_env2_t&&>(env2))
229230
{}
230231

231232
STDEXEC_IMMOVABLE(_opstate);
@@ -308,10 +309,18 @@ namespace nv::execution::_strm
308309
template <class SetTag>
309310
struct _transform_let_sender
310311
{
311-
template <class Env, class Fun, stream_completing_sender<Env> Sender>
312+
template <class Env, class Fun, class Sender>
312313
auto operator()(Env const &, __ignore, Fun fn, Sender&& sndr) const
313314
{
314-
return let_sender{static_cast<Sender&&>(sndr), static_cast<Fun&&>(fn), SetTag{}};
315+
if constexpr (stream_completing_sender<Sender, Env>)
316+
{
317+
return let_sender{static_cast<Sender&&>(sndr), static_cast<Fun&&>(fn), SetTag{}};
318+
}
319+
else
320+
{
321+
using _let_t = decltype(STDEXEC::__let::__let_from_set<SetTag>);
322+
return _strm::_no_stream_scheduler_in_env<_let_t, Sender, Env>();
323+
}
315324
}
316325
};
317326

include/nvexec/stream/repeat_n.cuh

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,16 @@ namespace nv::execution::_strm
8383
, sched_(std::move(sched))
8484
, count_(count)
8585
{
86-
_connect();
86+
if (count_ != 0)
87+
{
88+
_connect();
89+
}
8790
}
8891

89-
void _connect()
92+
auto& _connect()
9093
{
9194
inner_opstate_.__emplace_from(STDEXEC::connect,
9295
exec::sequence(STDEXEC::schedule(sched_), sndr_),
93-
//STDEXEC::on(sched_, sndr_),
9496
receiver{*this});
9597
}
9698

@@ -114,8 +116,7 @@ namespace nv::execution::_strm
114116
}
115117
else
116118
{
117-
_connect();
118-
STDEXEC::start(*inner_opstate_);
119+
STDEXEC::start(_connect());
119120
}
120121
}
121122
else
@@ -167,6 +168,11 @@ namespace nv::execution::_strm
167168
STDEXEC::set_error_t(cudaError_t)>();
168169
}
169170

171+
explicit sender(CvSender&& sndr, std::size_t count)
172+
: sndr_(static_cast<CvSender&&>(sndr))
173+
, count_(count)
174+
{}
175+
170176
template <STDEXEC::receiver Receiver>
171177
auto connect(Receiver rcvr) && -> repeat_n::opstate<CvSender, Receiver>
172178
{
@@ -186,6 +192,7 @@ namespace nv::execution::_strm
186192
return STDEXEC::get_env(sndr_);
187193
}
188194

195+
private:
189196
CvSender sndr_; // could be a value or a reference
190197
std::size_t count_;
191198
};

include/nvexec/stream/schedule_from.cuh

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@
2727

2828
namespace nv::execution
2929
{
30-
struct CANNOT_DISPATCH_THE_SCHEDULE_FROM_ALGORITHM_TO_THE_CUDA_STREAM_SCHEDULER;
31-
struct BECAUSE_THERE_IS_NO_CUDA_STREAM_SCHEDULER_IN_THE_ENVIRONMENT;
32-
struct ADD_A_CONTINUES_ON_TRANSITION_TO_THE_CUDA_STREAM_SCHEDULER_BEFORE_THE_SCHEDULE_FROM_ALGORITHM;
33-
3430
namespace _strm
3531
{
3632
namespace _schfr
@@ -188,15 +184,7 @@ namespace nv::execution
188184
}
189185
else
190186
{
191-
return STDEXEC::__not_a_sender<
192-
STDEXEC::_WHAT_(
193-
CANNOT_DISPATCH_THE_SCHEDULE_FROM_ALGORITHM_TO_THE_CUDA_STREAM_SCHEDULER),
194-
STDEXEC::_WHY_(BECAUSE_THERE_IS_NO_CUDA_STREAM_SCHEDULER_IN_THE_ENVIRONMENT),
195-
STDEXEC::_WHERE_(STDEXEC::_IN_ALGORITHM_, STDEXEC::schedule_from_t),
196-
// STDEXEC::_TO_FIX_THIS_ERROR_(
197-
// ADD_A_CONTINUES_ON_TRANSITION_TO_THE_CUDA_STREAM_SCHEDULER_BEFORE_THE_SCHEDULE_FROM_ALGORITHM),
198-
STDEXEC::_WITH_PRETTY_SENDER_<Sender>,
199-
STDEXEC::_WITH_ENVIRONMENT_(Env)>{};
187+
return _strm::_no_stream_scheduler_in_env<STDEXEC::schedule_from_t, Sender, Env>();
200188
}
201189
}
202190
};

include/nvexec/stream/split.cuh

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -391,11 +391,18 @@ namespace nv::execution::_strm
391391
template <class Sender>
392392
using _sender_t = split_sender<__decay_t<Sender>>;
393393

394-
template <class Env, stream_completing_sender<Env> Sender>
395-
auto operator()(Env const & env, __ignore, __ignore, Sender&& sndr) const -> _sender_t<Sender>
394+
template <class Env, class Sender>
395+
auto operator()(Env const & env, __ignore, __ignore, Sender&& sndr) const
396396
{
397-
auto sched = get_completion_scheduler<set_value_t>(get_env(sndr), env);
398-
return _sender_t<Sender>{sched.ctx_, static_cast<Sender&&>(sndr)};
397+
if constexpr (stream_completing_sender<Sender, Env>)
398+
{
399+
auto sched = get_completion_scheduler<set_value_t>(get_env(sndr), env);
400+
return _sender_t<Sender>{sched.ctx_, static_cast<Sender&&>(sndr)};
401+
}
402+
else
403+
{
404+
return _strm::_no_stream_scheduler_in_env<exec::split_t, _sender_t<Sender>, Env>();
405+
}
399406
}
400407
};
401408
} // namespace nv::execution::_strm

0 commit comments

Comments
 (0)