77import re
88import threading
99from heapq import heappop , heappush
10- from typing import TYPE_CHECKING , Any , Match , NamedTuple
10+ from typing import TYPE_CHECKING , Match , NamedTuple
1111
1212import boto3
1313
@@ -23,23 +23,23 @@ class _CacheInfo(NamedTuple):
2323 has_valid_cache : bool
2424 file_format : str | None = None
2525 query_execution_id : str | None = None
26- query_execution_payload : dict [ str , Any ] | None = None
26+ query_execution_payload : "QueryExecutionTypeDef" | None = None
2727
2828
2929class _LocalMetadataCacheManager :
3030 def __init__ (self ) -> None :
3131 self ._lock : threading .Lock = threading .Lock ()
32- self ._cache : dict [str , Any ] = {}
32+ self ._cache : dict [str , "QueryExecutionTypeDef" ] = {}
3333 self ._pqueue : list [tuple [datetime .datetime , str ]] = []
3434 self ._max_cache_size = 100
3535
36- def update_cache (self , items : list [dict [ str , Any ] ]) -> None :
36+ def update_cache (self , items : list ["QueryExecutionTypeDef" ]) -> None :
3737 """
3838 Update the local metadata cache with new query metadata.
3939
4040 Parameters
4141 ----------
42- items : List[Dict[str, Any]]
42+ items
4343 List of query execution metadata which is returned by boto3 `batch_get_query_execution()`.
4444 """
4545 with self ._lock :
@@ -62,18 +62,17 @@ def update_cache(self, items: list[dict[str, Any]]) -> None:
6262 heappush (self ._pqueue , (item ["Status" ]["SubmissionDateTime" ], item ["QueryExecutionId" ]))
6363 self ._cache [item ["QueryExecutionId" ]] = item
6464
65- def sorted_successful_generator (self ) -> list [dict [ str , Any ] ]:
65+ def sorted_successful_generator (self ) -> list ["QueryExecutionTypeDef" ]:
6666 """
6767 Sorts the entries in the local cache based on query Completion DateTime.
6868
6969 This is useful to guarantee LRU caching rules.
7070
7171 Returns
7272 -------
73- List[Dict[str, Any]]
7473 Returns successful DDL and DML queries sorted by query completion time.
7574 """
76- filtered : list [dict [ str , Any ] ] = []
75+ filtered : list ["QueryExecutionTypeDef" ] = []
7776 for query in self ._cache .values ():
7877 if (query ["Status" ].get ("State" ) == "SUCCEEDED" ) and (query .get ("StatementType" ) in ["DDL" , "DML" ]):
7978 filtered .append (query )
@@ -111,13 +110,13 @@ def _parse_select_query_from_possible_ctas(possible_ctas: str) -> str | None:
111110 return None
112111
113112
114- def _compare_query_string (sql : str , other : str ) -> bool :
113+ def _compare_query_string (
114+ sql : str , other : str , sql_params : list [str ] | None = None , other_params : list [str ] | None = None
115+ ) -> bool :
115116 comparison_query = _prepare_query_string_for_comparison (query_string = other )
116117 _logger .debug ("sql: %s" , sql )
117118 _logger .debug ("comparison_query: %s" , comparison_query )
118- if sql == comparison_query :
119- return True
120- return False
119+ return sql == comparison_query and sql_params == other_params
121120
122121
123122def _prepare_query_string_for_comparison (query_string : str ) -> str :
@@ -135,7 +134,7 @@ def _get_last_query_infos(
135134 max_remote_cache_entries : int ,
136135 boto3_session : boto3 .Session | None = None ,
137136 workgroup : str | None = None ,
138- ) -> list [dict [ str , Any ] ]:
137+ ) -> list ["QueryExecutionTypeDef" ]:
139138 """Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
140139 client_athena = _utils .client (service_name = "athena" , session = boto3_session )
141140 page_size = 50
@@ -160,14 +159,15 @@ def _get_last_query_infos(
160159 QueryExecutionIds = uncached_ids [i : i + page_size ],
161160 ).get ("QueryExecutions" )
162161 )
163- _cache_manager .update_cache (new_execution_data ) # type: ignore[arg-type]
162+ _cache_manager .update_cache (new_execution_data )
164163 return _cache_manager .sorted_successful_generator ()
165164
166165
167166def _check_for_cached_results (
168167 sql : str ,
169168 boto3_session : boto3 .Session | None ,
170169 workgroup : str | None ,
170+ params : list [str ] | None = None ,
171171 athena_cache_settings : typing .AthenaCacheSettings | None = None ,
172172) -> _CacheInfo :
173173 """
@@ -207,15 +207,25 @@ def _check_for_cached_results(
207207 if statement_type == "DDL" and query_info ["Query" ].startswith ("CREATE TABLE" ):
208208 parsed_query : str | None = _parse_select_query_from_possible_ctas (possible_ctas = query_info ["Query" ])
209209 if parsed_query is not None :
210- if _compare_query_string (sql = comparable_sql , other = parsed_query ):
210+ if _compare_query_string (
211+ sql = comparable_sql ,
212+ other = parsed_query ,
213+ sql_params = params ,
214+ other_params = query_info .get ("ExecutionParameters" ),
215+ ):
211216 return _CacheInfo (
212217 has_valid_cache = True ,
213218 file_format = "parquet" ,
214219 query_execution_id = query_execution_id ,
215220 query_execution_payload = query_info ,
216221 )
217222 elif statement_type == "DML" and not query_info ["Query" ].startswith ("INSERT" ):
218- if _compare_query_string (sql = comparable_sql , other = query_info ["Query" ]):
223+ if _compare_query_string (
224+ sql = comparable_sql ,
225+ other = query_info ["Query" ],
226+ sql_params = params ,
227+ other_params = query_info .get ("ExecutionParameters" ),
228+ ):
219229 return _CacheInfo (
220230 has_valid_cache = True ,
221231 file_format = "csv" ,
0 commit comments