Skip to content

Commit db079df

Browse files
Improve bucketing series generation by casting only the required columns (#1664)
Co-authored-by: Lucas Hanson <[email protected]>
1 parent 1563caa commit db079df

File tree

1 file changed

+54
-47
lines changed

1 file changed

+54
-47
lines changed

awswrangler/s3/_write_dataset.py

Lines changed: 54 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,59 @@
1414
_logger: logging.Logger = logging.getLogger(__name__)
1515

1616

17+
def _get_bucketing_series(df: pd.DataFrame, bucketing_info: Tuple[List[str], int]) -> pd.Series:
18+
bucket_number_series = (
19+
df[bucketing_info[0]]
20+
# Prevent "upcasting" mixed types by casting to object
21+
.astype("O").apply(
22+
lambda row: _get_bucket_number(bucketing_info[1], [row[col_name] for col_name in bucketing_info[0]]),
23+
axis="columns",
24+
)
25+
)
26+
return bucket_number_series.astype(pd.CategoricalDtype(range(bucketing_info[1])))
27+
28+
29+
def _simulate_overflow(value: int, bits: int = 31, signed: bool = False) -> int:
30+
base = 1 << bits
31+
value %= base
32+
return value - base if signed and value.bit_length() == bits else value
33+
34+
35+
def _get_bucket_number(number_of_buckets: int, values: List[Union[str, int, bool]]) -> int:
36+
hash_code = 0
37+
for value in values:
38+
hash_code = 31 * hash_code + _get_value_hash(value)
39+
hash_code = _simulate_overflow(hash_code)
40+
41+
return hash_code % number_of_buckets
42+
43+
44+
def _get_value_hash(value: Union[str, int, bool]) -> int:
45+
if isinstance(value, (int, np.int_)):
46+
value = int(value)
47+
bigint_min, bigint_max = -(2**63), 2**63 - 1
48+
int_min, int_max = -(2**31), 2**31 - 1
49+
if not bigint_min <= value <= bigint_max:
50+
raise ValueError(f"{value} exceeds the range that Athena cannot handle as bigint.")
51+
if not int_min <= value <= int_max:
52+
value = (value >> 32) ^ value
53+
if value < 0:
54+
return -value - 1
55+
return int(value)
56+
if isinstance(value, (str, np.str_)):
57+
value_hash = 0
58+
for byte in value.encode():
59+
value_hash = value_hash * 31 + byte
60+
value_hash = _simulate_overflow(value_hash)
61+
return value_hash
62+
if isinstance(value, (bool, np.bool_)):
63+
return int(value)
64+
65+
raise exceptions.InvalidDataFrame(
66+
"Column specified for bucketing contains invalid data type. Only string, int and bool are supported."
67+
)
68+
69+
1770
def _to_partitions(
1871
func: Callable[..., List[str]],
1972
concurrent_partitioning: bool,
@@ -110,12 +163,7 @@ def _to_buckets(
110163
**func_kwargs: Any,
111164
) -> List[str]:
112165
_proxy: _WriteProxy = proxy if proxy else _WriteProxy(use_threads=False)
113-
bucket_number_series = df.astype("O").apply(
114-
lambda row: _get_bucket_number(bucketing_info[1], [row[col_name] for col_name in bucketing_info[0]]),
115-
axis="columns",
116-
)
117-
bucket_number_series = bucket_number_series.astype(pd.CategoricalDtype(range(bucketing_info[1])))
118-
for bucket_number, subgroup in df.groupby(by=bucket_number_series, observed=False):
166+
for bucket_number, subgroup in df.groupby(by=_get_bucketing_series(df=df, bucketing_info=bucketing_info)):
119167
_proxy.write(
120168
func=func,
121169
df=subgroup,
@@ -132,47 +180,6 @@ def _to_buckets(
132180
return paths
133181

134182

135-
def _simulate_overflow(value: int, bits: int = 31, signed: bool = False) -> int:
136-
base = 1 << bits
137-
value %= base
138-
return value - base if signed and value.bit_length() == bits else value
139-
140-
141-
def _get_bucket_number(number_of_buckets: int, values: List[Union[str, int, bool]]) -> int:
142-
hash_code = 0
143-
for value in values:
144-
hash_code = 31 * hash_code + _get_value_hash(value)
145-
hash_code = _simulate_overflow(hash_code)
146-
147-
return hash_code % number_of_buckets
148-
149-
150-
def _get_value_hash(value: Union[str, int, bool]) -> int:
151-
if isinstance(value, (int, np.int_)):
152-
value = int(value)
153-
bigint_min, bigint_max = -(2**63), 2**63 - 1
154-
int_min, int_max = -(2**31), 2**31 - 1
155-
if not bigint_min <= value <= bigint_max:
156-
raise ValueError(f"{value} exceeds the range that Athena cannot handle as bigint.")
157-
if not int_min <= value <= int_max:
158-
value = (value >> 32) ^ value
159-
if value < 0:
160-
return -value - 1
161-
return int(value)
162-
if isinstance(value, (str, np.str_)):
163-
value_hash = 0
164-
for byte in value.encode():
165-
value_hash = value_hash * 31 + byte
166-
value_hash = _simulate_overflow(value_hash)
167-
return value_hash
168-
if isinstance(value, (bool, np.bool_)):
169-
return int(value)
170-
171-
raise exceptions.InvalidDataFrame(
172-
"Column specified for bucketing contains invalid data type. Only string, int and bool are supported."
173-
)
174-
175-
176183
def _to_dataset(
177184
func: Callable[..., List[str]],
178185
concurrent_partitioning: bool,

0 commit comments

Comments
 (0)