Skip to content

Commit b952eb7

Browse files
committed
Fixing general athena cache bugs.
1 parent 1d0bbec commit b952eb7

File tree

3 files changed

+93
-6
lines changed

3 files changed

+93
-6
lines changed

awswrangler/_data_types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,9 +439,13 @@ def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd
439439
)
440440
elif pandas_type == "string":
441441
curr_type: str = str(df[col].dtypes)
442-
if curr_type.startswith("int") or curr_type.startswith("float"):
442+
print(curr_type)
443+
if (curr_type.lower().startswith("int") is True) or (curr_type.startswith("float") is True):
443444
df[col] = df[col].astype(str).astype("string")
445+
elif curr_type.startswith("object") is True:
446+
df[col] = df[col].astype(str)
444447
else:
448+
print(col)
445449
df[col] = df[col].astype("string")
446450
else:
447451
try:

awswrangler/athena.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ def read_sql_query( # pylint: disable=too-many-branches,too-many-locals,too-man
506506
max_cache_seconds=max_cache_seconds,
507507
max_cache_query_inspections=max_cache_query_inspections,
508508
)
509+
_logger.debug("cache_info: %s", cache_info)
509510

510511
if cache_info["has_valid_cache"] is True:
511512
_logger.debug("Valid cache found. Retrieving...")
@@ -687,6 +688,7 @@ def _resolve_query_with_cache( # pylint: disable=too-many-return-statements
687688
session: Optional[boto3.Session],
688689
):
689690
"""Fetch cached data and return it as a pandas Dataframe (or list of Dataframes)."""
691+
_logger.debug("cache_info: %s", cache_info)
690692
if cache_info["data_type"] == "parquet":
691693
manifest_path = cache_info["query_execution_info"]["Statistics"]["DataManifestLocation"]
692694
# this is needed just so we can access boto's modeled exceptions
@@ -970,7 +972,9 @@ def _prepare_query_string_for_comparison(query_string: str) -> str:
970972
"""To use cached data, we need to compare queries. Returns a query string in canonical form."""
971973
# for now this is a simple complete strip, but it could grow into much more sophisticated
972974
# query comparison data structures
973-
return "".join(query_string.split()).strip("()").lower()
975+
query_string = "".join(query_string.split()).strip("()").lower()
976+
query_string = query_string[:-1] if query_string.endswith(";") is True else query_string
977+
return query_string
974978

975979

976980
def _get_last_query_executions(
@@ -983,6 +987,7 @@ def _get_last_query_executions(
983987
args["WorkGroup"] = workgroup
984988
paginator = client_athena.get_paginator("list_query_executions")
985989
for page in paginator.paginate(**args):
990+
_logger.debug("paginating Athena's queries history...")
986991
query_execution_id_list: List[str] = page["QueryExecutionIds"]
987992
execution_data = client_athena.batch_get_query_execution(QueryExecutionIds=query_execution_id_list)
988993
yield execution_data.get("QueryExecutions")
@@ -1026,33 +1031,45 @@ def _check_for_cached_results(
10261031
num_executions_inspected: int = 0
10271032
if max_cache_seconds > 0: # pylint: disable=too-many-nested-blocks
10281033
current_timestamp = datetime.datetime.now(datetime.timezone.utc)
1029-
print(current_timestamp)
10301034
for query_executions in _get_last_query_executions(boto3_session=session, workgroup=workgroup):
1035+
1036+
_logger.debug("len(query_executions): %s", len(query_executions))
10311037
cached_queries: List[Dict[str, Any]] = _sort_successful_executions_data(query_executions=query_executions)
10321038
comparable_sql: str = _prepare_query_string_for_comparison(sql)
1039+
_logger.debug("len(cached_queries): %s", len(cached_queries))
10331040

10341041
# this could be mapreduced, but it is only 50 items long, tops
10351042
for query_info in cached_queries:
1036-
if (current_timestamp - query_info["Status"]["CompletionDateTime"]).total_seconds() > max_cache_seconds:
1037-
break # pragma: no cover
1043+
1044+
query_timestamp: datetime.datetime = query_info["Status"]["CompletionDateTime"]
1045+
_logger.debug("current_timestamp: %s", current_timestamp)
1046+
_logger.debug("query_timestamp: %s", query_timestamp)
1047+
if (current_timestamp - query_timestamp).total_seconds() > max_cache_seconds:
1048+
return {"has_valid_cache": False} # pragma: no cover
10381049

10391050
comparison_query: Optional[str]
10401051
if query_info["StatementType"] == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
10411052
parsed_query: Optional[str] = _parse_select_query_from_possible_ctas(query_info["Query"])
10421053
if parsed_query is not None:
10431054
comparison_query = _prepare_query_string_for_comparison(query_string=parsed_query)
1055+
_logger.debug("DDL - comparison_query: %s", comparison_query)
1056+
_logger.debug("DDL - comparable_sql: %s", comparable_sql)
10441057
if comparison_query == comparable_sql:
10451058
data_type = "parquet"
10461059
return {"has_valid_cache": True, "data_type": data_type, "query_execution_info": query_info}
10471060

10481061
elif query_info["StatementType"] == "DML" and not query_info["Query"].startswith("INSERT"):
10491062
comparison_query = _prepare_query_string_for_comparison(query_string=query_info["Query"])
1063+
_logger.debug("DML - comparison_query: %s", comparison_query)
1064+
_logger.debug("DML - comparable_sql: %s", comparable_sql)
10501065
if comparison_query == comparable_sql:
10511066
data_type = "csv"
10521067
return {"has_valid_cache": True, "data_type": data_type, "query_execution_info": query_info}
10531068

10541069
num_executions_inspected += 1
1070+
_logger.debug("num_executions_inspected: %s", num_executions_inspected)
1071+
_logger.debug("max_cache_query_inspections: %s", max_cache_query_inspections)
10551072
if num_executions_inspected >= max_cache_query_inspections:
1056-
break # pragma: no cover
1073+
return {"has_valid_cache": False} # pragma: no cover
10571074

10581075
return {"has_valid_cache": False}

tests/test_s3_athena.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2058,6 +2058,33 @@ def test_cache_query_ctas_approach_false(path, database, table):
20582058
assert df.c0.sum() == df3.c0.sum()
20592059

20602060

2061+
def test_cache_query_semicolon(path, database, table):
2062+
df = pd.DataFrame({"c0": [0, None]}, dtype="Int64")
2063+
paths = wr.s3.to_parquet(
2064+
df=df,
2065+
path=path,
2066+
dataset=True,
2067+
mode="overwrite",
2068+
database=database,
2069+
table=table,
2070+
)["paths"]
2071+
wr.s3.wait_objects_exist(paths=paths)
2072+
2073+
with patch(
2074+
"awswrangler.athena._check_for_cached_results", return_value={"has_valid_cache": False}
2075+
) as mocked_cache_attempt:
2076+
df2 = wr.athena.read_sql_query(f"SELECT * FROM {table}", database=database, ctas_approach=True, max_cache_seconds=0)
2077+
mocked_cache_attempt.assert_called()
2078+
assert df.shape == df2.shape
2079+
assert df.c0.sum() == df2.c0.sum()
2080+
2081+
with patch("awswrangler.athena._resolve_query_without_cache") as resolve_no_cache:
2082+
df3 = wr.athena.read_sql_query(f"SELECT * FROM {table};", database=database, ctas_approach=True, max_cache_seconds=900)
2083+
resolve_no_cache.assert_not_called()
2084+
assert df.shape == df3.shape
2085+
assert df.c0.sum() == df3.c0.sum()
2086+
2087+
20612088
@pytest.mark.parametrize("partition_cols", [None, ["c2"], ["c1", "c2"]])
20622089
def test_metadata_partitions_dataset(path, partition_cols):
20632090
df = pd.DataFrame({"c0": [0, 1, 2], "c1": [3, 4, 5], "c2": [6, 7, 8]})
@@ -2483,3 +2510,42 @@ def test_sanitize_columns(path, sanitize_columns, col):
24832510
assert len(df.index) == 3
24842511
assert len(df.columns) == 1
24852512
assert df.columns == [col]
2513+
2514+
2515+
def test_parquet_catalog_casting_to_string(path, table, database):
2516+
paths = wr.s3.to_parquet(
2517+
df=get_df_cast(),
2518+
path=path,
2519+
index=False,
2520+
dataset=True,
2521+
mode="overwrite",
2522+
database=database,
2523+
table=table,
2524+
dtype={
2525+
"iint8": "string",
2526+
"iint16": "string",
2527+
"iint32": "string",
2528+
"iint64": "string",
2529+
"float": "string",
2530+
"double": "double",
2531+
"decimal": "string",
2532+
"string": "string",
2533+
"date": "string",
2534+
"timestamp": "string",
2535+
"bool": "string",
2536+
"binary": "string",
2537+
"category": "string",
2538+
"par0": "string",
2539+
"par1": "string",
2540+
},
2541+
)["paths"]
2542+
wr.s3.wait_objects_exist(paths=paths)
2543+
df = wr.s3.read_parquet(path=path)
2544+
assert len(df.index) == 3
2545+
assert len(df.columns) == 15
2546+
df = wr.athena.read_sql_table(table=table, database=database, ctas_approach=True)
2547+
assert len(df.index) == 3
2548+
assert len(df.columns) == 15
2549+
df = wr.athena.read_sql_table(table=table, database=database, ctas_approach=False)
2550+
assert len(df.index) == 3
2551+
assert len(df.columns) == 15

0 commit comments

Comments
 (0)