diff --git a/src/sagemaker/workflow/emr_step.py b/src/sagemaker/workflow/emr_step.py index 293c45bc6c..b03fe6b96f 100644 --- a/src/sagemaker/workflow/emr_step.py +++ b/src/sagemaker/workflow/emr_step.py @@ -21,8 +21,9 @@ from sagemaker.workflow.properties import ( Properties, ) +from sagemaker.workflow.retry import StepRetryPolicy from sagemaker.workflow.step_collections import StepCollection -from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig +from sagemaker.workflow.steps import ConfigurableRetryStep, Step, StepTypeEnum, CacheConfig class EMRStepConfig: @@ -110,8 +111,8 @@ def to_request(self) -> RequestType: ) -class EMRStep(Step): - """EMR step for workflow.""" +class EMRStep(ConfigurableRetryStep): + """EMR step for workflow with configurable retry policies.""" def _validate_cluster_config(self, cluster_config, step_name): """Validates user provided cluster_config. @@ -164,6 +165,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, cluster_config: Optional[Dict[str, Any]] = None, execution_role_arn: Optional[str] = None, + retry_policies: Optional[List[StepRetryPolicy]] = None, ): """Constructs an `EMRStep`. @@ -200,7 +202,14 @@ def __init__( called on the cluster specified by ``cluster_id``, so you can only include this field if ``cluster_id`` is not None. """ - super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on) + super().__init__( + name=name, + step_type=StepTypeEnum.EMR, + display_name=display_name, + description=description, + depends_on=depends_on, + retry_policies=retry_policies, + ) emr_step_args = {"StepConfig": step_config.to_request()} root_property = Properties(step_name=name, step=self, shape_name="Step", service_name="emr") @@ -248,7 +257,7 @@ def properties(self) -> RequestType: return self._properties def to_request(self) -> RequestType: - """Updates the dictionary with cache configuration.""" + """Updates the dictionary with cache configuration and retry policies""" request_dict = super().to_request() if self.cache_config: request_dict.update(self.cache_config.config) diff --git a/tests/integ/sagemaker/workflow/test_emr_steps.py b/tests/integ/sagemaker/workflow/test_emr_steps.py index b757742ddc..d5c8928229 100644 --- a/tests/integ/sagemaker/workflow/test_emr_steps.py +++ b/tests/integ/sagemaker/workflow/test_emr_steps.py @@ -20,6 +20,7 @@ from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig from sagemaker.workflow.parameters import ParameterInteger from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.retry import StepRetryPolicy, StepExceptionTypeEnum @pytest.fixture @@ -134,3 +135,215 @@ def test_emr_with_cluster_config(sagemaker_session, role, pipeline_name, region_ pipeline.delete() except Exception: pass + + +def test_emr_with_retry_policies(sagemaker_session, role, pipeline_name, region_name): + """Test EMR steps with retry policies in both cluster_id and cluster_config scenarios.""" + emr_step_config = EMRStepConfig( + jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar", + args=["dummy_emr_script_path"], + ) + + retry_policies = [ + StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], + interval_seconds=1, + max_attempts=3, + backoff_rate=2.0, + ) + ] + + # Step with existing cluster and retry policies + step_emr_1 = EMRStep( + name="emr-step-1", + cluster_id="j-1YONHTCP3YZKC", + display_name="emr_step_1", + description="EMR Step with retry policies", + step_config=emr_step_config, + retry_policies=retry_policies, + ) + + # Step with cluster config and retry policies + cluster_config = { + "Instances": { + "InstanceGroups": [ + { + "Name": "Master Instance Group", + "InstanceRole": "MASTER", + "InstanceCount": 1, + "InstanceType": "m1.small", + "Market": "ON_DEMAND", + } + ], + "InstanceCount": 1, + "HadoopVersion": "MyHadoopVersion", + }, + "AmiVersion": "3.8.0", + "AdditionalInfo": "MyAdditionalInfo", + } + + step_emr_2 = EMRStep( + name="emr-step-2", + display_name="emr_step_2", + description="EMR Step with cluster config and retry policies", + cluster_id=None, + step_config=emr_step_config, + cluster_config=cluster_config, + retry_policies=retry_policies, + ) + + pipeline = Pipeline( + name=pipeline_name, + steps=[step_emr_1, step_emr_2], + sagemaker_session=sagemaker_session, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + assert re.match( + rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", + create_arn, + ) + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_emr_with_expire_after_retry_policy(sagemaker_session, role, pipeline_name, region_name): + """Test EMR step with retry policy using expire_after_mins.""" + emr_step_config = EMRStepConfig( + jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar", + args=["dummy_emr_script_path"], + ) + + retry_policies = [ + StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], + interval_seconds=1, + expire_after_mins=30, + backoff_rate=2.0, + ) + ] + + step_emr = EMRStep( + name="emr-step-expire", + cluster_id="j-1YONHTCP3YZKC", + display_name="emr_step_expire", + description="EMR Step with expire after retry policy", + step_config=emr_step_config, + retry_policies=retry_policies, + ) + + pipeline = Pipeline( + name=pipeline_name, + steps=[step_emr], + sagemaker_session=sagemaker_session, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + assert re.match( + rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", + create_arn, + ) + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_emr_with_multiple_exception_types(sagemaker_session, role, pipeline_name, region_name): + """Test EMR step with multiple exception types in retry policy.""" + retry_policies = [ + StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT, StepExceptionTypeEnum.THROTTLING], + interval_seconds=1, + max_attempts=3, + backoff_rate=2.0, + ) + ] + + step_emr = EMRStep( + name="emr-step-multi-except", + cluster_id="j-1YONHTCP3YZKC", + display_name="emr_step_multi_except", + description="EMR Step with multiple exception types", + step_config=EMRStepConfig( + jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar", + args=["dummy_emr_script_path"], + ), + retry_policies=retry_policies, + ) + + pipeline = Pipeline( + name=pipeline_name, + steps=[step_emr], + sagemaker_session=sagemaker_session, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + assert re.match( + rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", + create_arn, + ) + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_emr_with_multiple_retry_policies(sagemaker_session, role, pipeline_name, region_name): + """Test EMR step with multiple retry policies.""" + retry_policies = [ + StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], + interval_seconds=1, + max_attempts=3, + backoff_rate=2.0, + ), + StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.THROTTLING], + interval_seconds=5, + expire_after_mins=60, + backoff_rate=1.5, + ), + ] + + step_emr = EMRStep( + name="emr-step-multi-policy", + cluster_id="j-1YONHTCP3YZKC", + display_name="emr_step_multi_policy", + description="EMR Step with multiple retry policies", + step_config=EMRStepConfig( + jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar", + args=["dummy_emr_script_path"], + ), + retry_policies=retry_policies, + ) + + pipeline = Pipeline( + name=pipeline_name, + steps=[step_emr], + sagemaker_session=sagemaker_session, + ) + + try: + response = pipeline.create(role) + create_arn = response["PipelineArn"] + assert re.match( + rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", + create_arn, + ) + finally: + try: + pipeline.delete() + except Exception: + pass diff --git a/tests/unit/sagemaker/workflow/test_emr_step.py b/tests/unit/sagemaker/workflow/test_emr_step.py index 9c78b7675e..cc732cee51 100644 --- a/tests/unit/sagemaker/workflow/test_emr_step.py +++ b/tests/unit/sagemaker/workflow/test_emr_step.py @@ -29,6 +29,7 @@ from sagemaker.workflow.steps import CacheConfig from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.parameters import ParameterString +from sagemaker.workflow.retry import StepRetryPolicy, StepExceptionTypeEnum from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered @@ -476,3 +477,325 @@ def test_emr_step_throws_exception_when_cluster_config_contains_restricted_entit actual_error_msg = exceptionInfo.value.args[0] assert actual_error_msg == expected_error_msg + + +def test_emr_step_with_retry_policies(sagemaker_session): + """Test EMRStep with retry policies.""" + emr_step_config = EMRStepConfig( + jar="s3:/script-runner/script-runner.jar", + args=["--arg_0", "arg_0_value"], + main_class="com.my.main", + properties=[{"Key": "Foo", "Value": "Foo_value"}, {"Key": "Bar", "Value": "Bar_value"}], + ) + + retry_policies = [ + StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], + interval_seconds=1, + max_attempts=3, + backoff_rate=2.0, + ), + StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.THROTTLING], + interval_seconds=5, + max_attempts=5, + backoff_rate=1.5, + ), + ] + + emr_step = EMRStep( + name="MyEMRStep", + display_name="MyEMRStep", + description="MyEMRStepDescription", + cluster_id="MyClusterID", + step_config=emr_step_config, + depends_on=["TestStep"], + cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"), + retry_policies=retry_policies, + ) + + expected_request = { + "Name": "MyEMRStep", + "Type": "EMR", + "Arguments": { + "ClusterId": "MyClusterID", + "StepConfig": { + "HadoopJarStep": { + "Args": ["--arg_0", "arg_0_value"], + "Jar": "s3:/script-runner/script-runner.jar", + "MainClass": "com.my.main", + "Properties": [ + {"Key": "Foo", "Value": "Foo_value"}, + {"Key": "Bar", "Value": "Bar_value"}, + ], + } + }, + }, + "DependsOn": ["TestStep"], + "DisplayName": "MyEMRStep", + "Description": "MyEMRStepDescription", + "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, + "RetryPolicies": [ + { + "ExceptionType": ["Step.SERVICE_FAULT"], + "IntervalSeconds": 1, + "MaxAttempts": 3, + "BackoffRate": 2.0, + }, + { + "ExceptionType": ["Step.THROTTLING"], + "IntervalSeconds": 5, + "MaxAttempts": 5, + "BackoffRate": 1.5, + }, + ], + } + + assert emr_step.to_request() == expected_request + + +def test_emr_step_with_retry_policies_and_cluster_config(): + """Test EMRStep with both retry policies and cluster configuration.""" + retry_policies = [ + StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], + interval_seconds=1, + max_attempts=3, + backoff_rate=2.0, + ) + ] + + emr_step = EMRStep( + name=g_emr_step_name, + display_name="MyEMRStep", + description="MyEMRStepDescription", + cluster_id=None, + cluster_config=g_cluster_config, + step_config=g_emr_step_config, + cache_config=CacheConfig(enable_caching=True, expire_after="PT1H"), + retry_policies=retry_policies, + ) + + expected_request = { + "Name": "MyEMRStep", + "Type": "EMR", + "Arguments": { + "StepConfig": {"HadoopJarStep": {"Jar": "s3:/script-runner/script-runner.jar"}}, + "ClusterConfig": { + "AdditionalInfo": "MyAdditionalInfo", + "AmiVersion": "3.8.0", + "Instances": { + "HadoopVersion": "MyHadoopVersion", + "InstanceCount": 1, + "InstanceGroups": [ + { + "InstanceCount": 1, + "InstanceRole": "MASTER", + "InstanceType": "m1.small", + "Market": "ON_DEMAND", + "Name": "Master Instance Group", + } + ], + }, + }, + }, + "DisplayName": "MyEMRStep", + "Description": "MyEMRStepDescription", + "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, + "RetryPolicies": [ + { + "ExceptionType": ["Step.SERVICE_FAULT"], + "IntervalSeconds": 1, + "MaxAttempts": 3, + "BackoffRate": 2.0, + } + ], + } + + assert emr_step.to_request() == expected_request + + +def test_emr_step_with_retry_policy_expire_after(): + """Test EMRStep with retry policy using expire_after_mins.""" + emr_step_config = EMRStepConfig( + jar="s3:/script-runner/script-runner.jar", + args=["--arg_0", "arg_0_value"], + ) + + retry_policies = [ + StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], + interval_seconds=1, + expire_after_mins=30, + backoff_rate=2.0, + ) + ] + + emr_step = EMRStep( + name="MyEMRStep", + display_name="MyEMRStep", + description="MyEMRStepDescription", + cluster_id="MyClusterID", + step_config=emr_step_config, + retry_policies=retry_policies, + ) + + expected_request = { + "Name": "MyEMRStep", + "Type": "EMR", + "Arguments": { + "ClusterId": "MyClusterID", + "StepConfig": { + "HadoopJarStep": { + "Args": ["--arg_0", "arg_0_value"], + "Jar": "s3:/script-runner/script-runner.jar", + } + }, + }, + "DisplayName": "MyEMRStep", + "Description": "MyEMRStepDescription", + "RetryPolicies": [ + { + "ExceptionType": ["Step.SERVICE_FAULT"], + "IntervalSeconds": 1, + "ExpireAfterMin": 30, + "BackoffRate": 2.0, + } + ], + } + + assert emr_step.to_request() == expected_request + + +def test_emr_step_with_all_exception_types(): + """Test EMRStep with all available exception types.""" + emr_step_config = EMRStepConfig(jar="s3:/script-runner/script-runner.jar") + + retry_policies = [ + StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT, StepExceptionTypeEnum.THROTTLING], + interval_seconds=1, + max_attempts=3, + backoff_rate=2.0, + ) + ] + + emr_step = EMRStep( + name="MyEMRStep", + display_name="MyEMRStep", + description="MyEMRStepDescription", + cluster_id="MyClusterID", + step_config=emr_step_config, + retry_policies=retry_policies, + ) + + expected_request = { + "Name": "MyEMRStep", + "Type": "EMR", + "Arguments": { + "ClusterId": "MyClusterID", + "StepConfig": { + "HadoopJarStep": { + "Jar": "s3:/script-runner/script-runner.jar", + } + }, + }, + "DisplayName": "MyEMRStep", + "Description": "MyEMRStepDescription", + "RetryPolicies": [ + { + "ExceptionType": ["Step.SERVICE_FAULT", "Step.THROTTLING"], + "IntervalSeconds": 1, + "MaxAttempts": 3, + "BackoffRate": 2.0, + } + ], + } + + assert emr_step.to_request() == expected_request + + +def test_pipeline_interpolates_emr_outputs_with_retry_policies(sagemaker_session): + """Test pipeline definition with EMR steps that have retry policies.""" + custom_step = CustomStep("TestStep") + parameter = ParameterString("MyStr") + + retry_policies = [ + StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], + interval_seconds=1, + max_attempts=3, + backoff_rate=2.0, + ) + ] + + step_emr = EMRStep( + name="emr_step_1", + cluster_id="MyClusterID", + display_name="emr_step_1", + description="MyEMRStepDescription", + depends_on=[custom_step], + step_config=EMRStepConfig(jar="s3:/script-runner/script-runner.jar"), + retry_policies=retry_policies, + ) + + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[step_emr, custom_step], + sagemaker_session=sagemaker_session, + ) + + pipeline_def = json.loads(pipeline.definition()) + assert "RetryPolicies" in pipeline_def["Steps"][0] + + +def test_emr_step_with_retry_policies_and_execution_role(): + """Test EMRStep with both retry policies and execution role.""" + retry_policies = [ + StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], + interval_seconds=1, + max_attempts=3, + backoff_rate=2.0, + ) + ] + + emr_step = EMRStep( + name="MyEMRStep", + display_name="MyEMRStep", + description="MyEMRStepDescription", + cluster_id="MyClusterID", + step_config=g_emr_step_config, + execution_role_arn="arn:aws:iam:000000000000:role/role", + retry_policies=retry_policies, + ) + + request = emr_step.to_request() + assert "RetryPolicies" in request + assert "ExecutionRoleArn" in request["Arguments"] + + +def test_emr_step_properties_with_retry_policies(): + """Test EMRStep properties when retry policies are provided.""" + retry_policies = [ + StepRetryPolicy( + exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], + interval_seconds=1, + max_attempts=3, + backoff_rate=2.0, + ) + ] + + emr_step = EMRStep( + name="MyEMRStep", + display_name="MyEMRStep", + description="MyEMRStepDescription", + cluster_id="MyClusterID", + step_config=g_emr_step_config, + retry_policies=retry_policies, + ) + + # Verify properties still work with retry policies + assert emr_step.properties.ClusterId == "MyClusterID" + assert emr_step.properties.Status.State.expr == {"Get": "Steps.MyEMRStep.Status.State"}