Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions src/sagemaker/workflow/emr_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from sagemaker.workflow.properties import (
Properties,
)
from sagemaker.workflow.retry import StepRetryPolicy
from sagemaker.workflow.step_collections import StepCollection
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
from sagemaker.workflow.steps import ConfigurableRetryStep, Step, StepTypeEnum, CacheConfig


class EMRStepConfig:
Expand Down Expand Up @@ -110,8 +111,8 @@ def to_request(self) -> RequestType:
)


class EMRStep(Step):
"""EMR step for workflow."""
class EMRStep(ConfigurableRetryStep):
"""EMR step for workflow with configurable retry policies."""

def _validate_cluster_config(self, cluster_config, step_name):
"""Validates user provided cluster_config.
Expand Down Expand Up @@ -164,6 +165,7 @@ def __init__(
cache_config: Optional[CacheConfig] = None,
cluster_config: Optional[Dict[str, Any]] = None,
execution_role_arn: Optional[str] = None,
retry_policies: Optional[List[StepRetryPolicy]] = None,
):
"""Constructs an `EMRStep`.

Expand Down Expand Up @@ -200,7 +202,14 @@ def __init__(
called on the cluster specified by ``cluster_id``, so you can only include this
field if ``cluster_id`` is not None.
"""
super(EMRStep, self).__init__(name, display_name, description, StepTypeEnum.EMR, depends_on)
super().__init__(
name=name,
step_type=StepTypeEnum.EMR,
display_name=display_name,
description=description,
depends_on=depends_on,
retry_policies=retry_policies,
)

emr_step_args = {"StepConfig": step_config.to_request()}
root_property = Properties(step_name=name, step=self, shape_name="Step", service_name="emr")
Expand Down Expand Up @@ -248,7 +257,7 @@ def properties(self) -> RequestType:
return self._properties

def to_request(self) -> RequestType:
"""Updates the dictionary with cache configuration."""
"""Updates the dictionary with cache configuration and retry policies"""
request_dict = super().to_request()
if self.cache_config:
request_dict.update(self.cache_config.config)
Expand Down
213 changes: 213 additions & 0 deletions tests/integ/sagemaker/workflow/test_emr_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig
from sagemaker.workflow.parameters import ParameterInteger
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.retry import StepRetryPolicy, StepExceptionTypeEnum


@pytest.fixture
Expand Down Expand Up @@ -134,3 +135,215 @@ def test_emr_with_cluster_config(sagemaker_session, role, pipeline_name, region_
pipeline.delete()
except Exception:
pass


def test_emr_with_retry_policies(sagemaker_session, role, pipeline_name, region_name):
"""Test EMR steps with retry policies in both cluster_id and cluster_config scenarios."""
emr_step_config = EMRStepConfig(
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
args=["dummy_emr_script_path"],
)

retry_policies = [
StepRetryPolicy(
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT],
interval_seconds=1,
max_attempts=3,
backoff_rate=2.0,
)
]

# Step with existing cluster and retry policies
step_emr_1 = EMRStep(
name="emr-step-1",
cluster_id="j-1YONHTCP3YZKC",
display_name="emr_step_1",
description="EMR Step with retry policies",
step_config=emr_step_config,
retry_policies=retry_policies,
)

# Step with cluster config and retry policies
cluster_config = {
"Instances": {
"InstanceGroups": [
{
"Name": "Master Instance Group",
"InstanceRole": "MASTER",
"InstanceCount": 1,
"InstanceType": "m1.small",
"Market": "ON_DEMAND",
}
],
"InstanceCount": 1,
"HadoopVersion": "MyHadoopVersion",
},
"AmiVersion": "3.8.0",
"AdditionalInfo": "MyAdditionalInfo",
}

step_emr_2 = EMRStep(
name="emr-step-2",
display_name="emr_step_2",
description="EMR Step with cluster config and retry policies",
cluster_id=None,
step_config=emr_step_config,
cluster_config=cluster_config,
retry_policies=retry_policies,
)

pipeline = Pipeline(
name=pipeline_name,
steps=[step_emr_1, step_emr_2],
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also assert here that the pipeline definition contains the expected retry policies?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm you have tested that in the unit tests already

assert re.match(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
finally:
try:
pipeline.delete()
except Exception:
pass


def test_emr_with_expire_after_retry_policy(sagemaker_session, role, pipeline_name, region_name):
"""Test EMR step with retry policy using expire_after_mins."""
emr_step_config = EMRStepConfig(
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
args=["dummy_emr_script_path"],
)

retry_policies = [
StepRetryPolicy(
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT],
interval_seconds=1,
expire_after_mins=30,
backoff_rate=2.0,
)
]

step_emr = EMRStep(
name="emr-step-expire",
cluster_id="j-1YONHTCP3YZKC",
display_name="emr_step_expire",
description="EMR Step with expire after retry policy",
step_config=emr_step_config,
retry_policies=retry_policies,
)

pipeline = Pipeline(
name=pipeline_name,
steps=[step_emr],
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
finally:
try:
pipeline.delete()
except Exception:
pass


def test_emr_with_multiple_exception_types(sagemaker_session, role, pipeline_name, region_name):
"""Test EMR step with multiple exception types in retry policy."""
retry_policies = [
StepRetryPolicy(
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT, StepExceptionTypeEnum.THROTTLING],
interval_seconds=1,
max_attempts=3,
backoff_rate=2.0,
)
]

step_emr = EMRStep(
name="emr-step-multi-except",
cluster_id="j-1YONHTCP3YZKC",
display_name="emr_step_multi_except",
description="EMR Step with multiple exception types",
step_config=EMRStepConfig(
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
args=["dummy_emr_script_path"],
),
retry_policies=retry_policies,
)

pipeline = Pipeline(
name=pipeline_name,
steps=[step_emr],
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
finally:
try:
pipeline.delete()
except Exception:
pass


def test_emr_with_multiple_retry_policies(sagemaker_session, role, pipeline_name, region_name):
"""Test EMR step with multiple retry policies."""
retry_policies = [
StepRetryPolicy(
exception_types=[StepExceptionTypeEnum.SERVICE_FAULT],
interval_seconds=1,
max_attempts=3,
backoff_rate=2.0,
),
StepRetryPolicy(
exception_types=[StepExceptionTypeEnum.THROTTLING],
interval_seconds=5,
expire_after_mins=60,
backoff_rate=1.5,
),
]

step_emr = EMRStep(
name="emr-step-multi-policy",
cluster_id="j-1YONHTCP3YZKC",
display_name="emr_step_multi_policy",
description="EMR Step with multiple retry policies",
step_config=EMRStepConfig(
jar="s3://us-west-2.elasticmapreduce/libs/script-runner/script-runner.jar",
args=["dummy_emr_script_path"],
),
retry_policies=retry_policies,
)

pipeline = Pipeline(
name=pipeline_name,
steps=[step_emr],
sagemaker_session=sagemaker_session,
)

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]
assert re.match(
rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
create_arn,
)
finally:
try:
pipeline.delete()
except Exception:
pass
Loading
Loading