55import time
66import warnings
77from decimal import Decimal
8- from typing import Any , Dict , List , NamedTuple , Optional , Union
8+ from typing import Any , Dict , Generator , List , NamedTuple , Optional , Union , cast
99
1010import boto3 # type: ignore
1111import pandas as pd # type: ignore
@@ -89,7 +89,7 @@ def _start_query_execution(
8989 client_athena : boto3 .client = _utils .client (service_name = "athena" , session = session )
9090 _logger .debug ("args: \n %s" , pprint .pformat (args ))
9191 response : Dict [str , Any ] = client_athena .start_query_execution (** args )
92- return response ["QueryExecutionId" ]
92+ return cast ( str , response ["QueryExecutionId" ])
9393
9494
9595def _get_workgroup_config (session : boto3 .Session , workgroup : Optional [str ] = None ) -> _WorkGroupConfig :
@@ -137,7 +137,7 @@ def _fetch_txt_result(query_metadata: _QueryMetadata, keep_files: bool, boto3_se
137137
138138def _parse_describe_table (df : pd .DataFrame ) -> pd .DataFrame :
139139 origin_df_dict = df .to_dict ()
140- target_df_dict : Dict [str , List ] = {"Column Name" : [], "Type" : [], "Partition" : [], "Comment" : []}
140+ target_df_dict : Dict [str , List [ Union [ str , bool ]] ] = {"Column Name" : [], "Type" : [], "Partition" : [], "Comment" : []}
141141 for index , col_name in origin_df_dict ["col_name" ].items ():
142142 col_name = col_name .strip ()
143143 if col_name .startswith ("#" ) or col_name == "" :
@@ -156,7 +156,7 @@ def _parse_describe_table(df: pd.DataFrame) -> pd.DataFrame:
156156def _get_query_metadata ( # pylint: disable=too-many-statements
157157 query_execution_id : str ,
158158 boto3_session : boto3 .Session ,
159- categories : List [str ] = None ,
159+ categories : Optional [ List [str ] ] = None ,
160160 query_execution_payload : Optional [Dict [str , Any ]] = None ,
161161) -> _QueryMetadata :
162162 """Get query metadata."""
@@ -226,7 +226,9 @@ def _get_query_metadata( # pylint: disable=too-many-statements
226226 return query_metadata
227227
228228
229- def _empty_dataframe_response (chunked : bool , query_metadata : _QueryMetadata ):
229+ def _empty_dataframe_response (
230+ chunked : bool , query_metadata : _QueryMetadata
231+ ) -> Union [pd .DataFrame , Generator [None , None , None ]]:
230232 """Generate an empty dataframe response."""
231233 if chunked is False :
232234 df = pd .DataFrame ()
@@ -425,7 +427,7 @@ def repair_table(
425427 boto3_session = session ,
426428 )
427429 response : Dict [str , Any ] = wait_query (query_execution_id = query_id , boto3_session = session )
428- return response ["Status" ]["State" ]
430+ return cast ( str , response ["Status" ]["State" ])
429431
430432
431433@apply_configs
@@ -556,7 +558,7 @@ def show_create_table(
556558 )
557559 query_metadata : _QueryMetadata = _get_query_metadata (query_execution_id = query_id , boto3_session = session )
558560 raw_result = _fetch_txt_result (query_metadata = query_metadata , keep_files = True , boto3_session = session ,)
559- return raw_result .createtab_stmt .str .strip ().str .cat (sep = " " )
561+ return cast ( str , raw_result .createtab_stmt .str .strip ().str .cat (sep = " " ) )
560562
561563
562564def get_work_group (workgroup : str , boto3_session : Optional [boto3 .Session ] = None ) -> Dict [str , Any ]:
@@ -581,7 +583,7 @@ def get_work_group(workgroup: str, boto3_session: Optional[boto3.Session] = None
581583
582584 """
583585 client_athena : boto3 .client = _utils .client (service_name = "athena" , session = boto3_session )
584- return client_athena .get_work_group (WorkGroup = workgroup )
586+ return cast ( Dict [ str , Any ], client_athena .get_work_group (WorkGroup = workgroup ) )
585587
586588
587589def stop_query_execution (query_execution_id : str , boto3_session : Optional [boto3 .Session ] = None ) -> None :
@@ -645,7 +647,7 @@ def wait_query(query_execution_id: str, boto3_session: Optional[boto3.Session] =
645647 raise exceptions .QueryFailed (response ["QueryExecution" ]["Status" ].get ("StateChangeReason" ))
646648 if state == "CANCELLED" :
647649 raise exceptions .QueryCancelled (response ["QueryExecution" ]["Status" ].get ("StateChangeReason" ))
648- return response ["QueryExecution" ]
650+ return cast ( Dict [ str , Any ], response ["QueryExecution" ])
649651
650652
651653def get_query_execution (query_execution_id : str , boto3_session : Optional [boto3 .Session ] = None ) -> Dict [str , Any ]:
@@ -673,4 +675,4 @@ def get_query_execution(query_execution_id: str, boto3_session: Optional[boto3.S
673675 """
674676 client_athena : boto3 .client = _utils .client (service_name = "athena" , session = boto3_session )
675677 response : Dict [str , Any ] = client_athena .get_query_execution (QueryExecutionId = query_execution_id )
676- return response ["QueryExecution" ]
678+ return cast ( Dict [ str , Any ], response ["QueryExecution" ])
0 commit comments