2626from rapidsmpf .utils .cudf import cudf_to_pylibcudf_table
2727
2828if TYPE_CHECKING :
29+ from collections .abc import Callable
2930 from typing import Any
3031
3132 import dask_cudf
3233
3334 import cudf
3435
36+ from rapidsmpf .integrations .core import (
37+ BCastJoinInfo ,
38+ ShufflerIntegration ,
39+ )
3540 from rapidsmpf .shuffler import Shuffler
3641
3742
@@ -153,6 +158,30 @@ def extract_partition(
153158 )
154159
155160
161+ def _get_cluster_kind (
162+ cluster_kind : Literal ["distributed" , "single" , "auto" ],
163+ ) -> Literal ["distributed" , "single" ]:
164+ """Validate and return the kind of cluster to use."""
165+ if cluster_kind not in ("distributed" , "single" , "auto" ):
166+ raise ValueError (
167+ f"Expected one of 'distributed', 'single', or 'auto'. Got { cluster_kind } "
168+ )
169+
170+ if cluster_kind == "auto" :
171+ try :
172+ from distributed import get_client
173+
174+ get_client ()
175+ except (ImportError , ValueError ):
176+ # Failed to import distributed/dask-cuda or find a Dask client.
177+ # Use single shuffle instead.
178+ cluster_kind = "single"
179+ else :
180+ cluster_kind = "distributed"
181+
182+ return cluster_kind
183+
184+
156185def dask_cudf_shuffle (
157186 df : dask_cudf .DataFrame ,
158187 on : list [str ],
@@ -195,10 +224,10 @@ def dask_cudf_shuffle(
195224 This API is currently intended for demonstration and
196225 testing purposes only.
197226 """
198- if cluster_kind not in ( "distributed" , "single" , "auto" ) :
199- raise ValueError (
200- f"Expected one of 'distributed', 'single', or 'auto'. Got { cluster_kind } "
201- )
227+ if ( cluster_kind := _get_cluster_kind ( cluster_kind )) == "distributed" :
228+ shuffle = rapidsmpf . integrations . dask . rapidsmpf_shuffle_graph
229+ else :
230+ shuffle = rapidsmpf . integrations . single . rapidsmpf_shuffle_graph
202231
203232 df0 = df .optimize ()
204233 count_in = df0 .npartitions
@@ -218,23 +247,6 @@ def dask_cudf_shuffle(
218247 else :
219248 sort_boundary_names = ()
220249
221- if cluster_kind == "auto" :
222- try :
223- from distributed import get_client
224-
225- get_client ()
226- except (ImportError , ValueError ):
227- # Failed to import distributed/dask-cuda or find a Dask client.
228- # Use single shuffle instead.
229- cluster_kind = "single"
230- else :
231- cluster_kind = "distributed"
232-
233- if cluster_kind == "distributed" :
234- shuffle = rapidsmpf .integrations .dask .rapidsmpf_shuffle_graph
235- else :
236- shuffle = rapidsmpf .integrations .single .rapidsmpf_shuffle_graph
237-
238250 shuffle_graph_args = (
239251 name_in ,
240252 name_out ,
@@ -269,3 +281,165 @@ def dask_cudf_shuffle(
269281 )
270282 else :
271283 return shuffled
284+
285+
286+ class DaskCudfJoinIntegration :
287+ """Dask-cuDF protocol for unified join integration."""
288+
289+ @staticmethod
290+ def get_shuffler_integration () -> ShufflerIntegration [cudf .DataFrame ]:
291+ """Return the shuffler integration."""
292+ return DaskCudfIntegration ()
293+
294+ @staticmethod
295+ def join_partition (
296+ left_input : Callable [[int ], cudf .DataFrame ],
297+ right_input : Callable [[int ], cudf .DataFrame ],
298+ bcast_info : BCastJoinInfo | None ,
299+ options : Any ,
300+ ) -> cudf .DataFrame :
301+ """
302+ Produce a joined DataFrame partition.
303+
304+ Parameters
305+ ----------
306+ left_input
307+ A callable that produces chunks of the left partition.
308+ The ``bcast_info.bcast_count`` parameter corresponds
309+ to the number of chunks the callable can produce.
310+ right_input
311+ A callable that produces chunks of the right partition.
312+ The ``bcast_info.bcast_count`` parameter corresponds
313+ to the number of chunks the callable can produce.
314+ bcast_info
315+ The broadcast join information.
316+ This should be None for a regular hash join.
317+ options
318+ Additional join options.
319+
320+ Returns
321+ -------
322+ A joined DataFrame partition.
323+
324+ Notes
325+ -----
326+ This method is used to produce a single joined table chunk.
327+ """
328+ join_kwargs = {
329+ "left_on" : options ["left_on" ],
330+ "right_on" : options ["right_on" ],
331+ "how" : options ["how" ],
332+ }
333+
334+ if bcast_info is None :
335+ return left_input (0 ).merge (right_input (0 ), ** join_kwargs )
336+ else : # pragma: no cover
337+ raise NotImplementedError ("Broadcast join not implemented." )
338+
339+
340+ def dask_cudf_join (
341+ left : dask_cudf .DataFrame ,
342+ right : dask_cudf .DataFrame ,
343+ left_on : list [str ],
344+ right_on : list [str ],
345+ * ,
346+ how : Literal ["inner" , "left" , "right" ] = "inner" ,
347+ left_pre_shuffled : bool = False ,
348+ right_pre_shuffled : bool = False ,
349+ cluster_kind : Literal ["distributed" , "single" , "auto" ] = "auto" ,
350+ config_options : Options = Options (),
351+ ) -> dask_cudf .DataFrame :
352+ """
353+ Join two Dask-cuDF DataFrames with RapidsMPF.
354+
355+ Parameters
356+ ----------
357+ left
358+ Left Dask-cuDF DataFrame.
359+ right
360+ Right Dask-cuDF DataFrame.
361+ left_on
362+ Left column names to join on.
363+ right_on
364+ Right column names to join on.
365+ how
366+ The type of join to perform.
367+ Options are ``{'inner', 'left', 'right'}``.
368+ left_pre_shuffled
369+ Whether the left collection is already shuffled.
370+ right_pre_shuffled
371+ Whether the right collection is already shuffled.
372+ cluster_kind
373+ What kind of Dask cluster to use. Available
374+ options are ``{'distributed', 'single', 'auto'}``.
375+ If 'auto' (the default), 'distributed' will be
376+ used if a global Dask client is found.
377+ Note: Only ``'distributed'`` is supported for now.
378+ config_options
379+ RapidsMPF configuration options.
380+
381+ Returns
382+ -------
383+ A joined Dask-cuDF DataFrame collection.
384+
385+ Notes
386+ -----
387+ This API is currently intended for demonstration and
388+ testing purposes only.
389+ """
390+ if (cluster_kind := _get_cluster_kind (cluster_kind )) == "distributed" :
391+ from rapidsmpf .integrations .dask .join import rapidsmpf_join_graph
392+ else : # pragma: no cover
393+ # TODO: Support single-worker joins.
394+ raise NotImplementedError ("Single-worker join not implemented." )
395+
396+ left0 = left .optimize ()
397+ right0 = right .optimize ()
398+ left_partition_count_in = left0 .npartitions
399+ right_partition_count_in = right0 .npartitions
400+
401+ token = tokenize (left0 , right0 , left_on , right_on , how )
402+ left_name_in = left0 ._name
403+ right_name_in = right0 ._name
404+ name_out = f"unified-join-{ token } "
405+ graph = rapidsmpf_join_graph (
406+ left_name_in ,
407+ right_name_in ,
408+ name_out ,
409+ left_partition_count_in ,
410+ right_partition_count_in ,
411+ DaskCudfJoinIntegration (),
412+ # Options that may be used for shuffling, broadcasting,
413+ # or repartitioning the left side.
414+ {
415+ "column_names" : left0 .columns ,
416+ "on" : left_on ,
417+ },
418+ # Options that may be used for shuffling, broadcasting,
419+ # or repartitioning the right side.
420+ {
421+ "column_names" : right0 .columns ,
422+ "on" : right_on ,
423+ },
424+ # Options that may be used for joining.
425+ {
426+ "left_on" : left_on ,
427+ "right_on" : right_on ,
428+ "how" : how ,
429+ },
430+ left_pre_shuffled = left_pre_shuffled ,
431+ right_pre_shuffled = right_pre_shuffled ,
432+ config_options = config_options ,
433+ )
434+ graph .update (left0 .dask )
435+ graph .update (right0 .dask )
436+
437+ meta = left0 .merge (right0 , left_on = left_on , right_on = right_on , how = how )._meta
438+ count_out = max (left_partition_count_in , right_partition_count_in )
439+ return dd .from_graph (
440+ graph ,
441+ meta ,
442+ (None ,) * (count_out + 1 ),
443+ [(name_out , pid ) for pid in range (count_out )],
444+ "rapidsmpf" ,
445+ )
0 commit comments