Skip to content

Commit 2e2b000

Browse files
committed
Add Spark.flatten() and mypy initial development
1 parent 49b7880 commit 2e2b000

File tree

4 files changed

+158
-34
lines changed

4 files changed

+158
-34
lines changed

awswrangler/data_types.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1+
from typing import List, Tuple, Dict, Callable
12
import logging
23
from datetime import datetime, date
34

4-
import pyarrow
5+
import pyarrow as pa # type: ignore
6+
import pandas as pd # type: ignore
57

6-
from awswrangler.exceptions import UnsupportedType, UndetectedType
8+
from awswrangler.exceptions import UnsupportedType, UndetectedType # type: ignore
79

810
logger = logging.getLogger(__name__)
911

1012

11-
def athena2pandas(dtype):
13+
def athena2pandas(dtype: str) -> str:
1214
dtype = dtype.lower()
1315
if dtype in ["int", "integer", "bigint", "smallint", "tinyint"]:
1416
return "Int64"
@@ -28,7 +30,7 @@ def athena2pandas(dtype):
2830
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
2931

3032

31-
def athena2pyarrow(dtype):
33+
def athena2pyarrow(dtype: str) -> str:
3234
dtype = dtype.lower()
3335
if dtype == "tinyint":
3436
return "int8"
@@ -54,7 +56,7 @@ def athena2pyarrow(dtype):
5456
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
5557

5658

57-
def athena2python(dtype):
59+
def athena2python(dtype: str) -> type:
5860
dtype = dtype.lower()
5961
if dtype in ["int", "integer", "bigint", "smallint", "tinyint"]:
6062
return int
@@ -72,7 +74,7 @@ def athena2python(dtype):
7274
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
7375

7476

75-
def athena2redshift(dtype):
77+
def athena2redshift(dtype: str) -> str:
7678
dtype = dtype.lower()
7779
if dtype == "smallint":
7880
return "SMALLINT"
@@ -96,7 +98,7 @@ def athena2redshift(dtype):
9698
raise UnsupportedType(f"Unsupported Athena type: {dtype}")
9799

98100

99-
def pandas2athena(dtype):
101+
def pandas2athena(dtype: str) -> str:
100102
dtype = dtype.lower()
101103
if dtype == "int32":
102104
return "int"
@@ -116,7 +118,7 @@ def pandas2athena(dtype):
116118
raise UnsupportedType(f"Unsupported Pandas type: {dtype}")
117119

118120

119-
def pandas2redshift(dtype):
121+
def pandas2redshift(dtype: str) -> str:
120122
dtype = dtype.lower()
121123
if dtype == "int32":
122124
return "INTEGER"
@@ -136,7 +138,7 @@ def pandas2redshift(dtype):
136138
raise UnsupportedType("Unsupported Pandas type: " + dtype)
137139

138140

139-
def pyarrow2athena(dtype):
141+
def pyarrow2athena(dtype: pa.types) -> str:
140142
dtype_str = str(dtype).lower()
141143
if dtype_str == "int8":
142144
return "tinyint"
@@ -167,7 +169,7 @@ def pyarrow2athena(dtype):
167169
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
168170

169171

170-
def pyarrow2redshift(dtype):
172+
def pyarrow2redshift(dtype: pa.types) -> str:
171173
dtype_str = str(dtype).lower()
172174
if dtype_str == "int16":
173175
return "SMALLINT"
@@ -191,25 +193,25 @@ def pyarrow2redshift(dtype):
191193
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")
192194

193195

194-
def python2athena(python_type):
195-
python_type = str(python_type)
196-
if python_type == "<class 'int'>":
196+
def python2athena(python_type: type) -> str:
197+
python_type_str: str = str(python_type)
198+
if python_type_str == "<class 'int'>":
197199
return "bigint"
198-
elif python_type == "<class 'float'>":
200+
elif python_type_str == "<class 'float'>":
199201
return "double"
200-
elif python_type == "<class 'boll'>":
202+
elif python_type_str == "<class 'boll'>":
201203
return "boolean"
202-
elif python_type == "<class 'str'>":
204+
elif python_type_str == "<class 'str'>":
203205
return "string"
204-
elif python_type == "<class 'datetime.datetime'>":
206+
elif python_type_str == "<class 'datetime.datetime'>":
205207
return "timestamp"
206-
elif python_type == "<class 'datetime.date'>":
208+
elif python_type_str == "<class 'datetime.date'>":
207209
return "date"
208210
else:
209-
raise UnsupportedType(f"Unsupported Python type: {python_type}")
211+
raise UnsupportedType(f"Unsupported Python type: {python_type_str}")
210212

211213

212-
def redshift2athena(dtype):
214+
def redshift2athena(dtype: str) -> str:
213215
dtype_str = str(dtype)
214216
if dtype_str in ["SMALLINT", "INT2"]:
215217
return "smallint"
@@ -233,8 +235,8 @@ def redshift2athena(dtype):
233235
raise UnsupportedType(f"Unsupported Redshift type: {dtype_str}")
234236

235237

236-
def redshift2pyarrow(dtype):
237-
dtype_str = str(dtype)
238+
def redshift2pyarrow(dtype: str) -> str:
239+
dtype_str: str = str(dtype)
238240
if dtype_str in ["SMALLINT", "INT2"]:
239241
return "int16"
240242
elif dtype_str in ["INTEGER", "INT", "INT4"]:
@@ -257,7 +259,7 @@ def redshift2pyarrow(dtype):
257259
raise UnsupportedType(f"Unsupported Redshift type: {dtype_str}")
258260

259261

260-
def spark2redshift(dtype):
262+
def spark2redshift(dtype: str) -> str:
261263
dtype = dtype.lower()
262264
if dtype == "smallint":
263265
return "SMALLINT"
@@ -281,7 +283,7 @@ def spark2redshift(dtype):
281283
raise UnsupportedType("Unsupported Spark type: " + dtype)
282284

283285

284-
def convert_schema(func, schema):
286+
def convert_schema(func: Callable, schema: List[Tuple[str, str]]) -> Dict[str, str]:
285287
"""
286288
Convert schema in the format of {"col name": "bigint", "col2 name": "int"}
287289
applying some data types conversion function (e.g. spark2redshift)
@@ -293,16 +295,16 @@ def convert_schema(func, schema):
293295
return {name: func(dtype) for name, dtype in schema}
294296

295297

296-
def extract_pyarrow_schema_from_pandas(dataframe,
297-
preserve_index,
298-
indexes_position="right"):
298+
def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,
299+
preserve_index: bool,
300+
indexes_position: str = "right") -> List[Tuple[str, str]]:
299301
"""
300302
Extract the related Pyarrow schema from any Pandas DataFrame
301303
302304
:param dataframe: Pandas Dataframe
303305
:param preserve_index: True or False
304306
:param indexes_position: "right" or "left"
305-
:return: Pyarrow schema (e.g. {"col name": "bigint", "col2 name": "int"})
307+
:return: Pyarrow schema (e.g. [("col name": "bigint"), ("col2 name": "int")]
306308
"""
307309
cols = []
308310
cols_dtypes = {}
@@ -319,8 +321,8 @@ def extract_pyarrow_schema_from_pandas(dataframe,
319321

320322
# Filling cols_dtypes and indexes
321323
indexes = []
322-
for field in pyarrow.Schema.from_pandas(df=dataframe[cols],
323-
preserve_index=preserve_index):
324+
for field in pa.Schema.from_pandas(df=dataframe[cols],
325+
preserve_index=preserve_index):
324326
name = str(field.name)
325327
dtype = field.type
326328
cols_dtypes[name] = dtype

awswrangler/s3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from math import ceil
33
import logging
44

5-
from botocore.exceptions import ClientError
6-
import s3fs
7-
import tenacity
5+
from botocore.exceptions import ClientError # type: ignore
6+
import s3fs # type: ignore
7+
import tenacity # type: ignore
88

99
from awswrangler.utils import calculate_bounders, wait_process_release
1010

awswrangler/spark.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from typing import List, Tuple, Dict
12
import logging
23

34
import pandas as pd
45

56
from pyspark.sql.functions import pandas_udf, PandasUDFType, spark_partition_id
67
from pyspark.sql.types import TimestampType
8+
from pyspark.sql import DataFrame
79

810
from awswrangler.exceptions import MissingBatchDetected, UnsupportedFileFormat
911

@@ -210,3 +212,122 @@ def create_glue_table(self,
210212
extra_args=extra_args)
211213
if load_partitions:
212214
self._session.athena.repair_table(database=database, table=table)
215+
216+
@staticmethod
217+
def _is_struct(dtype: str) -> bool:
218+
return True if dtype.startswith("struct") else False
219+
220+
@staticmethod
221+
def _is_array(dtype: str) -> bool:
222+
return True if dtype.startswith("array") else False
223+
224+
@staticmethod
225+
def _is_map(dtype: str) -> bool:
226+
return True if dtype.startswith("map") else False
227+
228+
@staticmethod
229+
def _is_array_or_map(dtype: str) -> bool:
230+
return True if (dtype.startswith("array") or dtype.startswith("map")) else False
231+
232+
@staticmethod
233+
def _parse_aux(path: str, aux: str) -> Tuple[str, str]:
234+
path_child: str
235+
dtype: str
236+
if ":" in aux:
237+
path_child, dtype = aux.split(sep=":", maxsplit=1)
238+
else:
239+
path_child = "element"
240+
dtype = aux
241+
return f"{path}.{path_child}", dtype
242+
243+
@staticmethod
244+
def _flatten_struct_column(path: str, dtype: str) -> List[Tuple[str, str]]:
245+
dtype: str = dtype[7:-1] # Cutting off "struct<" and ">"
246+
cols: List[Tuple[str, str]] = []
247+
struct_acc: int = 0
248+
path_child: str
249+
dtype_child: str
250+
aux: str = ""
251+
for c, i in zip(dtype, range(len(dtype), 0, -1)): # Zipping a descendant ID for each letter
252+
if ((c == ",") and (struct_acc == 0)) or (i == 1):
253+
if i == 1:
254+
aux += c
255+
path_child, dtype_child = Spark._parse_aux(path=path, aux=aux)
256+
if Spark._is_struct(dtype=dtype_child):
257+
cols += Spark._flatten_struct_column(path=path_child, dtype=dtype_child) # Recursion
258+
elif Spark._is_array(dtype=dtype):
259+
cols.append((path, "array"))
260+
else:
261+
cols.append((path_child, dtype_child))
262+
aux = ""
263+
elif c == "<":
264+
aux += c
265+
struct_acc += 1
266+
elif c == ">":
267+
aux += c
268+
struct_acc -= 1
269+
else:
270+
aux += c
271+
return cols
272+
273+
@staticmethod
274+
def _flatten_struct_dataframe(
275+
df: DataFrame,
276+
explode_outer: bool = True,
277+
explode_pos: bool = True) -> List[Tuple[str, str, str]]:
278+
explode: str = "EXPLODE_OUTER" if explode_outer is True else "EXPLODE"
279+
explode = f"POS{explode}" if explode_pos is True else explode
280+
cols: List[Tuple[str, str]] = []
281+
for path, dtype in df.dtypes:
282+
if Spark._is_struct(dtype=dtype):
283+
cols += Spark._flatten_struct_column(path=path, dtype=dtype)
284+
elif Spark._is_array(dtype=dtype):
285+
cols.append((path, "array"))
286+
elif Spark._is_map(dtype=dtype):
287+
cols.append((path, "map"))
288+
else:
289+
cols.append((path, dtype))
290+
cols_exprs: List[Tuple[str, str, str]] = []
291+
expr: str
292+
for path, dtype in cols:
293+
path_under = path.replace('.', '_')
294+
if Spark._is_array(dtype):
295+
if explode_pos:
296+
expr = f"{explode}({path}) AS ({path_under}_pos, {path_under})"
297+
else:
298+
expr = f"{explode}({path}) AS {path_under}"
299+
elif Spark._is_map(dtype):
300+
if explode_pos:
301+
expr = f"{explode}({path}) AS ({path_under}_pos, {path_under}_key, {path_under}_value)"
302+
else:
303+
expr = f"{explode}({path}) AS ({path_under}_key, {path_under}_value)"
304+
else:
305+
expr = f"{path} AS {path.replace('.', '_')}"
306+
cols_exprs.append((path, dtype, expr))
307+
return cols_exprs
308+
309+
@staticmethod
310+
def _build_name(name: str, expr: str) -> str:
311+
suffix: str = expr[expr.find("(") + 1: expr.find(")")]
312+
return f"{name}_{suffix}".replace(".", "_")
313+
314+
@staticmethod
315+
def flatten(
316+
df: DataFrame,
317+
explode_outer: bool = True,
318+
explode_pos: bool = True,
319+
name: str = "root") -> Dict[str, DataFrame]:
320+
cols_exprs: List[Tuple[str, str, str]] = Spark._flatten_struct_dataframe(
321+
df=df,
322+
explode_outer=explode_outer,
323+
explode_pos=explode_pos)
324+
exprs_arr: List[str] = [x[2] for x in cols_exprs if Spark._is_array_or_map(x[1])]
325+
exprs: List[str] = [x[2] for x in cols_exprs if not Spark._is_array_or_map(x[1])]
326+
dfs: Dict[str, DataFrame] = {name: df.selectExpr(exprs)}
327+
exprs: List[str] = [x[2] for x in cols_exprs if not Spark._is_array_or_map(x[1]) and not x[0].endswith("_pos")]
328+
for expr in exprs_arr:
329+
df_arr = df.selectExpr(exprs + [expr])
330+
name_new: str = Spark._build_name(name=name, expr=expr)
331+
dfs_new = Spark.flatten(df=df_arr, explode_outer=explode_outer, explode_pos=explode_pos, name=name_new)
332+
dfs = {**dfs, **dfs_new}
333+
return dfs

requirements-dev.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ twine~=1.13.0
66
pyspark~=2.4.3
77
wheel~=0.33.6
88
sphinx~=2.1.2
9-
pyspark-stubs~=2.4.0
9+
pyspark-stubs~=2.4.0
10+
mypy~=0.730

0 commit comments

Comments
 (0)