Skip to content

Commit e578a24

Browse files
committed
Fix Athena ctas_approach issue with immutability. #335
1 parent 98b801d commit e578a24

File tree

6 files changed

+24
-15
lines changed

6 files changed

+24
-15
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
1010

1111
[![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/)
12-
[![Coverage](https://img.shields.io/badge/coverage-92%25-brightgreen.svg)](https://pypi.org/project/awswrangler/)
12+
[![Coverage](https://img.shields.io/badge/coverage-93%25-brightgreen.svg)](https://pypi.org/project/awswrangler/)
1313
![Static Checking](https://github.com/awslabs/aws-data-wrangler/workflows/Static%20Checking/badge.svg?branch=master)
1414
[![Documentation Status](https://readthedocs.org/projects/aws-data-wrangler/badge/?version=latest)](https://aws-data-wrangler.readthedocs.io/?badge=latest)
1515

awswrangler/_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,12 +257,12 @@ def list_sampling(lst: List[Any], sampling: float) -> List[Any]:
257257

258258
def ensure_df_is_mutable(df: pd.DataFrame) -> pd.DataFrame:
259259
"""Ensure that all columns has the writeable flag True."""
260-
columns: List[str] = df.columns.to_list()
261-
for column in columns:
260+
for column in df.columns.to_list():
262261
if hasattr(df[column].values, "flags") is True:
263262
if df[column].values.flags.writeable is False:
264-
df = df.copy(deep=True)
265-
break
263+
s: pd.Series = df[column]
264+
df[column] = None
265+
df[column] = s
266266
return df
267267

268268

awswrangler/s3/_read_parquet.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _arrowtable2df(
169169
path: str,
170170
path_root: Optional[str],
171171
) -> pd.DataFrame:
172-
return _apply_partitions(
172+
df: pd.DataFrame = _apply_partitions(
173173
df=table.to_pandas(
174174
use_threads=use_threads,
175175
split_blocks=True,
@@ -185,6 +185,7 @@ def _arrowtable2df(
185185
path=path,
186186
path_root=path_root,
187187
)
188+
return _utils.ensure_df_is_mutable(df=df)
188189

189190

190191
def _read_parquet_chunked(
@@ -254,7 +255,7 @@ def _read_parquet_chunked(
254255
yield next_slice
255256

256257

257-
def _read_parquet_file_single_thread(
258+
def _read_parquet_file(
258259
path: str,
259260
columns: Optional[List[str]],
260261
categories: Optional[List[str]],
@@ -285,7 +286,7 @@ def _count_row_groups(
285286
return pq_file.num_row_groups
286287

287288

288-
def _read_parquet_file_multi_thread(
289+
def _read_parquet_row_group(
289290
row_group: int,
290291
path: str,
291292
columns: Optional[List[str]],
@@ -306,7 +307,7 @@ def _read_parquet_file_multi_thread(
306307
return pq_file.read_row_group(i=row_group, columns=columns, use_threads=False, use_pandas_metadata=False)
307308

308309

309-
def _read_parquet_file(
310+
def _read_parquet(
310311
path: str,
311312
columns: Optional[List[str]],
312313
categories: Optional[List[str]],
@@ -318,7 +319,7 @@ def _read_parquet_file(
318319
use_threads: bool,
319320
) -> pd.DataFrame:
320321
if use_threads is False:
321-
table: pa.Table = _read_parquet_file_single_thread(
322+
table: pa.Table = _read_parquet_file(
322323
path=path,
323324
columns=columns,
324325
categories=categories,
@@ -333,7 +334,7 @@ def _read_parquet_file(
333334
with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor:
334335
tables: Tuple[pa.Table, ...] = tuple(
335336
executor.map(
336-
_read_parquet_file_multi_thread,
337+
_read_parquet_row_group,
337338
range(num_row_groups),
338339
itertools.repeat(path),
339340
itertools.repeat(columns),
@@ -529,7 +530,7 @@ def read_parquet(
529530
if chunked is not False:
530531
return _read_parquet_chunked(paths=paths, chunked=chunked, validate_schema=validate_schema, **args)
531532
if len(paths) == 1:
532-
return _read_parquet_file(path=paths[0], **args)
533+
return _read_parquet(path=paths[0], **args)
533534
if validate_schema is True:
534535
_validate_schemas_from_files(
535536
paths=paths,
@@ -540,8 +541,8 @@ def read_parquet(
540541
)
541542
if use_threads is True:
542543
args["use_threads"] = True
543-
return _read_concurrent(func=_read_parquet_file, ignore_index=True, paths=paths, **args)
544-
return _union(dfs=[_read_parquet_file(path=p, **args) for p in paths], ignore_index=True)
544+
return _read_concurrent(func=_read_parquet, ignore_index=True, paths=paths, **args)
545+
return _union(dfs=[_read_parquet(path=p, **args) for p in paths], ignore_index=True)
545546

546547

547548
@apply_configs

test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ microtime() {
77

88
START=$(microtime)
99

10-
./validation.sh
10+
./validate.sh
1111
tox -e ALL
1212
coverage html --directory coverage
1313
rm -rf .coverage* Running

tests/test_athena_parquet.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,3 +424,11 @@ def test_read_parquet_filter_partitions(path, glue_table, glue_database, use_thr
424424
assert df2.c0.iloc[0] == i
425425
assert df2.c1.iloc[0] == i
426426
assert df2.c2.iloc[0] == i
427+
428+
429+
@pytest.mark.parametrize("use_threads", [True, False])
430+
def test_read_parquet_mutability(path, glue_table, glue_database, use_threads):
431+
sql = "SELECT timestamp '2012-08-08 01:00' AS c0"
432+
df = wr.athena.read_sql_query(sql, "default", use_threads=use_threads)
433+
df["c0"] = df["c0"] + pd.DateOffset(months=-2)
434+
assert df.c0[0].value == 1339117200000000000
File renamed without changes.

0 commit comments

Comments
 (0)