1- from typing import TYPE_CHECKING , Dict , Optional , Any , Iterator , List , Union
1+ from typing import TYPE_CHECKING , Dict , Optional , Any , Iterator , List , Union , Tuple
22from math import ceil
33from itertools import islice
44import re
@@ -55,16 +55,16 @@ def get_table_python_types(self, database: str, table: str) -> Dict[str, Optiona
5555 def metadata_to_glue (self ,
5656 dataframe ,
5757 path : str ,
58- objects_paths ,
59- file_format ,
60- database = None ,
61- table = None ,
62- partition_cols = None ,
63- preserve_index = True ,
58+ objects_paths : List [ str ] ,
59+ file_format : str ,
60+ database : str ,
61+ table : Optional [ str ] ,
62+ partition_cols : Optional [ List [ str ]] = None ,
63+ preserve_index : bool = True ,
6464 mode : str = "append" ,
65- compression = None ,
66- cast_columns = None ,
67- extra_args : Optional [Dict [str , Optional [Union [str , int ]]]] = None ,
65+ compression : Optional [ str ] = None ,
66+ cast_columns : Optional [ Dict [ str , str ]] = None ,
67+ extra_args : Optional [Dict [str , Optional [Union [str , int , List [ str ] ]]]] = None ,
6868 description : Optional [str ] = None ,
6969 parameters : Optional [Dict [str , str ]] = None ,
7070 columns_comments : Optional [Dict [str , str ]] = None ) -> None :
@@ -88,6 +88,8 @@ def metadata_to_glue(self,
8888 :return: None
8989 """
9090 indexes_position = "left" if file_format == "csv" else "right"
91+ schema : List [Tuple [str , str ]]
92+ partition_cols_schema : List [Tuple [str , str ]]
9193 schema , partition_cols_schema = Glue ._build_schema (dataframe = dataframe ,
9294 partition_cols = partition_cols ,
9395 preserve_index = preserve_index ,
@@ -138,14 +140,14 @@ def does_table_exists(self, database, table):
138140 return False
139141
140142 def create_table (self ,
141- database ,
142- table ,
143- schema ,
144- path ,
145- file_format ,
146- compression ,
147- partition_cols_schema = None ,
148- extra_args = None ,
143+ database : str ,
144+ table : str ,
145+ schema : List [ Tuple [ str , str ]] ,
146+ path : str ,
147+ file_format : str ,
148+ compression : Optional [ str ] ,
149+ partition_cols_schema : List [ Tuple [ str , str ]] ,
150+ extra_args : Optional [ Dict [ str , Union [ str , int , List [ str ], None ]]] = None ,
149151 description : Optional [str ] = None ,
150152 parameters : Optional [Dict [str , str ]] = None ,
151153 columns_comments : Optional [Dict [str , str ]] = None ) -> None :
@@ -166,13 +168,17 @@ def create_table(self,
166168 :return: None
167169 """
168170 if file_format == "parquet" :
169- table_input = Glue .parquet_table_definition (table , partition_cols_schema , schema , path , compression )
171+ table_input : Dict [str , Any ] = Glue .parquet_table_definition (table = table ,
172+ partition_cols_schema = partition_cols_schema ,
173+ schema = schema ,
174+ path = path ,
175+ compression = compression )
170176 elif file_format == "csv" :
171- table_input = Glue .csv_table_definition (table ,
172- partition_cols_schema ,
173- schema ,
174- path ,
175- compression ,
177+ table_input = Glue .csv_table_definition (table = table ,
178+ partition_cols_schema = partition_cols_schema ,
179+ schema = schema ,
180+ path = path ,
181+ compression = compression ,
176182 extra_args = extra_args )
177183 else :
178184 raise UnsupportedFileFormat (file_format )
@@ -223,19 +229,23 @@ def get_connection_details(self, name):
223229 return self ._client_glue .get_connection (Name = name , HidePassword = False )["Connection" ]
224230
225231 @staticmethod
226- def _build_schema (dataframe , partition_cols , preserve_index , indexes_position , cast_columns = None ):
232+ def _build_schema (
233+ dataframe ,
234+ partition_cols : Optional [List [str ]],
235+ preserve_index : bool ,
236+ indexes_position : str ,
237+ cast_columns : Optional [Dict [str , str ]] = None ) -> Tuple [List [Tuple [str , str ]], List [Tuple [str , str ]]]:
227238 if cast_columns is None :
228239 cast_columns = {}
229240 logger .debug (f"dataframe.dtypes:\n { dataframe .dtypes } " )
230- if not partition_cols :
241+ if partition_cols is None :
231242 partition_cols = []
232243
233- pyarrow_schema = data_types .extract_pyarrow_schema_from_pandas (dataframe = dataframe ,
234- preserve_index = preserve_index ,
235- indexes_position = indexes_position )
244+ pyarrow_schema : List [Tuple [str , str ]] = data_types .extract_pyarrow_schema_from_pandas (
245+ dataframe = dataframe , preserve_index = preserve_index , indexes_position = indexes_position )
236246
237- schema_built = []
238- partition_cols_types = {}
247+ schema_built : List [ Tuple [ str , str ]] = []
248+ partition_cols_types : Dict [ str , str ] = {}
239249 for name , dtype in pyarrow_schema :
240250 if (cast_columns is not None ) and (name in cast_columns .keys ()):
241251 if name in partition_cols :
@@ -256,7 +266,7 @@ def _build_schema(dataframe, partition_cols, preserve_index, indexes_position, c
256266 else :
257267 schema_built .append ((name , athena_type ))
258268
259- partition_cols_schema_built = [(name , partition_cols_types [name ]) for name in partition_cols ]
269+ partition_cols_schema_built : List = [(name , partition_cols_types [name ]) for name in partition_cols ]
260270
261271 logger .debug (f"schema_built:\n { schema_built } " )
262272 logger .debug (f"partition_cols_schema_built:\n { partition_cols_schema_built } " )
@@ -269,12 +279,12 @@ def parse_table_name(path):
269279 return path .rpartition ("/" )[2 ]
270280
271281 @staticmethod
272- def csv_table_definition (table ,
273- partition_cols_schema ,
274- schema ,
275- path ,
276- compression ,
277- extra_args : Optional [Dict [str , Optional [Union [str , int ]]]] = None ):
282+ def csv_table_definition (table : str ,
283+ partition_cols_schema : List [ Tuple [ str , str ]] ,
284+ schema : List [ Tuple [ str , str ]] ,
285+ path : str ,
286+ compression : Optional [ str ] ,
287+ extra_args : Optional [Dict [str , Optional [Union [str , int , List [ str ] ]]]] = None ):
278288 if extra_args is None :
279289 extra_args = {"sep" : "," }
280290 if partition_cols_schema is None :
@@ -301,6 +311,9 @@ def csv_table_definition(table,
301311 refined_schema = [(name , dtype ) if dtype in dtypes_allowed else (name , "string" ) for name , dtype in schema ]
302312 else :
303313 raise InvalidSerDe (f"{ serde } in not in the valid SerDe list." )
314+ if "columns" in extra_args :
315+ refined_schema = [(name , dtype ) for name , dtype in refined_schema
316+ if name in extra_args ["columns" ]] # type: ignore
304317 return {
305318 "Name" : table ,
306319 "PartitionKeys" : [{
@@ -378,7 +391,8 @@ def csv_partition_definition(partition, compression, extra_args=None):
378391 }
379392
380393 @staticmethod
381- def parquet_table_definition (table , partition_cols_schema , schema , path , compression ):
394+ def parquet_table_definition (table : str , partition_cols_schema : List [Tuple [str , str ]],
395+ schema : List [Tuple [str , str ]], path : str , compression : Optional [str ]):
382396 if not partition_cols_schema :
383397 partition_cols_schema = []
384398 compressed = False if compression is None else True
0 commit comments