11"""Amazon Athena Module."""
22
33import csv
4+ import datetime
45import logging
56import pprint
67import re
78import time
8- from datetime import datetime , timezone
99from decimal import Decimal
1010from typing import Any , Dict , Iterator , List , Optional , Tuple , Union
1111
1818_logger : logging .Logger = logging .getLogger (__name__ )
1919
2020_QUERY_WAIT_POLLING_DELAY : float = 0.2 # SECONDS
21- _CACHE_PREVIOUS_QUERY_COUNT : int = 50 # number of past queries to scan for cached results
2221
2322
2423def get_query_columns_types (query_execution_id : str , boto3_session : Optional [boto3 .Session ] = None ) -> Dict [str , str ]:
@@ -391,6 +390,7 @@ def read_sql_query(
391390 use_threads : bool = True ,
392391 boto3_session : Optional [boto3 .Session ] = None ,
393392 max_cache_seconds : int = 0 ,
393+ max_cache_query_inspections : int = 50 ,
394394) -> Union [pd .DataFrame , Iterator [pd .DataFrame ]]:
395395 """Execute any SQL query on AWS Athena and return the results as a Pandas DataFrame.
396396
@@ -474,13 +474,17 @@ def read_sql_query(
474474 If enabled os.cpu_count() will be used as the max number of threads.
475475 boto3_session : boto3.Session(), optional
476476 Boto3 Session. The default boto3 session will be used if boto3_session receive None.
477- max_cache_seconds: int
477+ max_cache_seconds : int
478478 Wrangler can look up in Athena's history if this query has been run before.
479479 If so, and its completion time is less than `max_cache_seconds` before now, wrangler
480480 skips query execution and just returns the same results as last time.
481481 If cached results are valid, wrangler ignores the `ctas_approach`, `s3_output`, `encryption`, `kms_key`,
482482 `keep_files` and `ctas_temp_table_name` params.
483483 If reading cached data fails for any reason, execution falls back to the usual query run path.
484+ max_cache_query_inspections : int
485+ Max number of queries that will be inspected from the history to try to find some result to reuse.
486+ The bigger the number of inspection, the bigger will be the latency for not cached queries.
487+ Only takes effect if max_cache_seconds > 0.
484488
485489 Returns
486490 -------
@@ -497,7 +501,11 @@ def read_sql_query(
497501
498502 # check for cached results
499503 cache_info : Dict [str , Any ] = _check_for_cached_results (
500- sql = sql , session = session , workgroup = workgroup , max_cache_seconds = max_cache_seconds
504+ sql = sql ,
505+ session = session ,
506+ workgroup = workgroup ,
507+ max_cache_seconds = max_cache_seconds ,
508+ max_cache_query_inspections = max_cache_query_inspections ,
501509 )
502510
503511 if cache_info ["has_valid_cache" ] is True :
@@ -961,17 +969,17 @@ def _prepare_query_string_for_comparison(query_string: str) -> str:
961969
962970def _get_last_query_executions (
963971 boto3_session : Optional [boto3 .Session ] = None , workgroup : Optional [str ] = None
964- ) -> List [Dict [str , Any ]]:
965- """Return the last 50 `query_execution_info`s run by the workgroup in Athena."""
972+ ) -> Iterator [ List [Dict [str , Any ] ]]:
973+ """Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
966974 client_athena : boto3 .client = _utils .client (service_name = "athena" , session = boto3_session )
967-
968- args : Dict [str , Any ] = {"MaxResults" : _CACHE_PREVIOUS_QUERY_COUNT }
969- if workgroup :
975+ args : Dict [str , str ] = {}
976+ if workgroup is not None :
970977 args ["WorkGroup" ] = workgroup
971- query_execution_list : Dict [str , Any ] = client_athena .list_query_executions (** args )
972- query_execution_id_list : List [str ] = query_execution_list ["QueryExecutionIds" ]
973- execution_data = client_athena .batch_get_query_execution (QueryExecutionIds = query_execution_id_list )
974- return execution_data .get ("QueryExecutions" )
978+ paginator = client_athena .get_paginator ("get_query_results" )
979+ for page in paginator .paginate (** args ):
980+ query_execution_id_list : List [str ] = page ["QueryExecutionIds" ]
981+ execution_data = client_athena .batch_get_query_execution (QueryExecutionIds = query_execution_id_list )
982+ yield execution_data .get ("QueryExecutions" )
975983
976984
977985def _sort_successful_executions_data (query_executions : List [Dict [str , Any ]]):
@@ -1002,37 +1010,42 @@ def _parse_select_query_from_possible_ctas(possible_ctas: str) -> Optional[str]:
10021010
10031011
10041012def _check_for_cached_results (
1005- sql : str , session : boto3 .Session , workgroup : Optional [str ], max_cache_seconds : int
1013+ sql : str , session : boto3 .Session , workgroup : Optional [str ], max_cache_seconds : int , max_cache_query_inspections : int
10061014) -> Dict [str , Any ]:
10071015 """
1008- Check wether `sql` has been run before, within the `max_cache_seconds` window, by the `workgroup`.
1016+ Check whether `sql` has been run before, within the `max_cache_seconds` window, by the `workgroup`.
10091017
10101018 If so, returns a dict with Athena's `query_execution_info` and the data format.
10111019 """
1012- if max_cache_seconds > 0 :
1013- last_query_executions = _get_last_query_executions (boto3_session = session , workgroup = workgroup )
1014- cached_queries = _sort_successful_executions_data (query_executions = last_query_executions )
1015- current_timestamp = datetime .now (timezone .utc )
1016- comparable_sql : str = _prepare_query_string_for_comparison (sql )
1017-
1018- # this could be mapreduced, but it is only 50 items long, tops
1019- for query_info in cached_queries :
1020- if (current_timestamp - query_info ["Status" ]["CompletionDateTime" ]).total_seconds () > max_cache_seconds :
1021- break # pragma: no cover
1022-
1023- comparison_query : Optional [str ]
1024- if query_info ["StatementType" ] == "DDL" and query_info ["Query" ].startswith ("CREATE TABLE" ):
1025- parsed_query : Optional [str ] = _parse_select_query_from_possible_ctas (query_info ["Query" ])
1026- if parsed_query is not None :
1027- comparison_query = _prepare_query_string_for_comparison (query_string = parsed_query )
1020+ num_executions_inspected : int = 0
1021+ if max_cache_seconds > 0 : # pylint: disable=too-many-nested-blocks
1022+ for query_executions in _get_last_query_executions (boto3_session = session , workgroup = workgroup ):
1023+ cached_queries : List [Dict [str , Any ]] = _sort_successful_executions_data (query_executions = query_executions )
1024+ current_timestamp = datetime .datetime .utcnow ()
1025+ comparable_sql : str = _prepare_query_string_for_comparison (sql )
1026+
1027+ # this could be mapreduced, but it is only 50 items long, tops
1028+ for query_info in cached_queries :
1029+ if (current_timestamp - query_info ["Status" ]["CompletionDateTime" ]).total_seconds () > max_cache_seconds :
1030+ break # pragma: no cover
1031+
1032+ comparison_query : Optional [str ]
1033+ if query_info ["StatementType" ] == "DDL" and query_info ["Query" ].startswith ("CREATE TABLE" ):
1034+ parsed_query : Optional [str ] = _parse_select_query_from_possible_ctas (query_info ["Query" ])
1035+ if parsed_query is not None :
1036+ comparison_query = _prepare_query_string_for_comparison (query_string = parsed_query )
1037+ if comparison_query == comparable_sql :
1038+ data_type = "parquet"
1039+ return {"has_valid_cache" : True , "data_type" : data_type , "query_execution_info" : query_info }
1040+
1041+ elif query_info ["StatementType" ] == "DML" and not query_info ["Query" ].startswith ("INSERT" ):
1042+ comparison_query = _prepare_query_string_for_comparison (query_string = query_info ["Query" ])
10281043 if comparison_query == comparable_sql :
1029- data_type = "parquet "
1044+ data_type = "csv "
10301045 return {"has_valid_cache" : True , "data_type" : data_type , "query_execution_info" : query_info }
10311046
1032- elif query_info ["StatementType" ] == "DML" and not query_info ["Query" ].startswith ("INSERT" ):
1033- comparison_query = _prepare_query_string_for_comparison (query_string = query_info ["Query" ])
1034- if comparison_query == comparable_sql :
1035- data_type = "csv"
1036- return {"has_valid_cache" : True , "data_type" : data_type , "query_execution_info" : query_info }
1047+ num_executions_inspected += 1
1048+ if num_executions_inspected >= max_cache_query_inspections :
1049+ break
10371050
10381051 return {"has_valid_cache" : False }
0 commit comments