Skip to content

Commit 7d68fc4

Browse files
committed
Add first spark flatten test (Simple struct)
1 parent 9b8d5f1 commit 7d68fc4

File tree

4 files changed

+51
-3
lines changed

4 files changed

+51
-3
lines changed

awswrangler/spark.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,16 @@ def flatten(df: sql.DataFrame,
319319
explode_outer: bool = True,
320320
explode_pos: bool = True,
321321
name: str = "root") -> Dict[str, sql.DataFrame]:
322+
"""
323+
Convert a complex nested DataFrame in one (or many) flat DataFrames
324+
If a columns is a struct it is flatten directly.
325+
If a columns is an array or map, then child DataFrames are created in different granularities.
326+
:param df: Spark DataFrame
327+
:param explode_outer: Should we preserve the null values on arrays?
328+
:param explode_pos: Create columns with the index of the ex-array
329+
:param name: The name of the root Dataframe
330+
:return: A list of Dictionaries with the name as Keys and the DataFrames as Values
331+
"""
322332
cols_exprs: List[
323333
Tuple[str, str, str]] = Spark._flatten_struct_dataframe(
324334
df=df, explode_outer=explode_outer, explode_pos=explode_pos)

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ flake8~=3.7.8
44
pytest-cov~=2.8.1
55
cfn-lint~=0.23.3
66
twine~=1.13.0
7-
pyspark~=2.4.4
87
wheel~=0.33.6
98
sphinx~=2.1.2
9+
pyspark~=2.4.4
1010
pyspark-stubs~=2.4.0

testing/run-tests.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
set -e
44

55
cd ..
6-
rm -rf .pytest_cache .mypy_cache
76
pip install -e .
87
yapf --in-place --recursive setup.py awswrangler testing/test_awswrangler
98
mypy awswrangler
109
flake8 setup.py awswrangler testing/test_awswrangler
1110
pytest --cov=awswrangler testing/test_awswrangler
12-
rm -rf .pytest_cache .mypy_cache
1311
cd testing

testing/test_awswrangler/test_spark.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import pytest
44
import boto3
5+
import pandas as pd
56
from pyspark.sql import SparkSession
67
from pyspark.sql.functions import lit, array, create_map, struct
8+
from pyspark.sql.types import StructType, StructField, IntegerType
79

810
from awswrangler import Session
911

@@ -164,3 +166,41 @@ def test_create_glue_table_csv(session, bucket, database, compression,
164166
assert int(pandas_df.iloc[0]["id"]) == 4
165167
assert pandas_df.iloc[0]["name"] == "four"
166168
assert float(pandas_df.iloc[0]["value"]) == 4.0
169+
170+
171+
def test_flatten_simple_struct(session):
172+
print()
173+
pdf = pd.DataFrame({
174+
"a": [1, 2],
175+
"b": [
176+
{
177+
"bb1": 1,
178+
"bb2": 2
179+
},
180+
{
181+
"bb1": 1,
182+
"bb2": 2
183+
},
184+
],
185+
})
186+
schema = StructType([
187+
StructField(name="a", dataType=IntegerType(), nullable=True),
188+
StructField(name="b",
189+
dataType=StructType([
190+
StructField(name="bb1",
191+
dataType=IntegerType(),
192+
nullable=True),
193+
StructField(name="bb2",
194+
dataType=IntegerType(),
195+
nullable=True),
196+
]),
197+
nullable=True),
198+
])
199+
df = session.spark_session.createDataFrame(data=pdf, schema=schema)
200+
df.printSchema()
201+
dfs = session.spark.flatten(df=df)
202+
assert len(dfs) == 1
203+
dfs["root"].printSchema()
204+
assert str(dfs["root"].dtypes
205+
) == "[('a', 'int'), ('b_bb1', 'int'), ('b_bb2', 'int')]"
206+
assert df.count() == dfs["root"].count()

0 commit comments

Comments
 (0)