66import uuid
77from abc import ABC , abstractmethod
88from enum import Enum
9- from typing import TYPE_CHECKING , Any , Callable , NamedTuple
9+ from typing import TYPE_CHECKING , Any , Callable , NamedTuple , Tuple
1010
1111import boto3
1212import pandas as pd
3535}
3636
3737
38+ def _compose_filename_prefix_for_mode (* , mode : str , filename_prefix : str = None ) -> Tuple [str , str ]:
39+ if mode == "overwrite_files" :
40+ if filename_prefix is None :
41+ filename_prefix = "part"
42+ random_filename_suffix = ""
43+ mode = "append"
44+ else :
45+ random_filename_suffix = uuid .uuid4 ().hex
46+
47+ if filename_prefix is None :
48+ filename_prefix = ""
49+ filename_prefix = filename_prefix + random_filename_suffix
50+ return filename_prefix , mode
51+
52+
3853def _extract_dtypes_from_table_input (table_input : dict [str , Any ]) -> dict [str , str ]:
3954 dtypes : dict [str , str ] = {}
4055 for col in table_input ["StorageDescriptor" ]["Columns" ]:
@@ -71,6 +86,7 @@ def _validate_args(
7186 parameters : dict [str , str ] | None ,
7287 columns_comments : dict [str , str ] | None ,
7388 columns_parameters : dict [str , dict [str , str ]] | None ,
89+ max_rows_by_file : int | None ,
7490 execution_engine : Enum ,
7591) -> None :
7692 if df .empty is True :
@@ -88,6 +104,10 @@ def _validate_args(
88104 raise exceptions .InvalidArgumentCombination ("Please, pass dataset=True to be able to use bucketing_info." )
89105 if mode is not None :
90106 raise exceptions .InvalidArgumentCombination ("Please pass dataset=True to be able to use mode." )
107+ if mode == "overwrite_files" and (max_rows_by_file or bucketing_info ):
108+ raise exceptions .InvalidArgumentValue (
109+ "When mode is set to 'overwrite_files', the "
110+ "`max_rows_by_file` and `bucketing_info` arguments cannot be set." )
91111 if any (arg is not None for arg in (table , description , parameters , columns_comments , columns_parameters )):
92112 raise exceptions .InvalidArgumentCombination (
93113 "Please pass dataset=True to be able to use any one of these "
@@ -278,20 +298,8 @@ def write( # noqa: PLR0913
278298 dtype = dtype if dtype else {}
279299 partitions_values : dict [str , list [str ]] = {}
280300
281- if mode == "overwrite_files" :
282- assert max_rows_by_file in [None , 0 ]
283-
284- if filename_prefix is None :
285- filename_prefix = "part"
286- random_filename_suffix = ""
287- mode = "append"
288- else :
289- random_filename_suffix = uuid .uuid4 ().hex
290-
291- if filename_prefix is None :
292- filename_prefix = ""
293- filename_prefix = filename_prefix + random_filename_suffix
294-
301+ mode , filename_prefix = _compose_filename_prefix_for_mode (
302+ mode = mode , filename_prefix = filename_prefix )
295303 mode = "append" if mode is None else mode
296304 cpus : int = _utils .ensure_cpu_count (use_threads = use_threads )
297305 s3_client = _utils .client (service_name = "s3" , session = boto3_session )
0 commit comments