Skip to content

Commit 8ab7858

Browse files
authored
Add Dask integration for a join operation (rapidsai#517)
In order to leverage the recent `AllGather` work within Dask-based execution in multi-GPU Polars, we need a new integration API in RapidsMPF. Although this new API **could** simply specialize in broadcast joins, it seems reasonable to introduce a general join integration that handles any realistic combination of shuffling/broadcasting the child tables. Even if we are not performing a broadcast join, it may be beneficial to extract left- and right-hand partitions within the same task that is performing the local join. This way, we don't leave un-spillable data on the worker between the shuffle-extraction task(s) and the corresponding join task. This PR adds a general integration API for join operations. It includes some of the pieces that will be needed for broadcast-based joins. However, it only implements the hash-join code path for now (to limit the size of the diff). Note that this PR does not implement "single-worker" execution either. Both broadcast joins and single-worker execution will be addressed in follow-up work. Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Tom Augspurger (https://github.com/TomAugspurger) URL: rapidsai#517
1 parent a61eb7e commit 8ab7858

File tree

5 files changed

+864
-79
lines changed

5 files changed

+864
-79
lines changed

python/rapidsmpf/rapidsmpf/examples/dask.py

Lines changed: 195 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,17 @@
2626
from rapidsmpf.utils.cudf import cudf_to_pylibcudf_table
2727

2828
if TYPE_CHECKING:
29+
from collections.abc import Callable
2930
from typing import Any
3031

3132
import dask_cudf
3233

3334
import cudf
3435

36+
from rapidsmpf.integrations.core import (
37+
BCastJoinInfo,
38+
ShufflerIntegration,
39+
)
3540
from rapidsmpf.shuffler import Shuffler
3641

3742

@@ -153,6 +158,30 @@ def extract_partition(
153158
)
154159

155160

161+
def _get_cluster_kind(
162+
cluster_kind: Literal["distributed", "single", "auto"],
163+
) -> Literal["distributed", "single"]:
164+
"""Validate and return the kind of cluster to use."""
165+
if cluster_kind not in ("distributed", "single", "auto"):
166+
raise ValueError(
167+
f"Expected one of 'distributed', 'single', or 'auto'. Got {cluster_kind}"
168+
)
169+
170+
if cluster_kind == "auto":
171+
try:
172+
from distributed import get_client
173+
174+
get_client()
175+
except (ImportError, ValueError):
176+
# Failed to import distributed/dask-cuda or find a Dask client.
177+
# Use single shuffle instead.
178+
cluster_kind = "single"
179+
else:
180+
cluster_kind = "distributed"
181+
182+
return cluster_kind
183+
184+
156185
def dask_cudf_shuffle(
157186
df: dask_cudf.DataFrame,
158187
on: list[str],
@@ -195,10 +224,10 @@ def dask_cudf_shuffle(
195224
This API is currently intended for demonstration and
196225
testing purposes only.
197226
"""
198-
if cluster_kind not in ("distributed", "single", "auto"):
199-
raise ValueError(
200-
f"Expected one of 'distributed', 'single', or 'auto'. Got {cluster_kind}"
201-
)
227+
if (cluster_kind := _get_cluster_kind(cluster_kind)) == "distributed":
228+
shuffle = rapidsmpf.integrations.dask.rapidsmpf_shuffle_graph
229+
else:
230+
shuffle = rapidsmpf.integrations.single.rapidsmpf_shuffle_graph
202231

203232
df0 = df.optimize()
204233
count_in = df0.npartitions
@@ -218,23 +247,6 @@ def dask_cudf_shuffle(
218247
else:
219248
sort_boundary_names = ()
220249

221-
if cluster_kind == "auto":
222-
try:
223-
from distributed import get_client
224-
225-
get_client()
226-
except (ImportError, ValueError):
227-
# Failed to import distributed/dask-cuda or find a Dask client.
228-
# Use single shuffle instead.
229-
cluster_kind = "single"
230-
else:
231-
cluster_kind = "distributed"
232-
233-
if cluster_kind == "distributed":
234-
shuffle = rapidsmpf.integrations.dask.rapidsmpf_shuffle_graph
235-
else:
236-
shuffle = rapidsmpf.integrations.single.rapidsmpf_shuffle_graph
237-
238250
shuffle_graph_args = (
239251
name_in,
240252
name_out,
@@ -269,3 +281,165 @@ def dask_cudf_shuffle(
269281
)
270282
else:
271283
return shuffled
284+
285+
286+
class DaskCudfJoinIntegration:
287+
"""Dask-cuDF protocol for unified join integration."""
288+
289+
@staticmethod
290+
def get_shuffler_integration() -> ShufflerIntegration[cudf.DataFrame]:
291+
"""Return the shuffler integration."""
292+
return DaskCudfIntegration()
293+
294+
@staticmethod
295+
def join_partition(
296+
left_input: Callable[[int], cudf.DataFrame],
297+
right_input: Callable[[int], cudf.DataFrame],
298+
bcast_info: BCastJoinInfo | None,
299+
options: Any,
300+
) -> cudf.DataFrame:
301+
"""
302+
Produce a joined DataFrame partition.
303+
304+
Parameters
305+
----------
306+
left_input
307+
A callable that produces chunks of the left partition.
308+
The ``bcast_info.bcast_count`` parameter corresponds
309+
to the number of chunks the callable can produce.
310+
right_input
311+
A callable that produces chunks of the right partition.
312+
The ``bcast_info.bcast_count`` parameter corresponds
313+
to the number of chunks the callable can produce.
314+
bcast_info
315+
The broadcast join information.
316+
This should be None for a regular hash join.
317+
options
318+
Additional join options.
319+
320+
Returns
321+
-------
322+
A joined DataFrame partition.
323+
324+
Notes
325+
-----
326+
This method is used to produce a single joined table chunk.
327+
"""
328+
join_kwargs = {
329+
"left_on": options["left_on"],
330+
"right_on": options["right_on"],
331+
"how": options["how"],
332+
}
333+
334+
if bcast_info is None:
335+
return left_input(0).merge(right_input(0), **join_kwargs)
336+
else: # pragma: no cover
337+
raise NotImplementedError("Broadcast join not implemented.")
338+
339+
340+
def dask_cudf_join(
341+
left: dask_cudf.DataFrame,
342+
right: dask_cudf.DataFrame,
343+
left_on: list[str],
344+
right_on: list[str],
345+
*,
346+
how: Literal["inner", "left", "right"] = "inner",
347+
left_pre_shuffled: bool = False,
348+
right_pre_shuffled: bool = False,
349+
cluster_kind: Literal["distributed", "single", "auto"] = "auto",
350+
config_options: Options = Options(),
351+
) -> dask_cudf.DataFrame:
352+
"""
353+
Join two Dask-cuDF DataFrames with RapidsMPF.
354+
355+
Parameters
356+
----------
357+
left
358+
Left Dask-cuDF DataFrame.
359+
right
360+
Right Dask-cuDF DataFrame.
361+
left_on
362+
Left column names to join on.
363+
right_on
364+
Right column names to join on.
365+
how
366+
The type of join to perform.
367+
Options are ``{'inner', 'left', 'right'}``.
368+
left_pre_shuffled
369+
Whether the left collection is already shuffled.
370+
right_pre_shuffled
371+
Whether the right collection is already shuffled.
372+
cluster_kind
373+
What kind of Dask cluster to use. Available
374+
options are ``{'distributed', 'single', 'auto'}``.
375+
If 'auto' (the default), 'distributed' will be
376+
used if a global Dask client is found.
377+
Note: Only ``'distributed'`` is supported for now.
378+
config_options
379+
RapidsMPF configuration options.
380+
381+
Returns
382+
-------
383+
A joined Dask-cuDF DataFrame collection.
384+
385+
Notes
386+
-----
387+
This API is currently intended for demonstration and
388+
testing purposes only.
389+
"""
390+
if (cluster_kind := _get_cluster_kind(cluster_kind)) == "distributed":
391+
from rapidsmpf.integrations.dask.join import rapidsmpf_join_graph
392+
else: # pragma: no cover
393+
# TODO: Support single-worker joins.
394+
raise NotImplementedError("Single-worker join not implemented.")
395+
396+
left0 = left.optimize()
397+
right0 = right.optimize()
398+
left_partition_count_in = left0.npartitions
399+
right_partition_count_in = right0.npartitions
400+
401+
token = tokenize(left0, right0, left_on, right_on, how)
402+
left_name_in = left0._name
403+
right_name_in = right0._name
404+
name_out = f"unified-join-{token}"
405+
graph = rapidsmpf_join_graph(
406+
left_name_in,
407+
right_name_in,
408+
name_out,
409+
left_partition_count_in,
410+
right_partition_count_in,
411+
DaskCudfJoinIntegration(),
412+
# Options that may be used for shuffling, broadcasting,
413+
# or repartitioning the left side.
414+
{
415+
"column_names": left0.columns,
416+
"on": left_on,
417+
},
418+
# Options that may be used for shuffling, broadcasting,
419+
# or repartitioning the right side.
420+
{
421+
"column_names": right0.columns,
422+
"on": right_on,
423+
},
424+
# Options that may be used for joining.
425+
{
426+
"left_on": left_on,
427+
"right_on": right_on,
428+
"how": how,
429+
},
430+
left_pre_shuffled=left_pre_shuffled,
431+
right_pre_shuffled=right_pre_shuffled,
432+
config_options=config_options,
433+
)
434+
graph.update(left0.dask)
435+
graph.update(right0.dask)
436+
437+
meta = left0.merge(right0, left_on=left_on, right_on=right_on, how=how)._meta
438+
count_out = max(left_partition_count_in, right_partition_count_in)
439+
return dd.from_graph(
440+
graph,
441+
meta,
442+
(None,) * (count_out + 1),
443+
[(name_out, pid) for pid in range(count_out)],
444+
"rapidsmpf",
445+
)

0 commit comments

Comments
 (0)