11"""Amazon S3 Write Dataset (PRIVATE)."""
22
33import logging
4+ from functools import singledispatch
45from typing import Any , Callable , Dict , List , Optional , Tuple , Union
56
67import boto3
1314
1415if config .distributed :
1516 import modin .pandas as pd
16- from modin .distributed .dataframe .pandas import from_partitions , unwrap_partitions
1717 from modin .pandas import DataFrame as ModinDataFrame
1818else :
1919 import pandas as pd
2020
2121_logger : logging .Logger = logging .getLogger (__name__ )
2222
2323
24- def _get_subgroup_prefix (keys : Tuple [str , None ], partition_cols : List [str ], path_root : str ) -> str :
25- subdir = "/" .join ([f"{ name } ={ val } " for name , val in zip (partition_cols , keys )])
26- return f"{ path_root } { subdir } /"
27-
28-
2924def _get_bucketing_series (df : pd .DataFrame , bucketing_info : Tuple [List [str ], int ]) -> pd .Series :
3025 bucket_number_series = df .astype ("O" ).apply (
3126 lambda row : _get_bucket_number (bucketing_info [1 ], [row [col_name ] for col_name in bucketing_info [0 ]]),
@@ -75,6 +70,11 @@ def _get_value_hash(value: Union[str, int, bool]) -> int:
7570 )
7671
7772
73+ def _get_subgroup_prefix (keys : Tuple [str , None ], partition_cols : List [str ], path_root : str ) -> str :
74+ subdir = "/" .join ([f"{ name } ={ val } " for name , val in zip (partition_cols , keys )])
75+ return f"{ path_root } { subdir } /"
76+
77+
7878def _delete_objects (
7979 keys : Tuple [str , None ],
8080 path_root : str ,
@@ -168,7 +168,7 @@ def _write_partitions_distributed(
168168 )
169169 else :
170170 paths = write_func ( # type: ignore
171- df = df_group .drop (partition_cols , axis = "columns" ),
171+ df_group .drop (partition_cols , axis = "columns" ),
172172 path_root = prefix ,
173173 filename_prefix = filename_prefix ,
174174 boto3_session = boto3_session ,
@@ -178,10 +178,11 @@ def _write_partitions_distributed(
178178 return prefix , df_group .name , paths
179179
180180
181+ @singledispatch
181182def _to_partitions (
183+ df : pd .DataFrame ,
182184 func : Callable [..., List [str ]],
183185 concurrent_partitioning : bool ,
184- df : pd .DataFrame ,
185186 path_root : str ,
186187 use_threads : Union [bool , int ],
187188 mode : str ,
@@ -221,8 +222,8 @@ def _to_partitions(
221222 )
222223 if bucketing_info :
223224 _to_buckets (
225+ subgroup ,
224226 func = func ,
225- df = subgroup ,
226227 path_root = prefix ,
227228 bucketing_info = bucketing_info ,
228229 boto3_session = boto3_session ,
@@ -233,11 +234,11 @@ def _to_partitions(
233234 )
234235 else :
235236 proxy .write (
236- func = func ,
237- df = subgroup ,
237+ func ,
238+ boto3_session ,
239+ subgroup ,
238240 path_root = prefix ,
239241 filename_prefix = filename_prefix ,
240- boto3_session = boto3_session ,
241242 use_threads = use_threads ,
242243 ** func_kwargs ,
243244 )
@@ -247,9 +248,9 @@ def _to_partitions(
247248
248249
249250def _to_partitions_distributed ( # pylint: disable=unused-argument
251+ df : pd .DataFrame ,
250252 func : Callable [..., List [str ]],
251253 concurrent_partitioning : bool ,
252- df : pd .DataFrame ,
253254 path_root : str ,
254255 use_threads : Union [bool , int ],
255256 mode : str ,
@@ -283,7 +284,7 @@ def _to_partitions_distributed( # pylint: disable=unused-argument
283284 boto3_session = None ,
284285 ** func_kwargs ,
285286 )
286- paths : List [str ] = [path for metadata in df_write_metadata .values for _ , _ , path in metadata ]
287+ paths : List [str ] = [path for metadata in df_write_metadata .values for _ , _ , paths in metadata for path in paths ]
287288 partitions_values : Dict [str , List [str ]] = {
288289 prefix : list (str (p ) for p in partitions ) if isinstance (partitions , tuple ) else [str (partitions )]
289290 for metadata in df_write_metadata .values
@@ -292,9 +293,10 @@ def _to_partitions_distributed( # pylint: disable=unused-argument
292293 return paths , partitions_values
293294
294295
296+ @singledispatch
295297def _to_buckets (
296- func : Callable [..., List [str ]],
297298 df : pd .DataFrame ,
299+ func : Callable [..., List [str ]],
298300 path_root : str ,
299301 bucketing_info : Tuple [List [str ], int ],
300302 filename_prefix : str ,
@@ -307,11 +309,11 @@ def _to_buckets(
307309 df_groups = df .groupby (by = _get_bucketing_series (df = df , bucketing_info = bucketing_info ))
308310 for bucket_number , subgroup in df_groups :
309311 _proxy .write (
310- func = func ,
311- df = subgroup ,
312+ func ,
313+ boto3_session ,
314+ subgroup ,
312315 path_root = path_root ,
313316 filename_prefix = f"{ filename_prefix } _bucket-{ bucket_number :05d} " ,
314- boto3_session = boto3_session ,
315317 use_threads = use_threads ,
316318 ** func_kwargs ,
317319 )
@@ -322,8 +324,8 @@ def _to_buckets(
322324
323325
324326def _to_buckets_distributed ( # pylint: disable=unused-argument
325- func : Callable [..., List [str ]],
326327 df : pd .DataFrame ,
328+ func : Callable [..., List [str ]],
327329 path_root : str ,
328330 bucketing_info : Tuple [List [str ], int ],
329331 filename_prefix : str ,
@@ -335,7 +337,7 @@ def _to_buckets_distributed( # pylint: disable=unused-argument
335337 df_groups = df .groupby (by = _get_bucketing_series (df = df , bucketing_info = bucketing_info ))
336338 paths : List [str ] = []
337339 df_paths = df_groups .apply (
338- func ,
340+ func . dispatch ( ModinDataFrame ), # type: ignore
339341 path_root = path_root ,
340342 filename_prefix = filename_prefix ,
341343 boto3_session = None ,
@@ -398,24 +400,14 @@ def _to_dataset(
398400 else :
399401 delete_objects (path = path_root , use_threads = use_threads , boto3_session = boto3_session )
400402
401- _to_partitions_fn : Callable [..., Tuple [List [str ], Dict [str , List [str ]]]] = _to_partitions
402- _to_buckets_fn : Callable [..., List [str ]] = _to_buckets
403- if config .distributed and isinstance (df , ModinDataFrame ):
404- # Ensure Modin dataframe is partitioned along row axis
405- # It avoids a situation where columns are split along multiple blocks
406- df = from_partitions (unwrap_partitions (df , axis = 0 ), axis = 0 )
407-
408- _to_partitions_fn = _to_partitions_distributed
409- _to_buckets_fn = _to_buckets_distributed
410-
411403 # Writing
412404 partitions_values : Dict [str , List [str ]] = {}
413405 paths : List [str ]
414406 if partition_cols :
415- paths , partitions_values = _to_partitions_fn (
407+ paths , partitions_values = _to_partitions (
408+ df ,
416409 func = func ,
417410 concurrent_partitioning = concurrent_partitioning ,
418- df = df ,
419411 path_root = path_root ,
420412 use_threads = use_threads ,
421413 mode = mode ,
@@ -433,9 +425,9 @@ def _to_dataset(
433425 ** func_kwargs ,
434426 )
435427 elif bucketing_info :
436- paths = _to_buckets_fn (
428+ paths = _to_buckets (
429+ df ,
437430 func = func ,
438- df = df ,
439431 path_root = path_root ,
440432 use_threads = use_threads ,
441433 bucketing_info = bucketing_info ,
@@ -446,7 +438,7 @@ def _to_dataset(
446438 )
447439 else :
448440 paths = func (
449- df = df ,
441+ df ,
450442 path_root = path_root ,
451443 filename_prefix = filename_prefix ,
452444 use_threads = use_threads ,
0 commit comments