Skip to content

Commit 98a5e0d

Browse files
committed
Add security group parameters for EMR
1 parent b32e37e commit 98a5e0d

File tree

5 files changed

+38
-13
lines changed

5 files changed

+38
-13
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ session.spark.create_glue_table(dataframe=dataframe,
205205

206206
```py3
207207
session = awswrangler.Session(spark_session=spark)
208-
dfs = session.spark.flatten(df=df_nested)
208+
dfs = session.spark.flatten(dataframe=df_nested)
209209
for name, df_flat in dfs:
210210
print(name)
211211
df_flat.show()

awswrangler/emr.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,18 @@ def _build_cluster_args(**pars):
4040
if pars["key_pair_name"] is not None:
4141
args["Instances"]["Ec2KeyName"] = pars["key_pair_name"]
4242

43+
# Security groups
44+
if pars["security_group_master"] is not None:
45+
args["Instances"]["EmrManagedMasterSecurityGroup"] = pars["security_group_master"]
46+
if pars["security_groups_master_additional"] is not None:
47+
args["Instances"]["AdditionalMasterSecurityGroups"] = pars["security_groups_master_additional"]
48+
if pars["security_group_slave"] is not None:
49+
args["Instances"]["EmrManagedSlaveSecurityGroup"] = pars["security_group_slave"]
50+
if pars["security_groups_slave_additional"] is not None:
51+
args["Instances"]["AdditionalSlaveSecurityGroups"] = pars["security_groups_slave_additional"]
52+
if pars["security_group_service_access"] is not None:
53+
args["Instances"]["ServiceAccessSecurityGroup"] = pars["security_group_service_access"]
54+
4355
# Configurations
4456
if pars["python3"] or pars["spark_glue_catalog"] or pars["hive_glue_catalog"] or pars["presto_glue_catalog"]:
4557
args["Configurations"]: List = []
@@ -265,7 +277,12 @@ def create_cluster(self,
265277
debugging: bool = True,
266278
applications: Optional[List[str]] = None,
267279
visible_to_all_users: bool = True,
268-
key_pair_name: Optional[str] = None):
280+
key_pair_name: Optional[str] = None,
281+
security_group_master: Optional[str] = None,
282+
security_groups_master_additional: Optional[List[str]] = None,
283+
security_group_slave: Optional[str] = None,
284+
security_groups_slave_additional: Optional[List[str]] = None,
285+
security_group_service_access: Optional[str] = None):
269286
"""
270287
Create a EMR cluster with instance fleets configuration
271288
https://docs.aws.amazon.com/emr/latest/ManagementGuide/emr-instance-fleet.html
@@ -305,6 +322,11 @@ def create_cluster(self,
305322
:param applications: List of applications (e.g ["Hadoop", "Spark", "Ganglia", "Hive"])
306323
:param visible_to_all_users: True or False
307324
:param key_pair_name: Key pair name (string)
325+
:param security_group_master: The identifier of the Amazon EC2 security group for the master node.
326+
:param security_groups_master_additional: A list of additional Amazon EC2 security group IDs for the master node.
327+
:param security_group_slave: The identifier of the Amazon EC2 security group for the core and task nodes.
328+
:param security_groups_slave_additional: A list of additional Amazon EC2 security group IDs for the core and task nodes.
329+
:param security_group_service_access: The identifier of the Amazon EC2 security group for the Amazon EMR service to access clusters in VPC private subnets.
308330
:return: Cluster ID (string)
309331
"""
310332
args = EMR._build_cluster_args(**locals())

awswrangler/spark.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,28 +294,31 @@ def _build_name(name: str, expr: str) -> str:
294294
return f"{name}_{suffix}".replace(".", "_")
295295

296296
@staticmethod
297-
def flatten(df: sql.DataFrame, explode_outer: bool = True, explode_pos: bool = True,
297+
def flatten(dataframe: sql.DataFrame, explode_outer: bool = True, explode_pos: bool = True,
298298
name: str = "root") -> Dict[str, sql.DataFrame]:
299299
"""
300300
Convert a complex nested DataFrame in one (or many) flat DataFrames
301301
If a columns is a struct it is flatten directly.
302302
If a columns is an array or map, then child DataFrames are created in different granularities.
303-
:param df: Spark DataFrame
303+
:param dataframe: Spark DataFrame
304304
:param explode_outer: Should we preserve the null values on arrays?
305305
:param explode_pos: Create columns with the index of the ex-array
306306
:param name: The name of the root Dataframe
307307
:return: A dictionary with the names as Keys and the DataFrames as Values
308308
"""
309-
cols_exprs: List[Tuple[str, str, str]] = Spark._flatten_struct_dataframe(df=df,
309+
cols_exprs: List[Tuple[str, str, str]] = Spark._flatten_struct_dataframe(df=dataframe,
310310
explode_outer=explode_outer,
311311
explode_pos=explode_pos)
312312
exprs_arr: List[str] = [x[2] for x in cols_exprs if Spark._is_array_or_map(x[1])]
313313
exprs: List[str] = [x[2] for x in cols_exprs if not Spark._is_array_or_map(x[1])]
314-
dfs: Dict[str, sql.DataFrame] = {name: df.selectExpr(exprs)}
314+
dfs: Dict[str, sql.DataFrame] = {name: dataframe.selectExpr(exprs)}
315315
exprs = [x[2] for x in cols_exprs if not Spark._is_array_or_map(x[1]) and not x[0].endswith("_pos")]
316316
for expr in exprs_arr:
317-
df_arr = df.selectExpr(exprs + [expr])
317+
df_arr = dataframe.selectExpr(exprs + [expr])
318318
name_new: str = Spark._build_name(name=name, expr=expr)
319-
dfs_new = Spark.flatten(df=df_arr, explode_outer=explode_outer, explode_pos=explode_pos, name=name_new)
319+
dfs_new = Spark.flatten(dataframe=df_arr,
320+
explode_outer=explode_outer,
321+
explode_pos=explode_pos,
322+
name=name_new)
320323
dfs = {**dfs, **dfs_new}
321324
return dfs

docs/source/examples.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ Flatten nested PySpark DataFrame
163163
.. code-block:: python
164164
165165
session = awswrangler.Session(spark_session=spark)
166-
dfs = session.spark.flatten(df=df_nested)
166+
dfs = session.spark.flatten(dataframe=df_nested)
167167
for name, df_flat in dfs:
168168
print(name)
169169
df_flat.show()

testing/test_awswrangler/test_spark.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def test_flatten_simple_struct(session):
190190
])
191191
df = session.spark_session.createDataFrame(data=pdf, schema=schema)
192192
df.printSchema()
193-
dfs = session.spark.flatten(df=df)
193+
dfs = session.spark.flatten(dataframe=df)
194194
assert len(dfs) == 1
195195
dfs["root"].printSchema()
196196
dtypes = str(dfs["root"].dtypes)
@@ -261,7 +261,7 @@ def test_flatten_complex_struct(session):
261261
])
262262
df = session.spark_session.createDataFrame(data=pdf, schema=schema)
263263
df.printSchema()
264-
dfs = session.spark.flatten(df=df)
264+
dfs = session.spark.flatten(dataframe=df)
265265
assert len(dfs) == 1
266266
dfs["root"].printSchema()
267267
dtypes = str(dfs["root"].dtypes)
@@ -294,7 +294,7 @@ def test_flatten_simple_map(session):
294294
])
295295
df = session.spark_session.createDataFrame(data=pdf, schema=schema)
296296
df.printSchema()
297-
dfs = session.spark.flatten(df=df)
297+
dfs = session.spark.flatten(dataframe=df)
298298
assert len(dfs) == 2
299299

300300
# root
@@ -329,7 +329,7 @@ def test_flatten_simple_array(session):
329329
])
330330
df = session.spark_session.createDataFrame(data=pdf, schema=schema)
331331
df.printSchema()
332-
dfs = session.spark.flatten(df=df)
332+
dfs = session.spark.flatten(dataframe=df)
333333
assert len(dfs) == 2
334334

335335
# root

0 commit comments

Comments
 (0)