Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/sagemaker/workflow/emr_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from sagemaker.workflow.properties import (
Properties,
)
from sagemaker.workflow.retry import StepRetryPolicy
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
from sagemaker.workflow.steps import ConfigurableRetryStep, Step, StepTypeEnum, CacheConfig


class EMRStepConfig:
Expand Down Expand Up @@ -110,8 +111,8 @@ def to_request(self) -> RequestType:
)


class EMRStep(Step):
"""EMR step for workflow."""
class EMRStep(ConfigurableRetryStep):
"""EMR step for workflow with configurable retry policies."""

def _validate_cluster_config(self, cluster_config, step_name):
"""Validates user provided cluster_config.
Expand Down Expand Up @@ -164,6 +165,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
cluster_config: Optional[Dict[str, Any]] = None,
execution_role_arn: Optional[str] = None,
retry_policies: Optional[List[StepRetryPolicy]] = None,
):
"""Constructs an `EMRStep`.

Expand Down Expand Up @@ -200,7 +202,14 @@ def __init__(
called on the cluster specified by ``cluster_id``, so you can only include this
field if ``cluster_id`` is not None.
"""
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)
super().__init__(
name=name,
step_type=StepTypeEnum.EMR,
display_name=display_name,
description=description,
depends_on=depends_on,
retry_policies=retry_policies,
)

emr_step_args = {"StepConfig": step_config.to_request()}
root_property = Properties(step_name=name, step=self, shape_name="Step", service_name="emr")
Expand Down Expand Up @@ -248,7 +257,7 @@ def properties(self) -> RequestType:
return self._properties

def to_request(self) -> RequestType:
"""Updates the dictionary with cache configuration."""
"""Updates the dictionary with cache configuration and retry policies"""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)
Expand Down
213 changes: 213 additions & 0 deletions tests/integ/sagemaker/workflow/test_emr_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig
from sagemaker.workflow.parameters import ParameterInteger
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.retry import StepRetryPolicy, StepExceptionTypeEnum


@pytest.fixture
Expand Down Expand Up @@ -134,3 +135,215 @@ def test_emr_with_cluster_config(sagemaker_session, role, pipeline_name, region_
pipeline.delete()
except Exception:
pass


def test_emr_with_retry_policies(sagemaker_session, role, pipeline_name, region_name):
"""Test EMR steps with retry policies in both cluster_id and cluster_config scenarios."""
emr_step_config = EMRStepConfig(
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
args=["dummy_emr_script_path"],
)

retry_policies = [
StepRetryPolicy(
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT],
interval_seconds=1,
max_attempts=3,
backoff_rate=2.0,
)
]

# Step with existing cluster and retry policies
step_emr_1 = EMRStep(
name="emr-step-1",
cluster_id="j-1YONHTCP3YZKC",
display_name="emr_step_1",
description="EMR Step with retry policies",
step_config=emr_step_config,
retry_policies=retry_policies,
)

# Step with cluster config and retry policies
cluster_config = {
"Instances": {
"InstanceGroups": [
{
"Name": "Master Instance Group",
"InstanceRole": "MASTER",
"InstanceCount": 1,
"InstanceType": "m1.small",
"Market": "ON_DEMAND",
}
],
"InstanceCount": 1,
"HadoopVersion": "MyHadoopVersion",
},
"AmiVersion": "3.8.0",
"AdditionalInfo": "MyAdditionalInfo",
}

step_emr_2 = EMRStep(
name="emr-step-2",
display_name="emr_step_2",
description="EMR Step with cluster config and retry policies",
cluster_id=None,
step_config=emr_step_config,
cluster_config=cluster_config,
retry_policies=retry_policies,
)

pipeline = Pipeline(
name=pipeline_name,
steps=[step_emr_1, step_emr_2],
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
finally:
try:
pipeline.delete()
except Exception:
pass


def test_emr_with_expire_after_retry_policy(sagemaker_session, role, pipeline_name, region_name):
"""Test EMR step with retry policy using expire_after_mins."""
emr_step_config = EMRStepConfig(
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
args=["dummy_emr_script_path"],
)

retry_policies = [
StepRetryPolicy(
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT],
interval_seconds=1,
expire_after_mins=30,
backoff_rate=2.0,
)
]

step_emr = EMRStep(
name="emr-step-expire",
cluster_id="j-1YONHTCP3YZKC",
display_name="emr_step_expire",
description="EMR Step with expire after retry policy",
step_config=emr_step_config,
retry_policies=retry_policies,
)

pipeline = Pipeline(
name=pipeline_name,
steps=[step_emr],
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
finally:
try:
pipeline.delete()
except Exception:
pass


def test_emr_with_multiple_exception_types(sagemaker_session, role, pipeline_name, region_name):
"""Test EMR step with multiple exception types in retry policy."""
retry_policies = [
StepRetryPolicy(
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT, StepExceptionTypeEnum.THROTTLING],
interval_seconds=1,
max_attempts=3,
backoff_rate=2.0,
)
]

step_emr = EMRStep(
name="emr-step-multi-except",
cluster_id="j-1YONHTCP3YZKC",
display_name="emr_step_multi_except",
description="EMR Step with multiple exception types",
step_config=EMRStepConfig(
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
args=["dummy_emr_script_path"],
),
retry_policies=retry_policies,
)

pipeline = Pipeline(
name=pipeline_name,
steps=[step_emr],
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
finally:
try:
pipeline.delete()
except Exception:
pass


def test_emr_with_multiple_retry_policies(sagemaker_session, role, pipeline_name, region_name):
"""Test EMR step with multiple retry policies."""
retry_policies = [
StepRetryPolicy(
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT],
interval_seconds=1,
max_attempts=3,
backoff_rate=2.0,
),
StepRetryPolicy(
exception_types=[StepExceptionTypeEnum.THROTTLING],
interval_seconds=5,
expire_after_mins=60,
backoff_rate=1.5,
),
]

step_emr = EMRStep(
name="emr-step-multi-policy",
cluster_id="j-1YONHTCP3YZKC",
display_name="emr_step_multi_policy",
description="EMR Step with multiple retry policies",
step_config=EMRStepConfig(
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
args=["dummy_emr_script_path"],
),
retry_policies=retry_policies,
)

pipeline = Pipeline(
name=pipeline_name,
steps=[step_emr],
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
finally:
try:
pipeline.delete()
except Exception:
pass
Loading
Loading