Skip to content

Commit aee092f

Browse files
authored
Merge pull request #1739 from ericniebler/improve-nvexec-schedule-from-diagnostic
improve diagnostic from nvexec's `schedule_from` alg when no stream scheduler is found
2 parents fb34ccb + 82d78da commit aee092f

File tree

1 file changed

+163
-144
lines changed

1 file changed

+163
-144
lines changed

include/nvexec/stream/schedule_from.cuh

Lines changed: 163 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -25,172 +25,191 @@
2525

2626
#include "common.cuh"
2727

28-
namespace nvexec::_strm {
29-
30-
namespace _schfr {
31-
template <class CvrefSenderId, class ReceiverId>
32-
struct operation_state_t {
33-
using Sender = __cvref_t<CvrefSenderId>;
34-
using Receiver = stdexec::__t<ReceiverId>;
35-
using Env = operation_state_base_t<ReceiverId>::env_t;
36-
37-
struct __t : operation_state_base_t<ReceiverId> {
38-
struct receiver_t;
39-
using __id = operation_state_t;
40-
using variant_t = variant_storage_t<Sender, Env>;
41-
using task_t = continuation_task_t<receiver_t, variant_t>;
42-
using enqueue_receiver =
43-
stdexec::__t<stream_enqueue_receiver<stdexec::__cvref_id<Env>, variant_t>>;
44-
using inner_op_state_t = connect_result_t<Sender, enqueue_receiver>;
45-
46-
struct receiver_t {
47-
using receiver_concept = stdexec::receiver_t;
48-
49-
template <class... _Args>
50-
void set_value(_Args&&... __args) noexcept {
51-
stdexec::set_value(std::move(op_state_.rcvr_), static_cast<_Args&&>(__args)...);
28+
namespace nvexec {
29+
struct CANNOT_DISPATCH_THE_SCHEDULE_FROM_ALGORITHM_TO_THE_CUDA_STREAM_SCHEDULER;
30+
struct BECAUSE_THERE_IS_NO_CUDA_STREAM_SCHEDULER_IN_THE_ENVIRONMENT;
31+
struct ADD_A_CONTINUES_ON_TRANSITION_TO_THE_CUDA_STREAM_SCHEDULER_BEFORE_THE_SCHEDULE_FROM_ALGORITHM;
32+
33+
namespace _strm {
34+
35+
namespace _schfr {
36+
template <class CvrefSenderId, class ReceiverId>
37+
struct operation_state_t {
38+
using Sender = __cvref_t<CvrefSenderId>;
39+
using Receiver = stdexec::__t<ReceiverId>;
40+
using Env = operation_state_base_t<ReceiverId>::env_t;
41+
42+
struct __t : operation_state_base_t<ReceiverId> {
43+
struct receiver_t;
44+
using __id = operation_state_t;
45+
using variant_t = variant_storage_t<Sender, Env>;
46+
using task_t = continuation_task_t<receiver_t, variant_t>;
47+
using enqueue_receiver =
48+
stdexec::__t<stream_enqueue_receiver<stdexec::__cvref_id<Env>, variant_t>>;
49+
using inner_op_state_t = connect_result_t<Sender, enqueue_receiver>;
50+
51+
struct receiver_t {
52+
using receiver_concept = stdexec::receiver_t;
53+
54+
template <class... _Args>
55+
void set_value(_Args&&... __args) noexcept {
56+
stdexec::set_value(std::move(op_state_.rcvr_), static_cast<_Args&&>(__args)...);
57+
}
58+
59+
template <class _Error>
60+
void set_error(_Error&& __err) noexcept {
61+
stdexec::set_error(std::move(op_state_.rcvr_), static_cast<_Error&&>(__err));
62+
}
63+
64+
void set_stopped() noexcept {
65+
stdexec::set_stopped(std::move(op_state_.rcvr_));
66+
}
67+
68+
[[nodiscard]]
69+
auto get_env() const noexcept -> Env {
70+
return op_state_.make_env();
71+
}
72+
73+
__t& op_state_;
74+
};
75+
76+
__t(Sender&& sender, Receiver&& rcvr, context_state_t context_state)
77+
: operation_state_base_t<ReceiverId>(static_cast<Receiver&&>(rcvr), context_state)
78+
, context_state_(context_state)
79+
, storage_(make_host<variant_t>(this->status_, context_state.pinned_resource_))
80+
, task_(
81+
make_host<task_t>(
82+
this->status_,
83+
context_state.pinned_resource_,
84+
receiver_t{*this},
85+
storage_.get(),
86+
this->get_stream(),
87+
context_state.pinned_resource_)
88+
.release())
89+
, env_(make_host(this->status_, context_state_.pinned_resource_, this->make_env()))
90+
, inner_op_{connect(
91+
static_cast<Sender&&>(sender),
92+
enqueue_receiver{
93+
env_.get(),
94+
storage_.get(),
95+
task_,
96+
context_state_.hub_->producer()})} {
97+
if (this->status_ == cudaSuccess) {
98+
this->status_ = task_->status_;
99+
}
52100
}
53101

54-
template <class _Error>
55-
void set_error(_Error&& __err) noexcept {
56-
stdexec::set_error(std::move(op_state_.rcvr_), static_cast<_Error&&>(__err));
57-
}
102+
STDEXEC_IMMOVABLE(__t);
58103

59-
void set_stopped() noexcept {
60-
stdexec::set_stopped(std::move(op_state_.rcvr_));
61-
}
104+
void start() & noexcept {
105+
started_.test_and_set(::cuda::std::memory_order::relaxed);
62106

63-
[[nodiscard]]
64-
auto get_env() const noexcept -> Env {
65-
return op_state_.make_env();
66-
}
107+
if (status_ != cudaSuccess) {
108+
// Couldn't allocate memory for operation state, complete with error
109+
stdexec::set_error(std::move(this->rcvr_), std::move(status_));
110+
return;
111+
}
67112

68-
__t& op_state_;
69-
};
70-
71-
__t(Sender&& sender, Receiver&& rcvr, context_state_t context_state)
72-
: operation_state_base_t<ReceiverId>(static_cast<Receiver&&>(rcvr), context_state)
73-
, context_state_(context_state)
74-
, storage_(make_host<variant_t>(this->status_, context_state.pinned_resource_))
75-
, task_(
76-
make_host<task_t>(
77-
this->status_,
78-
context_state.pinned_resource_,
79-
receiver_t{*this},
80-
storage_.get(),
81-
this->get_stream(),
82-
context_state.pinned_resource_)
83-
.release())
84-
, env_(make_host(this->status_, context_state_.pinned_resource_, this->make_env()))
85-
, inner_op_{connect(
86-
static_cast<Sender&&>(sender),
87-
enqueue_receiver{
88-
env_.get(),
89-
storage_.get(),
90-
task_,
91-
context_state_.hub_->producer()})} {
92-
if (this->status_ == cudaSuccess) {
93-
this->status_ = task_->status_;
113+
stdexec::start(inner_op_);
94114
}
95-
}
96-
97-
STDEXEC_IMMOVABLE(__t);
98115

99-
void start() & noexcept {
100-
started_.test_and_set(::cuda::std::memory_order::relaxed);
116+
cudaError_t status_{cudaSuccess};
117+
context_state_t context_state_;
118+
host_ptr<variant_t> storage_;
119+
task_t* task_;
120+
::cuda::std::atomic_flag started_{};
121+
host_ptr<__decay_t<Env>> env_{};
122+
inner_op_state_t inner_op_;
123+
};
124+
};
125+
} // namespace _schfr
101126

102-
if (status_ != cudaSuccess) {
103-
// Couldn't allocate memory for operation state, complete with error
104-
stdexec::set_error(std::move(this->rcvr_), std::move(status_));
105-
return;
106-
}
127+
template <class SenderId>
128+
struct schedule_from_sender_t {
129+
using Sender = stdexec::__t<SenderId>;
107130

108-
stdexec::start(inner_op_);
109-
}
110-
111-
cudaError_t status_{cudaSuccess};
112-
context_state_t context_state_;
113-
host_ptr<variant_t> storage_;
114-
task_t* task_;
115-
::cuda::std::atomic_flag started_{};
116-
host_ptr<__decay_t<Env>> env_{};
117-
inner_op_state_t inner_op_;
118-
};
119-
};
120-
} // namespace _schfr
131+
struct __t : stream_sender_base {
132+
using __id = schedule_from_sender_t;
121133

122-
template <class SenderId>
123-
struct schedule_from_sender_t {
124-
using Sender = stdexec::__t<SenderId>;
134+
template <class Self, class Receiver>
135+
using op_state_th =
136+
stdexec::__t<_schfr::operation_state_t<__cvref_id<Self, Sender>, stdexec::__id<Receiver>>>;
125137

126-
struct __t : stream_sender_base {
127-
using __id = schedule_from_sender_t;
138+
template <class... Ts>
139+
using _set_value_t = completion_signatures<set_value_t(stdexec::__decay_t<Ts>...)>;
128140

129-
template <class Self, class Receiver>
130-
using op_state_th =
131-
stdexec::__t<_schfr::operation_state_t<__cvref_id<Self, Sender>, stdexec::__id<Receiver>>>;
141+
template <class Ty>
142+
using _set_error_t = completion_signatures<set_error_t(stdexec::__decay_t<Ty>)>;
132143

133-
template <class... Ts>
134-
using _set_value_t = completion_signatures<set_value_t(stdexec::__decay_t<Ts>...)>;
144+
template <class Self, class... Env>
145+
using _completion_signatures_t = transform_completion_signatures<
146+
__completion_signatures_of_t<__copy_cvref_t<Self, Sender>, Env...>,
147+
completion_signatures<set_stopped_t(), set_error_t(cudaError_t)>,
148+
_set_value_t,
149+
_set_error_t
150+
>;
135151

136-
template <class Ty>
137-
using _set_error_t = completion_signatures<set_error_t(stdexec::__decay_t<Ty>)>;
152+
__t(context_state_t context_state, Sender sndr)
153+
: context_state_(context_state)
154+
, sndr_{static_cast<Sender&&>(sndr)} {
155+
}
138156

139-
template <class Self, class... Env>
140-
using _completion_signatures_t = transform_completion_signatures<
141-
__completion_signatures_of_t<__copy_cvref_t<Self, Sender>, Env...>,
142-
completion_signatures<set_stopped_t(), set_error_t(cudaError_t)>,
143-
_set_value_t,
144-
_set_error_t
145-
>;
157+
template <__decays_to<__t> Self, receiver Receiver>
158+
requires receiver_of<Receiver, _completion_signatures_t<Self, env_of_t<Receiver>>>
159+
STDEXEC_EXPLICIT_THIS_BEGIN(auto connect)(this Self&& self, Receiver rcvr)
160+
-> op_state_th<Self, Receiver> {
161+
return op_state_th<Self, Receiver>{
162+
static_cast<Self&&>(self).sndr_, static_cast<Receiver&&>(rcvr), self.context_state_};
163+
}
164+
STDEXEC_EXPLICIT_THIS_END(connect)
146165

147-
__t(context_state_t context_state, Sender sndr)
148-
: context_state_(context_state)
149-
, sndr_{static_cast<Sender&&>(sndr)} {
150-
}
166+
template <__decays_to<__t> Self, class... Env>
167+
STDEXEC_EXPLICIT_THIS_BEGIN(auto get_completion_signatures)(this Self&&, Env&&...)
168+
-> _completion_signatures_t<Self, Env...> {
169+
return {};
170+
}
171+
STDEXEC_EXPLICIT_THIS_END(get_completion_signatures)
151172

152-
template <__decays_to<__t> Self, receiver Receiver>
153-
requires receiver_of<Receiver, _completion_signatures_t<Self, env_of_t<Receiver>>>
154-
STDEXEC_EXPLICIT_THIS_BEGIN(auto connect)(this Self&& self, Receiver rcvr)
155-
-> op_state_th<Self, Receiver> {
156-
return op_state_th<Self, Receiver>{
157-
static_cast<Self&&>(self).sndr_, static_cast<Receiver&&>(rcvr), self.context_state_};
158-
}
159-
STDEXEC_EXPLICIT_THIS_END(connect)
173+
auto get_env() const noexcept -> stdexec::__fwd_env_t<stdexec::env_of_t<Sender>> {
174+
return stdexec::__fwd_env(stdexec::get_env(sndr_));
175+
}
160176

161-
template <__decays_to<__t> Self, class... Env>
162-
STDEXEC_EXPLICIT_THIS_BEGIN(auto get_completion_signatures)(this Self&&, Env&&...) -> _completion_signatures_t<Self, Env...> {
163-
return {};
164-
}
165-
STDEXEC_EXPLICIT_THIS_END(get_completion_signatures)
177+
context_state_t context_state_;
178+
Sender sndr_;
179+
};
180+
};
166181

167-
auto get_env() const noexcept -> stdexec::__fwd_env_t<stdexec::env_of_t<Sender>> {
168-
return stdexec::__fwd_env(stdexec::get_env(sndr_));
182+
template <class Env>
183+
struct transform_sender_for<stdexec::schedule_from_t, Env> {
184+
template <class Sender>
185+
using _current_scheduler_t =
186+
__result_of<get_completion_scheduler<set_value_t>, env_of_t<Sender>, const Env&>;
187+
188+
template <class Sender>
189+
auto operator()(__ignore, __ignore, Sender&& sndr) const {
190+
if constexpr (stream_completing_sender<Sender, Env>) {
191+
using _sender_t = __t<schedule_from_sender_t<__id<__decay_t<Sender>>>>;
192+
auto stream_sched = get_completion_scheduler<set_value_t>(get_env(sndr), env_);
193+
return _sender_t{stream_sched.context_state_, static_cast<Sender&&>(sndr)};
194+
} else {
195+
return stdexec::__not_a_sender<
196+
stdexec::_WHAT_<>(
197+
CANNOT_DISPATCH_THE_SCHEDULE_FROM_ALGORITHM_TO_THE_CUDA_STREAM_SCHEDULER),
198+
stdexec::_WHY_(BECAUSE_THERE_IS_NO_CUDA_STREAM_SCHEDULER_IN_THE_ENVIRONMENT),
199+
stdexec::_WHERE_(stdexec::_IN_ALGORITHM_, stdexec::schedule_from_t),
200+
// stdexec::_TO_FIX_THIS_ERROR_(
201+
// ADD_A_CONTINUES_ON_TRANSITION_TO_THE_CUDA_STREAM_SCHEDULER_BEFORE_THE_SCHEDULE_FROM_ALGORITHM),
202+
stdexec::_WITH_SENDER_<Sender>,
203+
stdexec::_WITH_ENVIRONMENT_<Env>
204+
>{};
205+
}
169206
}
170207

171-
context_state_t context_state_;
172-
Sender sndr_;
208+
const Env& env_;
173209
};
174-
};
175-
176-
template <class Env>
177-
struct transform_sender_for<stdexec::schedule_from_t, Env> {
178-
template <class Sender>
179-
using _current_scheduler_t =
180-
__result_of<get_completion_scheduler<set_value_t>, env_of_t<Sender>, const Env&>;
181-
182-
template <class Sender>
183-
auto operator()(__ignore, __ignore, Sender&& sndr) const {
184-
static_assert(stream_completing_sender<Sender, Env>);
185-
using _sender_t = __t<schedule_from_sender_t<__id<__decay_t<Sender>>>>;
186-
auto stream_sched = get_completion_scheduler<set_value_t>(get_env(sndr), env_);
187-
return _sender_t{stream_sched.context_state_, static_cast<Sender&&>(sndr)};
188-
}
189-
190-
const Env& env_;
191-
};
192-
193-
} // namespace nvexec::_strm
210+
211+
} // namespace _strm
212+
} // namespace nvexec
194213

195214
namespace stdexec::__detail {
196215
template <class SenderId>

0 commit comments

Comments
 (0)