1313"""The step definitions for workflow."""
1414from __future__ import absolute_import
1515
16- from typing import List , Union , Optional
16+ from typing import Any , Dict , List , Union , Optional
1717
1818from sagemaker .workflow .entities import (
1919 RequestType ,
@@ -33,7 +33,8 @@ def __init__(
3333 ):
3434 """Create a definition for input data used by an EMR cluster(job flow) step.
3535
36- See AWS documentation on the ``StepConfig`` API for more details on the parameters.
36+ See AWS documentation for more information about the `StepConfig
37+ <https://docs.aws.amazon.com/emr/latest/APIReference/API_StepConfig.html>`_ API parameters.
3738
3839 Args:
3940 args(List[str]):
@@ -61,9 +62,89 @@ def to_request(self) -> RequestType:
6162 return config
6263
6364
65+ INSTANCES = "Instances"
66+ INSTANCEGROUPS = "InstanceGroups"
67+ INSTANCEFLEETS = "InstanceFleets"
68+ ERR_STR_WITH_NAME_AUTO_TERMINATION_OR_STEPS = (
69+ "In EMRStep {step_name}, cluster_config "
70+ "should not contain any of the Name, "
71+ "AutoTerminationPolicy and/or Steps."
72+ )
73+
74+ ERR_STR_WITHOUT_INSTANCE = "In EMRStep {step_name}, cluster_config must contain " + INSTANCES + "."
75+
76+ ERR_STR_WITH_KEEPJOBFLOW_OR_TERMINATIONPROTECTED = (
77+ "In EMRStep {step_name}, " + INSTANCES + " should not contain "
78+ "KeepJobFlowAliveWhenNoSteps or "
79+ "TerminationProtected."
80+ )
81+
82+ ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS = (
83+ "In EMRStep {step_name}, "
84+ + INSTANCES
85+ + " should contain either "
86+ + INSTANCEGROUPS
87+ + " or "
88+ + INSTANCEFLEETS
89+ + "."
90+ )
91+
92+ ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG = (
93+ "EMRStep {step_name} can not have both cluster_id"
94+ "or cluster_config."
95+ "To use EMRStep with "
96+ "cluster_config, cluster_id "
97+ "must be explicitly set to None."
98+ )
99+
100+ ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG = (
101+ "EMRStep {step_name} must have either cluster_id or cluster_config"
102+ )
103+
104+
64105class EMRStep (Step ):
65106 """EMR step for workflow."""
66107
108+ def _validate_cluster_config (self , cluster_config , step_name ):
109+ """Validates user provided cluster_config.
110+
111+ Args:
112+ cluster_config(Union[Dict[str, Any], List[Dict[str, Any]]]):
113+ user provided cluster configuration.
114+ step_name: The name of the EMR step.
115+ """
116+
117+ if (
118+ "Name" in cluster_config
119+ or "AutoTerminationPolicy" in cluster_config
120+ or "Steps" in cluster_config
121+ ):
122+ raise ValueError (
123+ ERR_STR_WITH_NAME_AUTO_TERMINATION_OR_STEPS .format (step_name = step_name )
124+ )
125+
126+ if INSTANCES not in cluster_config :
127+ raise ValueError (ERR_STR_WITHOUT_INSTANCE .format (step_name = step_name ))
128+
129+ if (
130+ "KeepJobFlowAliveWhenNoSteps" in cluster_config [INSTANCES ]
131+ or "TerminationProtected" in cluster_config [INSTANCES ]
132+ ):
133+ raise ValueError (
134+ ERR_STR_WITH_KEEPJOBFLOW_OR_TERMINATIONPROTECTED .format (step_name = step_name )
135+ )
136+
137+ if (
138+ INSTANCEGROUPS in cluster_config [INSTANCES ]
139+ and INSTANCEFLEETS in cluster_config [INSTANCES ]
140+ ) or (
141+ INSTANCEGROUPS not in cluster_config [INSTANCES ]
142+ and INSTANCEFLEETS not in cluster_config [INSTANCES ]
143+ ):
144+ raise ValueError (
145+ ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS .format (step_name = step_name )
146+ )
147+
67148 def __init__ (
68149 self ,
69150 name : str ,
@@ -73,8 +154,9 @@ def __init__(
73154 step_config : EMRStepConfig ,
74155 depends_on : Optional [List [Union [str , Step , StepCollection ]]] = None ,
75156 cache_config : CacheConfig = None ,
157+ cluster_config : Dict [str , Any ] = None ,
76158 ):
77- """Constructs a EMRStep.
159+ """Constructs an ` EMRStep` .
78160
79161 Args:
80162 name(str): The name of the EMR step.
@@ -86,16 +168,46 @@ def __init__(
86168 names or `Step` instances or `StepCollection` instances that this `EMRStep`
87169 depends on.
88170 cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
171+ cluster_config(Dict[str, Any]): The recipe of the
172+ EMR cluster, passed as a dictionary.
173+ The elements are defined in the request syntax for `RunJobFlow`.
174+ However, the following elements are not recognized as part of the cluster
175+ configuration and you should not include them in the dictionary:
176+
177+ * ``cluster_config[Name]``
178+ * ``cluster_config[Steps]``
179+ * ``cluster_config[AutoTerminationPolicy]``
180+ * ``cluster_config[Instances][KeepJobFlowAliveWhenNoSteps]``
181+ * ``cluster_config[Instances][TerminationProtected]``
182+
183+ For more information about the fields you can include in your cluster
184+ configuration, see
185+ https://docs.aws.amazon.com/emr/latest/APIReference/API_RunJobFlow.html.
186+ Note that if you want to use ``cluster_config``, then you have to set
187+ ``cluster_id`` as None.
89188
90189 """
91190 super (EMRStep , self ).__init__ (name , display_name , description , StepTypeEnum .EMR , depends_on )
92191
93- emr_step_args = {"ClusterId" : cluster_id , "StepConfig" : step_config .to_request ()}
192+ emr_step_args = {"StepConfig" : step_config .to_request ()}
193+ root_property = Properties (step_name = name , shape_name = "Step" , service_name = "emr" )
194+
195+ if cluster_id is None and cluster_config is None :
196+ raise ValueError (ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG .format (step_name = name ))
197+
198+ if cluster_id is not None and cluster_config is not None :
199+ raise ValueError (ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG .format (step_name = name ))
200+
201+ if cluster_id is not None :
202+ emr_step_args ["ClusterId" ] = cluster_id
203+ root_property .__dict__ ["ClusterId" ] = cluster_id
204+ elif cluster_config is not None :
205+ self ._validate_cluster_config (cluster_config , name )
206+ emr_step_args ["ClusterConfig" ] = cluster_config
207+ root_property .__dict__ ["ClusterConfig" ] = cluster_config
208+
94209 self .args = emr_step_args
95210 self .cache_config = cache_config
96-
97- root_property = Properties (step_name = name , shape_name = "Step" , service_name = "emr" )
98- root_property .__dict__ ["ClusterId" ] = cluster_id
99211 self ._properties = root_property
100212
101213 @property
0 commit comments