11from typing import List , Tuple , Dict
22import logging
33
4- import pandas as pd
4+ import pandas as pd # type: ignore
55
6- from pyspark .sql .functions import pandas_udf , PandasUDFType , spark_partition_id
7- from pyspark .sql .types import TimestampType
8- from pyspark .sql import DataFrame
6+ from pyspark import sql
97
108from awswrangler .exceptions import MissingBatchDetected , UnsupportedFileFormat
119
@@ -38,7 +36,7 @@ def date2timestamp(dataframe):
3836 for name , dtype in dataframe .dtypes :
3937 if dtype == "date" :
4038 dataframe = dataframe .withColumn (
41- name , dataframe [name ].cast (TimestampType ()))
39+ name , dataframe [name ].cast (sql . types . TimestampType ()))
4240 logger .warning (
4341 f"Casting column { name } from date to timestamp!" )
4442 return dataframe
@@ -98,8 +96,9 @@ def to_redshift(
9896 spark .conf .set ("spark.sql.execution.arrow.enabled" , "true" )
9997 session_primitives = self ._session .primitives
10098
101- @pandas_udf (returnType = "objects_paths string" ,
102- functionType = PandasUDFType .GROUPED_MAP )
99+ @sql .functions .pandas_udf (
100+ returnType = "objects_paths string" ,
101+ functionType = sql .functions .PandasUDFType .GROUPED_MAP )
103102 def write (pandas_dataframe ):
104103 del pandas_dataframe ["aws_data_wrangler_internal_partition_id" ]
105104 paths = session_primitives .session .pandas .to_parquet (
@@ -112,7 +111,7 @@ def write(pandas_dataframe):
112111 return pd .DataFrame .from_dict ({"objects_paths" : paths })
113112
114113 df_objects_paths = dataframe .repartition (numPartitions = num_partitions ) \
115- .withColumn ("aws_data_wrangler_internal_partition_id" , spark_partition_id ()) \
114+ .withColumn ("aws_data_wrangler_internal_partition_id" , sql . functions . spark_partition_id ()) \
116115 .groupby ("aws_data_wrangler_internal_partition_id" ) \
117116 .apply (write )
118117
@@ -227,7 +226,8 @@ def _is_map(dtype: str) -> bool:
227226
228227 @staticmethod
229228 def _is_array_or_map (dtype : str ) -> bool :
230- return True if (dtype .startswith ("array" ) or dtype .startswith ("map" )) else False
229+ return True if (dtype .startswith ("array" )
230+ or dtype .startswith ("map" )) else False
231231
232232 @staticmethod
233233 def _parse_aux (path : str , aux : str ) -> Tuple [str , str ]:
@@ -242,19 +242,22 @@ def _parse_aux(path: str, aux: str) -> Tuple[str, str]:
242242
243243 @staticmethod
244244 def _flatten_struct_column (path : str , dtype : str ) -> List [Tuple [str , str ]]:
245- dtype : str = dtype [7 :- 1 ] # Cutting off "struct<" and ">"
245+ dtype = dtype [7 :- 1 ] # Cutting off "struct<" and ">"
246246 cols : List [Tuple [str , str ]] = []
247247 struct_acc : int = 0
248248 path_child : str
249249 dtype_child : str
250250 aux : str = ""
251- for c , i in zip (dtype , range (len (dtype ), 0 , - 1 )): # Zipping a descendant ID for each letter
251+ for c , i in zip (dtype ,
252+ range (len (dtype ), 0 ,
253+ - 1 )): # Zipping a descendant ID for each letter
252254 if ((c == "," ) and (struct_acc == 0 )) or (i == 1 ):
253255 if i == 1 :
254256 aux += c
255257 path_child , dtype_child = Spark ._parse_aux (path = path , aux = aux )
256258 if Spark ._is_struct (dtype = dtype_child ):
257- cols += Spark ._flatten_struct_column (path = path_child , dtype = dtype_child ) # Recursion
259+ cols += Spark ._flatten_struct_column (
260+ path = path_child , dtype = dtype_child ) # Recursion
258261 elif Spark ._is_array (dtype = dtype ):
259262 cols .append ((path , "array" ))
260263 else :
@@ -271,10 +274,10 @@ def _flatten_struct_column(path: str, dtype: str) -> List[Tuple[str, str]]:
271274 return cols
272275
273276 @staticmethod
274- def _flatten_struct_dataframe (
275- df : DataFrame ,
276- explode_outer : bool = True ,
277- explode_pos : bool = True ) -> List [Tuple [str , str , str ]]:
277+ def _flatten_struct_dataframe (df : sql . DataFrame ,
278+ explode_outer : bool = True ,
279+ explode_pos : bool = True
280+ ) -> List [Tuple [str , str , str ]]:
278281 explode : str = "EXPLODE_OUTER" if explode_outer is True else "EXPLODE"
279282 explode = f"POS{ explode } " if explode_pos is True else explode
280283 cols : List [Tuple [str , str ]] = []
@@ -308,26 +311,34 @@ def _flatten_struct_dataframe(
308311
309312 @staticmethod
310313 def _build_name (name : str , expr : str ) -> str :
311- suffix : str = expr [expr .find ("(" ) + 1 : expr .find (")" )]
314+ suffix : str = expr [expr .find ("(" ) + 1 :expr .find (")" )]
312315 return f"{ name } _{ suffix } " .replace ("." , "_" )
313316
314317 @staticmethod
315- def flatten (
316- df : DataFrame ,
317- explode_outer : bool = True ,
318- explode_pos : bool = True ,
319- name : str = "root" ) -> Dict [str , DataFrame ]:
320- cols_exprs : List [Tuple [str , str , str ]] = Spark ._flatten_struct_dataframe (
321- df = df ,
322- explode_outer = explode_outer ,
323- explode_pos = explode_pos )
324- exprs_arr : List [str ] = [x [2 ] for x in cols_exprs if Spark ._is_array_or_map (x [1 ])]
325- exprs : List [str ] = [x [2 ] for x in cols_exprs if not Spark ._is_array_or_map (x [1 ])]
326- dfs : Dict [str , DataFrame ] = {name : df .selectExpr (exprs )}
327- exprs : List [str ] = [x [2 ] for x in cols_exprs if not Spark ._is_array_or_map (x [1 ]) and not x [0 ].endswith ("_pos" )]
318+ def flatten (df : sql .DataFrame ,
319+ explode_outer : bool = True ,
320+ explode_pos : bool = True ,
321+ name : str = "root" ) -> Dict [str , sql .DataFrame ]:
322+ cols_exprs : List [
323+ Tuple [str , str , str ]] = Spark ._flatten_struct_dataframe (
324+ df = df , explode_outer = explode_outer , explode_pos = explode_pos )
325+ exprs_arr : List [str ] = [
326+ x [2 ] for x in cols_exprs if Spark ._is_array_or_map (x [1 ])
327+ ]
328+ exprs : List [str ] = [
329+ x [2 ] for x in cols_exprs if not Spark ._is_array_or_map (x [1 ])
330+ ]
331+ dfs : Dict [str , sql .DataFrame ] = {name : df .selectExpr (exprs )}
332+ exprs = [
333+ x [2 ] for x in cols_exprs
334+ if not Spark ._is_array_or_map (x [1 ]) and not x [0 ].endswith ("_pos" )
335+ ]
328336 for expr in exprs_arr :
329337 df_arr = df .selectExpr (exprs + [expr ])
330338 name_new : str = Spark ._build_name (name = name , expr = expr )
331- dfs_new = Spark .flatten (df = df_arr , explode_outer = explode_outer , explode_pos = explode_pos , name = name_new )
339+ dfs_new = Spark .flatten (df = df_arr ,
340+ explode_outer = explode_outer ,
341+ explode_pos = explode_pos ,
342+ name = name_new )
332343 dfs = {** dfs , ** dfs_new }
333344 return dfs
0 commit comments