Skip to content

Commit c99dca2

Browse files
committed
Makes ctas_approach works even with eventual consistency issues
1 parent 9903c9e commit c99dca2

File tree

6 files changed

+210
-36
lines changed

6 files changed

+210
-36
lines changed

awswrangler/athena.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
class Athena:
1717
def __init__(self, session):
1818
self._session = session
19-
self._client_athena = session.boto3_session.client(service_name="athena", config=session.botocore_config)
19+
self._client_athena = session.boto3_session.client(service_name="athena",
20+
use_ssl=True,
21+
config=session.botocore_config)
22+
self._client_s3 = session.boto3_session.client(service_name="s3", use_ssl=True, config=session.botocore_config)
2023

2124
def get_query_columns_metadata(self, query_execution_id: str) -> Dict[str, str]:
2225
"""
@@ -256,3 +259,14 @@ def normalize_table_name(name):
256259
:return: normalized table name (str)
257260
"""
258261
return Athena._normalize_name(name=name)
262+
263+
@staticmethod
264+
def _parse_path(path):
265+
path2 = path.replace("s3://", "")
266+
parts = path2.partition("/")
267+
return parts[0], parts[2]
268+
269+
def extract_manifest_paths(self, path: str) -> List[str]:
270+
bucket_name, key_path = self._parse_path(path)
271+
body: bytes = self._client_s3.get_object(Bucket=bucket_name, Key=key_path)["Body"].read()
272+
return [x for x in body.decode('utf-8').split("\n") if x != ""]

awswrangler/pandas.py

Lines changed: 130 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def read_sql_athena(self,
499499
workgroup: Optional[str] = None,
500500
encryption: Optional[str] = None,
501501
kms_key: Optional[str] = None,
502-
ctas_approach: bool = False,
502+
ctas_approach: bool = None,
503503
procs_cpu_bound: Optional[int] = None,
504504
max_result_size: Optional[int] = None):
505505
"""
@@ -523,11 +523,12 @@ def read_sql_athena(self,
523523
:param workgroup: The name of the workgroup in which the query is being started. (By default uses de Session() workgroup)
524524
:param encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
525525
:param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
526-
:param ctas_approach: Wraps the query with a CTAS
526+
:param ctas_approach: Wraps the query with a CTAS (Session's deafult is False)
527527
:param procs_cpu_bound: Number of cores used for CPU bound tasks
528528
:param max_result_size: Max number of bytes on each request to S3 (VALID ONLY FOR ctas_approach=False)
529529
:return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size was passed
530530
"""
531+
ctas_approach = ctas_approach if ctas_approach is not None else self._session.ctas_approach if self._session.ctas_approach is not None else False
531532
if ctas_approach is True and max_result_size is not None:
532533
raise InvalidParameters("ctas_approach can't use max_result_size!")
533534
if s3_output is None:
@@ -580,7 +581,10 @@ def _read_sql_athena_ctas(self,
580581
kms_key=kms_key)
581582
self._session.athena.wait_query(query_execution_id=query_id)
582583
self._session.glue.delete_table_if_exists(database=database, table=name)
583-
return self.read_parquet(path=path, procs_cpu_bound=procs_cpu_bound)
584+
manifest_path: str = f"{s3_output}/tables/{query_id}-manifest.csv"
585+
paths: List[str] = self._session.athena.extract_manifest_paths(path=manifest_path)
586+
logger.debug(f"paths: {paths}")
587+
return self.read_parquet(path=paths, procs_cpu_bound=procs_cpu_bound)
584588

585589
def _read_sql_athena_regular(self,
586590
sql: str,
@@ -1209,30 +1213,150 @@ def drop_duplicated_columns(dataframe: pd.DataFrame, inplace: bool = True) -> pd
12091213
return dataframe.loc[:, ~duplicated_cols]
12101214

12111215
def read_parquet(self,
1212-
path: str,
1216+
path: Union[str, List[str]],
12131217
columns: Optional[List[str]] = None,
12141218
filters: Optional[Union[List[Tuple[Any]], List[Tuple[Any]]]] = None,
12151219
procs_cpu_bound: Optional[int] = None) -> pd.DataFrame:
12161220
"""
12171221
Read parquet data from S3
12181222
1223+
:param path: AWS S3 path or List of paths (E.g. s3://bucket-name/folder_name/)
1224+
:param columns: Names of columns to read from the file
1225+
:param filters: List of filters to apply, like ``[[('x', '=', 0), ...], ...]``.
1226+
:param procs_cpu_bound: Number of cores used for CPU bound tasks
1227+
"""
1228+
procs_cpu_bound = procs_cpu_bound if procs_cpu_bound is not None else self._session.procs_cpu_bound if self._session.procs_cpu_bound is not None else 1
1229+
logger.debug(f"procs_cpu_bound: {procs_cpu_bound}")
1230+
df: Optional[pd.DataFrame] = None
1231+
session_primitives = self._session.primitives
1232+
path = [path] if type(path) == str else path # type: ignore
1233+
bounders = calculate_bounders(len(path), procs_cpu_bound)
1234+
logger.debug(f"len(bounders): {len(bounders)}")
1235+
if len(bounders) == 1:
1236+
df = Pandas._read_parquet_paths(session_primitives=session_primitives,
1237+
path=path,
1238+
columns=columns,
1239+
filters=filters,
1240+
procs_cpu_bound=procs_cpu_bound)
1241+
else:
1242+
procs = []
1243+
receive_pipes = []
1244+
for bounder in bounders:
1245+
receive_pipe, send_pipe = mp.Pipe()
1246+
logger.debug(f"bounder: {bounder}")
1247+
proc = mp.Process(
1248+
target=self._read_parquet_paths_remote,
1249+
args=(
1250+
send_pipe,
1251+
session_primitives,
1252+
path[bounder[0]:bounder[1]],
1253+
columns,
1254+
filters,
1255+
1 # procs_cpu_bound
1256+
),
1257+
)
1258+
proc.daemon = False
1259+
proc.start()
1260+
procs.append(proc)
1261+
receive_pipes.append(receive_pipe)
1262+
logger.debug(f"len(procs): {len(bounders)}")
1263+
for i in range(len(procs)):
1264+
logger.debug(f"Waiting pipe number: {i}")
1265+
df_received = receive_pipes[i].recv()
1266+
if df is None:
1267+
df = df_received
1268+
else:
1269+
df = pd.concat(objs=[df, df_received], ignore_index=True)
1270+
logger.debug(f"Waiting proc number: {i}")
1271+
procs[i].join()
1272+
logger.debug(f"Closing proc number: {i}")
1273+
receive_pipes[i].close()
1274+
return df
1275+
1276+
@staticmethod
1277+
def _read_parquet_paths_remote(send_pipe: mp.connection.Connection,
1278+
session_primitives: Any,
1279+
path: Union[str, List[str]],
1280+
columns: Optional[List[str]] = None,
1281+
filters: Optional[Union[List[Tuple[Any]], List[Tuple[Any]]]] = None,
1282+
procs_cpu_bound: Optional[int] = None):
1283+
df: pd.DataFrame = Pandas._read_parquet_paths(session_primitives=session_primitives,
1284+
path=path,
1285+
columns=columns,
1286+
filters=filters,
1287+
procs_cpu_bound=procs_cpu_bound)
1288+
send_pipe.send(df)
1289+
send_pipe.close()
1290+
1291+
@staticmethod
1292+
def _read_parquet_paths(session_primitives: Any,
1293+
path: Union[str, List[str]],
1294+
columns: Optional[List[str]] = None,
1295+
filters: Optional[Union[List[Tuple[Any]], List[Tuple[Any]]]] = None,
1296+
procs_cpu_bound: Optional[int] = None) -> pd.DataFrame:
1297+
"""
1298+
Read parquet data from S3
1299+
1300+
:param session_primitives: SessionPrimitives()
1301+
:param path: AWS S3 path or List of paths (E.g. s3://bucket-name/folder_name/)
1302+
:param columns: Names of columns to read from the file
1303+
:param filters: List of filters to apply, like ``[[('x', '=', 0), ...], ...]``.
1304+
:param procs_cpu_bound: Number of cores used for CPU bound tasks
1305+
"""
1306+
df: pd.DataFrame
1307+
if (type(path) == str) or (len(path) == 1):
1308+
path = path[0] if type(path) == list else path # type: ignore
1309+
df = Pandas._read_parquet_path(
1310+
session_primitives=session_primitives,
1311+
path=path, # type: ignore
1312+
columns=columns,
1313+
filters=filters,
1314+
procs_cpu_bound=procs_cpu_bound)
1315+
else:
1316+
df = Pandas._read_parquet_path(session_primitives=session_primitives,
1317+
path=path[0],
1318+
columns=columns,
1319+
filters=filters,
1320+
procs_cpu_bound=procs_cpu_bound)
1321+
for p in path[1:]:
1322+
df_aux = Pandas._read_parquet_path(session_primitives=session_primitives,
1323+
path=p,
1324+
columns=columns,
1325+
filters=filters,
1326+
procs_cpu_bound=procs_cpu_bound)
1327+
df = pd.concat(objs=[df, df_aux], ignore_index=True)
1328+
return df
1329+
1330+
@staticmethod
1331+
def _read_parquet_path(session_primitives: Any,
1332+
path: str,
1333+
columns: Optional[List[str]] = None,
1334+
filters: Optional[Union[List[Tuple[Any]], List[Tuple[Any]]]] = None,
1335+
procs_cpu_bound: Optional[int] = None) -> pd.DataFrame:
1336+
"""
1337+
Read parquet data from S3
1338+
1339+
:param session_primitives: SessionPrimitives()
12191340
:param path: AWS S3 path (E.g. s3://bucket-name/folder_name/)
12201341
:param columns: Names of columns to read from the file
12211342
:param filters: List of filters to apply, like ``[[('x', '=', 0), ...], ...]``.
12221343
:param procs_cpu_bound: Number of cores used for CPU bound tasks
12231344
"""
12241345
path = path[:-1] if path[-1] == "/" else path
1225-
procs_cpu_bound = 1 if self._session.procs_cpu_bound is None else self._session.procs_cpu_bound if procs_cpu_bound is None else procs_cpu_bound
1346+
procs_cpu_bound = procs_cpu_bound if procs_cpu_bound is not None else session_primitives.procs_cpu_bound if session_primitives.procs_cpu_bound is not None else 1
12261347
use_threads: bool = True if procs_cpu_bound > 1 else False
1227-
fs: S3FileSystem = s3.get_fs(session_primitives=self._session.primitives)
1348+
fs: S3FileSystem = s3.get_fs(session_primitives=session_primitives)
12281349
fs = pa.filesystem._ensure_filesystem(fs)
1350+
logger.debug(f"Reading Parquet table: {path}")
12291351
table = pq.read_table(source=path, columns=columns, filters=filters, filesystem=fs, use_threads=use_threads)
12301352
# Check if we lose some integer during the conversion (Happens when has some null value)
12311353
integers = [field.name for field in table.schema if str(field.type).startswith("int")]
1354+
logger.debug(f"Converting to Pandas: {path}")
12321355
df = table.to_pandas(use_threads=use_threads, integer_object_nulls=True)
12331356
for c in integers:
12341357
if not str(df[c].dtype).startswith("int"):
12351358
df[c] = df[c].astype("Int64")
1359+
logger.debug(f"Done: {path}")
12361360
return df
12371361

12381362
def read_table(self,

awswrangler/session.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ def __init__(self,
4848
athena_s3_output: Optional[str] = None,
4949
athena_encryption: Optional[str] = "SSE_S3",
5050
athena_kms_key: Optional[str] = None,
51-
athena_database: str = "default"):
51+
athena_database: str = "default",
52+
athena_ctas_approach: bool = False):
5253
"""
5354
Most parameters inherit from Boto3 or Pyspark.
5455
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
@@ -68,6 +69,7 @@ def __init__(self,
6869
:param procs_io_bound: number of processes that can be used in single node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR)
6970
:param athena_workgroup: Default AWS Athena Workgroup (str)
7071
:param athena_database: AWS Glue/Athena database name
72+
:param athena_ctas_approach: Wraps the query with a CTAS
7173
:param athena_s3_output: AWS S3 path
7274
:param athena_encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
7375
:param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
@@ -92,6 +94,7 @@ def __init__(self,
9294
self._athena_encryption: Optional[str] = athena_encryption
9395
self._athena_kms_key: Optional[str] = athena_kms_key
9496
self._athena_database: str = athena_database
97+
self._athena_ctas_approach: bool = athena_ctas_approach
9598
self._primitives = None
9699
self._load_new_primitives()
97100
if boto3_session:
@@ -131,23 +134,22 @@ def _load_new_primitives(self):
131134
Load or reload a new AWS Wrangler Session primitives
132135
:return: None
133136
"""
134-
self._primitives = SessionPrimitives(
135-
profile_name=self._profile_name,
136-
aws_access_key_id=self._aws_access_key_id,
137-
aws_secret_access_key=self._aws_secret_access_key,
138-
aws_session_token=self._aws_session_token,
139-
region_name=self._region_name,
140-
botocore_max_retries=self._botocore_max_retries,
141-
s3_additional_kwargs=self._s3_additional_kwargs,
142-
botocore_config=self._botocore_config,
143-
procs_cpu_bound=self._procs_cpu_bound,
144-
procs_io_bound=self._procs_io_bound,
145-
athena_workgroup=self._athena_workgroup,
146-
athena_s3_output=self._athena_s3_output,
147-
athena_encryption=self._athena_encryption,
148-
athena_kms_key=self._athena_kms_key,
149-
athena_database=self._athena_database,
150-
)
137+
self._primitives = SessionPrimitives(profile_name=self._profile_name,
138+
aws_access_key_id=self._aws_access_key_id,
139+
aws_secret_access_key=self._aws_secret_access_key,
140+
aws_session_token=self._aws_session_token,
141+
region_name=self._region_name,
142+
botocore_max_retries=self._botocore_max_retries,
143+
s3_additional_kwargs=self._s3_additional_kwargs,
144+
botocore_config=self._botocore_config,
145+
procs_cpu_bound=self._procs_cpu_bound,
146+
procs_io_bound=self._procs_io_bound,
147+
athena_workgroup=self._athena_workgroup,
148+
athena_s3_output=self._athena_s3_output,
149+
athena_encryption=self._athena_encryption,
150+
athena_kms_key=self._athena_kms_key,
151+
athena_database=self._athena_database,
152+
athena_ctas_approach=self._athena_ctas_approach)
151153

152154
@property
153155
def profile_name(self):
@@ -217,6 +219,10 @@ def athena_kms_key(self) -> Optional[str]:
217219
def athena_database(self) -> str:
218220
return self._athena_database
219221

222+
@property
223+
def athena_ctas_approach(self) -> bool:
224+
return self._athena_ctas_approach
225+
220226
@property
221227
def boto3_session(self):
222228
return self._boto3_session
@@ -297,7 +303,8 @@ def __init__(self,
297303
athena_s3_output: Optional[str] = None,
298304
athena_encryption: Optional[str] = None,
299305
athena_kms_key: Optional[str] = None,
300-
athena_database: Optional[str] = None):
306+
athena_database: Optional[str] = None,
307+
athena_ctas_approach: bool = False):
301308
"""
302309
Most parameters inherit from Boto3.
303310
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
@@ -314,6 +321,7 @@ def __init__(self,
314321
:param procs_io_bound: number of processes that can be used in single node applications for I/O bound cases (Default: os.cpu_count() * PROCS_IO_BOUND_FACTOR)
315322
:param athena_workgroup: Default AWS Athena Workgroup (str)
316323
:param athena_database: AWS Glue/Athena database name
324+
:param athena_ctas_approach: Wraps the query with a CTAS
317325
:param athena_s3_output: AWS S3 path
318326
:param athena_encryption: None|'SSE_S3'|'SSE_KMS'|'CSE_KMS'
319327
:param athena_kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
@@ -333,6 +341,7 @@ def __init__(self,
333341
self._athena_encryption: Optional[str] = athena_encryption
334342
self._athena_kms_key: Optional[str] = athena_kms_key
335343
self._athena_database: Optional[str] = athena_database
344+
self._athena_ctas_approach: bool = athena_ctas_approach
336345

337346
@property
338347
def profile_name(self):
@@ -394,6 +403,10 @@ def athena_kms_key(self) -> Optional[str]:
394403
def athena_database(self) -> Optional[str]:
395404
return self._athena_database
396405

406+
@property
407+
def athena_ctas_approach(self) -> bool:
408+
return self._athena_ctas_approach
409+
397410
@property
398411
def session(self):
399412
"""
@@ -413,4 +426,5 @@ def session(self):
413426
athena_s3_output=self._athena_s3_output,
414427
athena_encryption=self._athena_encryption,
415428
athena_kms_key=self._athena_kms_key,
416-
athena_database=self._athena_database)
429+
athena_database=self._athena_database,
430+
athena_ctas_approach=self._athena_ctas_approach)

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
numpy~=1.17.4
22
pandas~=0.25.3
33
pyarrow~=0.15.1
4-
botocore~=1.13.35
5-
boto3~=1.10.35
4+
botocore~=1.13.36
5+
boto3~=1.10.36
66
s3fs~=0.4.0
77
tenacity~=6.0.0
88
pg8000~=1.13.2

testing/test_awswrangler/test_pandas.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,5 +1509,29 @@ def test_read_sql_athena_ctas(session, bucket, database):
15091509
procs_cpu_bound=4,
15101510
partition_cols=["partition"])
15111511
df2 = session.pandas.read_sql_athena(ctas_approach=True, sql="select * from test", database=database)
1512+
session.s3.delete_objects(path=path)
15121513
assert len(list(df.columns)) == len(list(df2.columns))
15131514
assert len(df.index) == len(df2.index)
1515+
1516+
1517+
def test_read_sql_athena_s3_output_ctas(session, bucket, database):
1518+
n: int = 1_000_000
1519+
df = pd.DataFrame({"id": list((range(n))), "partition": list(["foo" if i % 2 == 0 else "boo" for i in range(n)])})
1520+
path = f"s3://{bucket}/test_read_sql_athena_s3_output_ctas/"
1521+
session.pandas.to_parquet(dataframe=df,
1522+
database=database,
1523+
table="test",
1524+
path=path,
1525+
mode="overwrite",
1526+
preserve_index=True,
1527+
procs_cpu_bound=4,
1528+
partition_cols=["partition"])
1529+
path_ctas = f"s3://{bucket}/test_read_sql_athena_s3_output_ctas_metadata/"
1530+
df2 = session.pandas.read_sql_athena(ctas_approach=True,
1531+
sql="select * from test",
1532+
database=database,
1533+
s3_output=path_ctas)
1534+
session.s3.delete_objects(path=path)
1535+
assert len(list(df.columns)) + 1 == len(list(df2.columns))
1536+
assert len(df.index) == len(df2.index)
1537+
print(df2)

0 commit comments

Comments
 (0)