|
25 | 25 |
|
26 | 26 | #include "common.cuh" |
27 | 27 |
|
28 | | -namespace nvexec::_strm { |
29 | | - |
30 | | - namespace _schfr { |
31 | | - template <class CvrefSenderId, class ReceiverId> |
32 | | - struct operation_state_t { |
33 | | - using Sender = __cvref_t<CvrefSenderId>; |
34 | | - using Receiver = stdexec::__t<ReceiverId>; |
35 | | - using Env = operation_state_base_t<ReceiverId>::env_t; |
36 | | - |
37 | | - struct __t : operation_state_base_t<ReceiverId> { |
38 | | - struct receiver_t; |
39 | | - using __id = operation_state_t; |
40 | | - using variant_t = variant_storage_t<Sender, Env>; |
41 | | - using task_t = continuation_task_t<receiver_t, variant_t>; |
42 | | - using enqueue_receiver = |
43 | | - stdexec::__t<stream_enqueue_receiver<stdexec::__cvref_id<Env>, variant_t>>; |
44 | | - using inner_op_state_t = connect_result_t<Sender, enqueue_receiver>; |
45 | | - |
46 | | - struct receiver_t { |
47 | | - using receiver_concept = stdexec::receiver_t; |
48 | | - |
49 | | - template <class... _Args> |
50 | | - void set_value(_Args&&... __args) noexcept { |
51 | | - stdexec::set_value(std::move(op_state_.rcvr_), static_cast<_Args&&>(__args)...); |
| 28 | +namespace nvexec { |
| 29 | + struct CANNOT_DISPATCH_THE_SCHEDULE_FROM_ALGORITHM_TO_THE_CUDA_STREAM_SCHEDULER; |
| 30 | + struct BECAUSE_THERE_IS_NO_CUDA_STREAM_SCHEDULER_IN_THE_ENVIRONMENT; |
| 31 | + struct ADD_A_CONTINUES_ON_TRANSITION_TO_THE_CUDA_STREAM_SCHEDULER_BEFORE_THE_SCHEDULE_FROM_ALGORITHM; |
| 32 | + |
| 33 | + namespace _strm { |
| 34 | + |
| 35 | + namespace _schfr { |
| 36 | + template <class CvrefSenderId, class ReceiverId> |
| 37 | + struct operation_state_t { |
| 38 | + using Sender = __cvref_t<CvrefSenderId>; |
| 39 | + using Receiver = stdexec::__t<ReceiverId>; |
| 40 | + using Env = operation_state_base_t<ReceiverId>::env_t; |
| 41 | + |
| 42 | + struct __t : operation_state_base_t<ReceiverId> { |
| 43 | + struct receiver_t; |
| 44 | + using __id = operation_state_t; |
| 45 | + using variant_t = variant_storage_t<Sender, Env>; |
| 46 | + using task_t = continuation_task_t<receiver_t, variant_t>; |
| 47 | + using enqueue_receiver = |
| 48 | + stdexec::__t<stream_enqueue_receiver<stdexec::__cvref_id<Env>, variant_t>>; |
| 49 | + using inner_op_state_t = connect_result_t<Sender, enqueue_receiver>; |
| 50 | + |
| 51 | + struct receiver_t { |
| 52 | + using receiver_concept = stdexec::receiver_t; |
| 53 | + |
| 54 | + template <class... _Args> |
| 55 | + void set_value(_Args&&... __args) noexcept { |
| 56 | + stdexec::set_value(std::move(op_state_.rcvr_), static_cast<_Args&&>(__args)...); |
| 57 | + } |
| 58 | + |
| 59 | + template <class _Error> |
| 60 | + void set_error(_Error&& __err) noexcept { |
| 61 | + stdexec::set_error(std::move(op_state_.rcvr_), static_cast<_Error&&>(__err)); |
| 62 | + } |
| 63 | + |
| 64 | + void set_stopped() noexcept { |
| 65 | + stdexec::set_stopped(std::move(op_state_.rcvr_)); |
| 66 | + } |
| 67 | + |
| 68 | + [[nodiscard]] |
| 69 | + auto get_env() const noexcept -> Env { |
| 70 | + return op_state_.make_env(); |
| 71 | + } |
| 72 | + |
| 73 | + __t& op_state_; |
| 74 | + }; |
| 75 | + |
| 76 | + __t(Sender&& sender, Receiver&& rcvr, context_state_t context_state) |
| 77 | + : operation_state_base_t<ReceiverId>(static_cast<Receiver&&>(rcvr), context_state) |
| 78 | + , context_state_(context_state) |
| 79 | + , storage_(make_host<variant_t>(this->status_, context_state.pinned_resource_)) |
| 80 | + , task_( |
| 81 | + make_host<task_t>( |
| 82 | + this->status_, |
| 83 | + context_state.pinned_resource_, |
| 84 | + receiver_t{*this}, |
| 85 | + storage_.get(), |
| 86 | + this->get_stream(), |
| 87 | + context_state.pinned_resource_) |
| 88 | + .release()) |
| 89 | + , env_(make_host(this->status_, context_state_.pinned_resource_, this->make_env())) |
| 90 | + , inner_op_{connect( |
| 91 | + static_cast<Sender&&>(sender), |
| 92 | + enqueue_receiver{ |
| 93 | + env_.get(), |
| 94 | + storage_.get(), |
| 95 | + task_, |
| 96 | + context_state_.hub_->producer()})} { |
| 97 | + if (this->status_ == cudaSuccess) { |
| 98 | + this->status_ = task_->status_; |
| 99 | + } |
52 | 100 | } |
53 | 101 |
|
54 | | - template <class _Error> |
55 | | - void set_error(_Error&& __err) noexcept { |
56 | | - stdexec::set_error(std::move(op_state_.rcvr_), static_cast<_Error&&>(__err)); |
57 | | - } |
| 102 | + STDEXEC_IMMOVABLE(__t); |
58 | 103 |
|
59 | | - void set_stopped() noexcept { |
60 | | - stdexec::set_stopped(std::move(op_state_.rcvr_)); |
61 | | - } |
| 104 | + void start() & noexcept { |
| 105 | + started_.test_and_set(::cuda::std::memory_order::relaxed); |
62 | 106 |
|
63 | | - [[nodiscard]] |
64 | | - auto get_env() const noexcept -> Env { |
65 | | - return op_state_.make_env(); |
66 | | - } |
| 107 | + if (status_ != cudaSuccess) { |
| 108 | + // Couldn't allocate memory for operation state, complete with error |
| 109 | + stdexec::set_error(std::move(this->rcvr_), std::move(status_)); |
| 110 | + return; |
| 111 | + } |
67 | 112 |
|
68 | | - __t& op_state_; |
69 | | - }; |
70 | | - |
71 | | - __t(Sender&& sender, Receiver&& rcvr, context_state_t context_state) |
72 | | - : operation_state_base_t<ReceiverId>(static_cast<Receiver&&>(rcvr), context_state) |
73 | | - , context_state_(context_state) |
74 | | - , storage_(make_host<variant_t>(this->status_, context_state.pinned_resource_)) |
75 | | - , task_( |
76 | | - make_host<task_t>( |
77 | | - this->status_, |
78 | | - context_state.pinned_resource_, |
79 | | - receiver_t{*this}, |
80 | | - storage_.get(), |
81 | | - this->get_stream(), |
82 | | - context_state.pinned_resource_) |
83 | | - .release()) |
84 | | - , env_(make_host(this->status_, context_state_.pinned_resource_, this->make_env())) |
85 | | - , inner_op_{connect( |
86 | | - static_cast<Sender&&>(sender), |
87 | | - enqueue_receiver{ |
88 | | - env_.get(), |
89 | | - storage_.get(), |
90 | | - task_, |
91 | | - context_state_.hub_->producer()})} { |
92 | | - if (this->status_ == cudaSuccess) { |
93 | | - this->status_ = task_->status_; |
| 113 | + stdexec::start(inner_op_); |
94 | 114 | } |
95 | | - } |
96 | | - |
97 | | - STDEXEC_IMMOVABLE(__t); |
98 | 115 |
|
99 | | - void start() & noexcept { |
100 | | - started_.test_and_set(::cuda::std::memory_order::relaxed); |
| 116 | + cudaError_t status_{cudaSuccess}; |
| 117 | + context_state_t context_state_; |
| 118 | + host_ptr<variant_t> storage_; |
| 119 | + task_t* task_; |
| 120 | + ::cuda::std::atomic_flag started_{}; |
| 121 | + host_ptr<__decay_t<Env>> env_{}; |
| 122 | + inner_op_state_t inner_op_; |
| 123 | + }; |
| 124 | + }; |
| 125 | + } // namespace _schfr |
101 | 126 |
|
102 | | - if (status_ != cudaSuccess) { |
103 | | - // Couldn't allocate memory for operation state, complete with error |
104 | | - stdexec::set_error(std::move(this->rcvr_), std::move(status_)); |
105 | | - return; |
106 | | - } |
| 127 | + template <class SenderId> |
| 128 | + struct schedule_from_sender_t { |
| 129 | + using Sender = stdexec::__t<SenderId>; |
107 | 130 |
|
108 | | - stdexec::start(inner_op_); |
109 | | - } |
110 | | - |
111 | | - cudaError_t status_{cudaSuccess}; |
112 | | - context_state_t context_state_; |
113 | | - host_ptr<variant_t> storage_; |
114 | | - task_t* task_; |
115 | | - ::cuda::std::atomic_flag started_{}; |
116 | | - host_ptr<__decay_t<Env>> env_{}; |
117 | | - inner_op_state_t inner_op_; |
118 | | - }; |
119 | | - }; |
120 | | - } // namespace _schfr |
| 131 | + struct __t : stream_sender_base { |
| 132 | + using __id = schedule_from_sender_t; |
121 | 133 |
|
122 | | - template <class SenderId> |
123 | | - struct schedule_from_sender_t { |
124 | | - using Sender = stdexec::__t<SenderId>; |
| 134 | + template <class Self, class Receiver> |
| 135 | + using op_state_th = |
| 136 | + stdexec::__t<_schfr::operation_state_t<__cvref_id<Self, Sender>, stdexec::__id<Receiver>>>; |
125 | 137 |
|
126 | | - struct __t : stream_sender_base { |
127 | | - using __id = schedule_from_sender_t; |
| 138 | + template <class... Ts> |
| 139 | + using _set_value_t = completion_signatures<set_value_t(stdexec::__decay_t<Ts>...)>; |
128 | 140 |
|
129 | | - template <class Self, class Receiver> |
130 | | - using op_state_th = |
131 | | - stdexec::__t<_schfr::operation_state_t<__cvref_id<Self, Sender>, stdexec::__id<Receiver>>>; |
| 141 | + template <class Ty> |
| 142 | + using _set_error_t = completion_signatures<set_error_t(stdexec::__decay_t<Ty>)>; |
132 | 143 |
|
133 | | - template <class... Ts> |
134 | | - using _set_value_t = completion_signatures<set_value_t(stdexec::__decay_t<Ts>...)>; |
| 144 | + template <class Self, class... Env> |
| 145 | + using _completion_signatures_t = transform_completion_signatures< |
| 146 | + __completion_signatures_of_t<__copy_cvref_t<Self, Sender>, Env...>, |
| 147 | + completion_signatures<set_stopped_t(), set_error_t(cudaError_t)>, |
| 148 | + _set_value_t, |
| 149 | + _set_error_t |
| 150 | + >; |
135 | 151 |
|
136 | | - template <class Ty> |
137 | | - using _set_error_t = completion_signatures<set_error_t(stdexec::__decay_t<Ty>)>; |
| 152 | + __t(context_state_t context_state, Sender sndr) |
| 153 | + : context_state_(context_state) |
| 154 | + , sndr_{static_cast<Sender&&>(sndr)} { |
| 155 | + } |
138 | 156 |
|
139 | | - template <class Self, class... Env> |
140 | | - using _completion_signatures_t = transform_completion_signatures< |
141 | | - __completion_signatures_of_t<__copy_cvref_t<Self, Sender>, Env...>, |
142 | | - completion_signatures<set_stopped_t(), set_error_t(cudaError_t)>, |
143 | | - _set_value_t, |
144 | | - _set_error_t |
145 | | - >; |
| 157 | + template <__decays_to<__t> Self, receiver Receiver> |
| 158 | + requires receiver_of<Receiver, _completion_signatures_t<Self, env_of_t<Receiver>>> |
| 159 | + STDEXEC_EXPLICIT_THIS_BEGIN(auto connect)(this Self&& self, Receiver rcvr) |
| 160 | + -> op_state_th<Self, Receiver> { |
| 161 | + return op_state_th<Self, Receiver>{ |
| 162 | + static_cast<Self&&>(self).sndr_, static_cast<Receiver&&>(rcvr), self.context_state_}; |
| 163 | + } |
| 164 | + STDEXEC_EXPLICIT_THIS_END(connect) |
146 | 165 |
|
147 | | - __t(context_state_t context_state, Sender sndr) |
148 | | - : context_state_(context_state) |
149 | | - , sndr_{static_cast<Sender&&>(sndr)} { |
150 | | - } |
| 166 | + template <__decays_to<__t> Self, class... Env> |
| 167 | + STDEXEC_EXPLICIT_THIS_BEGIN(auto get_completion_signatures)(this Self&&, Env&&...) |
| 168 | + -> _completion_signatures_t<Self, Env...> { |
| 169 | + return {}; |
| 170 | + } |
| 171 | + STDEXEC_EXPLICIT_THIS_END(get_completion_signatures) |
151 | 172 |
|
152 | | - template <__decays_to<__t> Self, receiver Receiver> |
153 | | - requires receiver_of<Receiver, _completion_signatures_t<Self, env_of_t<Receiver>>> |
154 | | - STDEXEC_EXPLICIT_THIS_BEGIN(auto connect)(this Self&& self, Receiver rcvr) |
155 | | - -> op_state_th<Self, Receiver> { |
156 | | - return op_state_th<Self, Receiver>{ |
157 | | - static_cast<Self&&>(self).sndr_, static_cast<Receiver&&>(rcvr), self.context_state_}; |
158 | | - } |
159 | | - STDEXEC_EXPLICIT_THIS_END(connect) |
| 173 | + auto get_env() const noexcept -> stdexec::__fwd_env_t<stdexec::env_of_t<Sender>> { |
| 174 | + return stdexec::__fwd_env(stdexec::get_env(sndr_)); |
| 175 | + } |
160 | 176 |
|
161 | | - template <__decays_to<__t> Self, class... Env> |
162 | | - STDEXEC_EXPLICIT_THIS_BEGIN(auto get_completion_signatures)(this Self&&, Env&&...) -> _completion_signatures_t<Self, Env...> { |
163 | | - return {}; |
164 | | - } |
165 | | - STDEXEC_EXPLICIT_THIS_END(get_completion_signatures) |
| 177 | + context_state_t context_state_; |
| 178 | + Sender sndr_; |
| 179 | + }; |
| 180 | + }; |
166 | 181 |
|
167 | | - auto get_env() const noexcept -> stdexec::__fwd_env_t<stdexec::env_of_t<Sender>> { |
168 | | - return stdexec::__fwd_env(stdexec::get_env(sndr_)); |
| 182 | + template <class Env> |
| 183 | + struct transform_sender_for<stdexec::schedule_from_t, Env> { |
| 184 | + template <class Sender> |
| 185 | + using _current_scheduler_t = |
| 186 | + __result_of<get_completion_scheduler<set_value_t>, env_of_t<Sender>, const Env&>; |
| 187 | + |
| 188 | + template <class Sender> |
| 189 | + auto operator()(__ignore, __ignore, Sender&& sndr) const { |
| 190 | + if constexpr (stream_completing_sender<Sender, Env>) { |
| 191 | + using _sender_t = __t<schedule_from_sender_t<__id<__decay_t<Sender>>>>; |
| 192 | + auto stream_sched = get_completion_scheduler<set_value_t>(get_env(sndr), env_); |
| 193 | + return _sender_t{stream_sched.context_state_, static_cast<Sender&&>(sndr)}; |
| 194 | + } else { |
| 195 | + return stdexec::__not_a_sender< |
| 196 | + stdexec::_WHAT_<>( |
| 197 | + CANNOT_DISPATCH_THE_SCHEDULE_FROM_ALGORITHM_TO_THE_CUDA_STREAM_SCHEDULER), |
| 198 | + stdexec::_WHY_(BECAUSE_THERE_IS_NO_CUDA_STREAM_SCHEDULER_IN_THE_ENVIRONMENT), |
| 199 | + stdexec::_WHERE_(stdexec::_IN_ALGORITHM_, stdexec::schedule_from_t), |
| 200 | + // stdexec::_TO_FIX_THIS_ERROR_( |
| 201 | + // ADD_A_CONTINUES_ON_TRANSITION_TO_THE_CUDA_STREAM_SCHEDULER_BEFORE_THE_SCHEDULE_FROM_ALGORITHM), |
| 202 | + stdexec::_WITH_SENDER_<Sender>, |
| 203 | + stdexec::_WITH_ENVIRONMENT_<Env> |
| 204 | + >{}; |
| 205 | + } |
169 | 206 | } |
170 | 207 |
|
171 | | - context_state_t context_state_; |
172 | | - Sender sndr_; |
| 208 | + const Env& env_; |
173 | 209 | }; |
174 | | - }; |
175 | | - |
176 | | - template <class Env> |
177 | | - struct transform_sender_for<stdexec::schedule_from_t, Env> { |
178 | | - template <class Sender> |
179 | | - using _current_scheduler_t = |
180 | | - __result_of<get_completion_scheduler<set_value_t>, env_of_t<Sender>, const Env&>; |
181 | | - |
182 | | - template <class Sender> |
183 | | - auto operator()(__ignore, __ignore, Sender&& sndr) const { |
184 | | - static_assert(stream_completing_sender<Sender, Env>); |
185 | | - using _sender_t = __t<schedule_from_sender_t<__id<__decay_t<Sender>>>>; |
186 | | - auto stream_sched = get_completion_scheduler<set_value_t>(get_env(sndr), env_); |
187 | | - return _sender_t{stream_sched.context_state_, static_cast<Sender&&>(sndr)}; |
188 | | - } |
189 | | - |
190 | | - const Env& env_; |
191 | | - }; |
192 | | - |
193 | | -} // namespace nvexec::_strm |
| 210 | + |
| 211 | + } // namespace _strm |
| 212 | +} // namespace nvexec |
194 | 213 |
|
195 | 214 | namespace stdexec::__detail { |
196 | 215 | template <class SenderId> |
|
0 commit comments