2020 _get_query_metadata ,
2121 _get_s3_output ,
2222 _get_workgroup_config ,
23+ _LocalMetadataCacheManager ,
2324 _QueryMetadata ,
2425 _start_query_execution ,
2526 _WorkGroupConfig ,
@@ -96,33 +97,37 @@ def _compare_query_string(sql: str, other: str) -> bool:
9697 return False
9798
9899
99- def _get_last_query_executions (
100- boto3_session : Optional [boto3 .Session ] = None , workgroup : Optional [str ] = None
101- ) -> Iterator [List [Dict [str , Any ]]]:
100+ def _get_last_query_infos (
101+ max_remote_cache_entries : int ,
102+ boto3_session : Optional [boto3 .Session ] = None ,
103+ workgroup : Optional [str ] = None ,
104+ ) -> List [Dict [str , Any ]]:
102105 """Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
103106 client_athena : boto3 .client = _utils .client (service_name = "athena" , session = boto3_session )
104- args : Dict [str , Union [str , Dict [str , int ]]] = {"PaginationConfig" : {"MaxItems" : 50 , "PageSize" : 50 }}
107+ page_size = 50
108+ args : Dict [str , Union [str , Dict [str , int ]]] = {
109+ "PaginationConfig" : {"MaxItems" : max_remote_cache_entries , "PageSize" : page_size }
110+ }
105111 if workgroup is not None :
106112 args ["WorkGroup" ] = workgroup
107113 paginator = client_athena .get_paginator ("list_query_executions" )
114+ uncached_ids = []
108115 for page in paginator .paginate (** args ):
109116 _logger .debug ("paginating Athena's queries history..." )
110117 query_execution_id_list : List [str ] = page ["QueryExecutionIds" ]
111- execution_data = client_athena .batch_get_query_execution (QueryExecutionIds = query_execution_id_list )
112- yield execution_data .get ("QueryExecutions" )
113-
114-
115- def _sort_successful_executions_data (query_executions : List [Dict [str , Any ]]) -> List [Dict [str , Any ]]:
116- """
117- Sorts `_get_last_query_executions`'s results based on query Completion DateTime.
118-
119- This is useful to guarantee LRU caching rules.
120- """
121- filtered : List [Dict [str , Any ]] = []
122- for query in query_executions :
123- if (query ["Status" ].get ("State" ) == "SUCCEEDED" ) and (query .get ("StatementType" ) in ["DDL" , "DML" ]):
124- filtered .append (query )
125- return sorted (filtered , key = lambda e : str (e ["Status" ]["CompletionDateTime" ]), reverse = True )
118+ for query_execution_id in query_execution_id_list :
119+ if query_execution_id not in _cache_manager :
120+ uncached_ids .append (query_execution_id )
121+ if uncached_ids :
122+ new_execution_data = []
123+ for i in range (0 , len (uncached_ids ), page_size ):
124+ new_execution_data .extend (
125+ client_athena .batch_get_query_execution (QueryExecutionIds = uncached_ids [i : i + page_size ]).get (
126+ "QueryExecutions"
127+ )
128+ )
129+ _cache_manager .update_cache (new_execution_data )
130+ return _cache_manager .sorted_successful_generator ()
126131
127132
128133def _parse_select_query_from_possible_ctas (possible_ctas : str ) -> Optional [str ]:
@@ -150,6 +155,7 @@ def _check_for_cached_results(
150155 workgroup : Optional [str ],
151156 max_cache_seconds : int ,
152157 max_cache_query_inspections : int ,
158+ max_remote_cache_entries : int ,
153159) -> _CacheInfo :
154160 """
155161 Check whether `sql` has been run before, within the `max_cache_seconds` window, by the `workgroup`.
@@ -162,45 +168,41 @@ def _check_for_cached_results(
162168 comparable_sql : str = _prepare_query_string_for_comparison (sql )
163169 current_timestamp : datetime .datetime = datetime .datetime .now (datetime .timezone .utc )
164170 _logger .debug ("current_timestamp: %s" , current_timestamp )
165- for query_executions in _get_last_query_executions (boto3_session = boto3_session , workgroup = workgroup ):
166- _logger .debug ("len(query_executions): %s" , len (query_executions ))
167- cached_queries : List [Dict [str , Any ]] = _sort_successful_executions_data (query_executions = query_executions )
168- _logger .debug ("len(cached_queries): %s" , len (cached_queries ))
169- for query_info in cached_queries :
170- query_execution_id : str = query_info ["QueryExecutionId" ]
171- query_timestamp : datetime .datetime = query_info ["Status" ]["CompletionDateTime" ]
172- _logger .debug ("query_timestamp: %s" , query_timestamp )
173-
174- if (current_timestamp - query_timestamp ).total_seconds () > max_cache_seconds :
175- return _CacheInfo (
176- has_valid_cache = False , query_execution_id = query_execution_id , query_execution_payload = query_info
177- )
178-
179- statement_type : Optional [str ] = query_info .get ("StatementType" )
180- if statement_type == "DDL" and query_info ["Query" ].startswith ("CREATE TABLE" ):
181- parsed_query : Optional [str ] = _parse_select_query_from_possible_ctas (possible_ctas = query_info ["Query" ])
182- if parsed_query is not None :
183- if _compare_query_string (sql = comparable_sql , other = parsed_query ):
184- return _CacheInfo (
185- has_valid_cache = True ,
186- file_format = "parquet" ,
187- query_execution_id = query_execution_id ,
188- query_execution_payload = query_info ,
189- )
190- elif statement_type == "DML" and not query_info ["Query" ].startswith ("INSERT" ):
191- if _compare_query_string (sql = comparable_sql , other = query_info ["Query" ]):
171+ for query_info in _get_last_query_infos (
172+ max_remote_cache_entries = max_remote_cache_entries ,
173+ boto3_session = boto3_session ,
174+ workgroup = workgroup ,
175+ ):
176+ query_execution_id : str = query_info ["QueryExecutionId" ]
177+ query_timestamp : datetime .datetime = query_info ["Status" ]["CompletionDateTime" ]
178+ _logger .debug ("query_timestamp: %s" , query_timestamp )
179+ if (current_timestamp - query_timestamp ).total_seconds () > max_cache_seconds :
180+ return _CacheInfo (
181+ has_valid_cache = False , query_execution_id = query_execution_id , query_execution_payload = query_info
182+ )
183+ statement_type : Optional [str ] = query_info .get ("StatementType" )
184+ if statement_type == "DDL" and query_info ["Query" ].startswith ("CREATE TABLE" ):
185+ parsed_query : Optional [str ] = _parse_select_query_from_possible_ctas (possible_ctas = query_info ["Query" ])
186+ if parsed_query is not None :
187+ if _compare_query_string (sql = comparable_sql , other = parsed_query ):
192188 return _CacheInfo (
193189 has_valid_cache = True ,
194- file_format = "csv " ,
190+ file_format = "parquet " ,
195191 query_execution_id = query_execution_id ,
196192 query_execution_payload = query_info ,
197193 )
198-
199- num_executions_inspected += 1
200- _logger .debug ("num_executions_inspected: %s" , num_executions_inspected )
201- if num_executions_inspected >= max_cache_query_inspections :
202- return _CacheInfo (has_valid_cache = False )
203-
194+ elif statement_type == "DML" and not query_info ["Query" ].startswith ("INSERT" ):
195+ if _compare_query_string (sql = comparable_sql , other = query_info ["Query" ]):
196+ return _CacheInfo (
197+ has_valid_cache = True ,
198+ file_format = "csv" ,
199+ query_execution_id = query_execution_id ,
200+ query_execution_payload = query_info ,
201+ )
202+ num_executions_inspected += 1
203+ _logger .debug ("num_executions_inspected: %s" , num_executions_inspected )
204+ if num_executions_inspected >= max_cache_query_inspections :
205+ return _CacheInfo (has_valid_cache = False )
204206 return _CacheInfo (has_valid_cache = False )
205207
206208
@@ -302,6 +304,7 @@ def _resolve_query_with_cache(
302304 boto3_session = session ,
303305 categories = categories ,
304306 query_execution_payload = cache_info .query_execution_payload ,
307+ metadata_cache_manager = _cache_manager ,
305308 )
306309 if cache_info .file_format == "parquet" :
307310 return _fetch_parquet_result (
@@ -380,6 +383,7 @@ def _resolve_query_without_cache_ctas(
380383 query_execution_id = query_id ,
381384 boto3_session = boto3_session ,
382385 categories = categories ,
386+ metadata_cache_manager = _cache_manager ,
383387 )
384388 except exceptions .QueryFailed as ex :
385389 msg : str = str (ex )
@@ -439,6 +443,7 @@ def _resolve_query_without_cache_regular(
439443 query_execution_id = query_id ,
440444 boto3_session = boto3_session ,
441445 categories = categories ,
446+ metadata_cache_manager = _cache_manager ,
442447 )
443448 return _fetch_csv_result (
444449 query_metadata = query_metadata ,
@@ -532,6 +537,8 @@ def read_sql_query(
532537 boto3_session : Optional [boto3 .Session ] = None ,
533538 max_cache_seconds : int = 0 ,
534539 max_cache_query_inspections : int = 50 ,
540+ max_remote_cache_entries : int = 50 ,
541+ max_local_cache_entries : int = 100 ,
535542 data_source : Optional [str ] = None ,
536543 params : Optional [Dict [str , Any ]] = None ,
537544) -> Union [pd .DataFrame , Iterator [pd .DataFrame ]]:
@@ -678,6 +685,15 @@ def read_sql_query(
678685 Max number of queries that will be inspected from the history to try to find some result to reuse.
679686 The bigger the number of inspection, the bigger will be the latency for not cached queries.
680687 Only takes effect if max_cache_seconds > 0.
688+ max_remote_cache_entries : int
689+ Max number of queries that will be retrieved from AWS for cache inspection.
690+ The bigger the number of inspection, the bigger will be the latency for not cached queries.
691+ Only takes effect if max_cache_seconds > 0 and default value is 50.
692+ max_local_cache_entries : int
693+ Max number of queries for which metadata will be cached locally. This will reduce the latency and also
694+ enables keeping more than `max_remote_cache_entries` available for the cache. This value should not be
695+ smaller than max_remote_cache_entries.
696+ Only takes effect if max_cache_seconds > 0 and default value is 100.
681697 data_source : str, optional
682698 Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
683699 params: Dict[str, any], optional
@@ -718,12 +734,17 @@ def read_sql_query(
718734 for key , value in params .items ():
719735 sql = sql .replace (f":{ key } ;" , str (value ))
720736
737+ if max_remote_cache_entries > max_local_cache_entries :
738+ max_remote_cache_entries = max_local_cache_entries
739+
740+ _cache_manager .max_cache_size = max_local_cache_entries
721741 cache_info : _CacheInfo = _check_for_cached_results (
722742 sql = sql ,
723743 boto3_session = session ,
724744 workgroup = workgroup ,
725745 max_cache_seconds = max_cache_seconds ,
726746 max_cache_query_inspections = max_cache_query_inspections ,
747+ max_remote_cache_entries = max_remote_cache_entries ,
727748 )
728749 _logger .debug ("cache_info:\n %s" , cache_info )
729750 if cache_info .has_valid_cache is True :
@@ -774,6 +795,8 @@ def read_sql_table(
774795 boto3_session : Optional [boto3 .Session ] = None ,
775796 max_cache_seconds : int = 0 ,
776797 max_cache_query_inspections : int = 50 ,
798+ max_remote_cache_entries : int = 50 ,
799+ max_local_cache_entries : int = 100 ,
777800 data_source : Optional [str ] = None ,
778801) -> Union [pd .DataFrame , Iterator [pd .DataFrame ]]:
779802 """Extract the full table AWS Athena and return the results as a Pandas DataFrame.
@@ -914,6 +937,15 @@ def read_sql_table(
914937 Max number of queries that will be inspected from the history to try to find some result to reuse.
915938 The bigger the number of inspection, the bigger will be the latency for not cached queries.
916939 Only takes effect if max_cache_seconds > 0.
940+ max_remote_cache_entries : int
941+ Max number of queries that will be retrieved from AWS for cache inspection.
942+ The bigger the number of inspection, the bigger will be the latency for not cached queries.
943+ Only takes effect if max_cache_seconds > 0 and default value is 50.
944+ max_local_cache_entries : int
945+ Max number of queries for which metadata will be cached locally. This will reduce the latency and also
946+ enables keeping more than `max_remote_cache_entries` available for the cache. This value should not be
947+ smaller than max_remote_cache_entries.
948+ Only takes effect if max_cache_seconds > 0 and default value is 100.
917949 data_source : str, optional
918950 Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
919951
@@ -947,4 +979,9 @@ def read_sql_table(
947979 boto3_session = boto3_session ,
948980 max_cache_seconds = max_cache_seconds ,
949981 max_cache_query_inspections = max_cache_query_inspections ,
982+ max_remote_cache_entries = max_remote_cache_entries ,
983+ max_local_cache_entries = max_local_cache_entries ,
950984 )
985+
986+
987+ _cache_manager = _LocalMetadataCacheManager ()
0 commit comments