@@ -98,31 +98,159 @@ class StreamingShuffler : public BaseStreamingFixture {
9898};
9999
100100TEST_F (StreamingShuffler, Basic) {
101- run_test ([&](auto ctx, auto ch_in, auto ch_out, std::vector<Node>& nodes) {
102- nodes.emplace_back (
103- node::shuffler (
101+ EXPECT_NO_FATAL_FAILURE (
102+ run_test ([&](auto ctx, auto ch_in, auto ch_out, std::vector<Node>& nodes) {
103+ nodes.emplace_back (
104+ node::shuffler (
105+ std::move (ctx),
106+ stream,
107+ std::move (ch_in),
108+ std::move (ch_out),
109+ op_id,
110+ num_partitions
111+ )
112+ );
113+ })
114+ );
115+ }
116+
117+ namespace {
118+
119+ void sync_streams (
120+ rmm::cuda_stream_view primary,
121+ rmm::cuda_stream_view secondary,
122+ cudaEvent_t const & event
123+ ) {
124+ if (primary.value () != secondary.value ()) {
125+ RAPIDSMPF_CUDA_TRY (cudaEventRecord (event, secondary));
126+ RAPIDSMPF_CUDA_TRY (cudaStreamWaitEvent (primary, event));
127+ }
128+ }
129+
130+ // emulate shuffler node with callbacks
131+ std::pair<Node, Node> shuffler_nb (
132+ std::shared_ptr<Context> ctx,
133+ rmm::cuda_stream_view stream,
134+ std::shared_ptr<Channel> ch_in,
135+ std::shared_ptr<Channel> ch_out,
136+ OpID op_id,
137+ shuffler::PartID total_num_partitions
138+ ) {
139+ // make a shared_ptr to the shuffler so that it can be passed into multiple coroutines
140+ auto shuffler = std::make_shared<rapidsmpf::shuffler::Shuffler>(
141+ ctx->comm (),
142+ ctx->progress_thread (),
143+ op_id,
144+ total_num_partitions,
145+ stream,
146+ ctx->br (),
147+ ctx->statistics (),
148+ shuffler::Shuffler::round_robin
149+ );
150+
151+ // insert task: insert the partition map chunks into the shuffler
152+ auto insert_task =
153+ [](
154+ auto shuffler, auto ctx, auto total_num_partitions, auto stream, auto ch_in
155+ ) -> Node {
156+ ShutdownAtExit c{ch_in};
157+ co_await ctx->executor ()->schedule ();
158+ CudaEvent event;
159+
160+ while (true ) {
161+ auto msg = co_await ch_in->receive ();
162+ if (msg.empty ()) {
163+ break ;
164+ }
165+ auto partition_map = msg.template release <PartitionMapChunk>();
166+
167+ // Make sure that the input chunk's stream is in sync with shuffler's stream.
168+ sync_streams (stream, partition_map.stream , event);
169+
170+ shuffler->insert (std::move (partition_map.data ));
171+ }
172+
173+ // Tell the shuffler that we have no more input data.
174+ std::vector<rapidsmpf::shuffler::PartID> finished (total_num_partitions);
175+ std::iota (finished.begin (), finished.end (), 0 );
176+ shuffler->insert_finished (std::move (finished));
177+ co_return ;
178+ };
179+
180+ // extract task: extract the packed chunks from the shuffler and send them to the
181+ // output channel
182+ auto extract_task = [](auto shuffler, auto ctx, auto ch_out) -> Node {
183+ ShutdownAtExit c{ch_out};
184+ co_await ctx->executor ()->schedule ();
185+
186+ coro::mutex mtx{};
187+ coro::condition_variable cv{};
188+ bool finished{false };
189+
190+ shuffler->register_finished_callback (
191+ [shuffler, ctx, ch_out, &mtx, &cv, &finished](auto pid) {
192+ // task to extract and send each finished partition
193+ auto extract_and_send = [](auto shuffler,
194+ auto ctx,
195+ auto ch_out,
196+ auto pid,
197+ coro::condition_variable& cv,
198+ coro::mutex& mtx,
199+ bool & finished) -> Node {
200+ co_await ctx->executor ()->schedule ();
201+ auto packed_chunks = shuffler->extract (pid);
202+ co_await ch_out->send (
203+ std::make_unique<PartitionVectorChunk>(
204+ pid, std::move (packed_chunks)
205+ )
206+ );
207+
208+ // signal that all partitions have been finished
209+ if (shuffler->finished ()) {
210+ {
211+ auto lock = co_await mtx.scoped_lock ();
212+ finished = true ;
213+ }
214+ co_await cv.notify_one ();
215+ }
216+ };
217+ // schedule a detached task to extract and send the packed chunks
218+ ctx->executor ()->spawn (
219+ extract_and_send (shuffler, ctx, ch_out, pid, cv, mtx, finished)
220+ );
221+ }
222+ );
223+
224+ // wait for all partitions to be finished
225+ {
226+ auto lock = co_await mtx.scoped_lock ();
227+ co_await cv.wait (lock, [&finished]() { return finished; });
228+ }
229+
230+ co_await ch_out->drain (ctx->executor ());
231+ };
232+
233+ return {
234+ insert_task (shuffler, ctx, total_num_partitions, stream, std::move (ch_in)),
235+ extract_task (std::move (shuffler), std::move (ctx), std::move (ch_out))
236+ };
237+ }
238+
239+ } // namespace
240+
241+ TEST_F (StreamingShuffler, callbacks) {
242+ EXPECT_NO_FATAL_FAILURE (
243+ run_test ([&](auto ctx, auto ch_in, auto ch_out, std::vector<Node>& nodes) {
244+ auto [insert_node, extract_node] = shuffler_nb (
104245 std::move (ctx),
105246 stream,
106247 std::move (ch_in),
107248 std::move (ch_out),
108249 op_id,
109250 num_partitions
110- )
111- );
112- });
113- }
114-
115- TEST_F (StreamingShuffler, callbacks) {
116- run_test ([&](auto ctx, auto ch_in, auto ch_out, std::vector<Node>& nodes) {
117- auto [insert_node, extract_node] = node::shuffler_nb (
118- std::move (ctx),
119- stream,
120- std::move (ch_in),
121- std::move (ch_out),
122- op_id,
123- num_partitions
124- );
125- nodes.emplace_back (std::move (insert_node));
126- nodes.emplace_back (std::move (extract_node));
127- });
251+ );
252+ nodes.emplace_back (std::move (insert_node));
253+ nodes.emplace_back (std::move (extract_node));
254+ })
255+ );
128256}
0 commit comments