Skip to content

Commit 0c9572c

Browse files
metrizableDan Choi
authored andcommitted
fix: change CreateModel to Model in step type (#536)
1 parent 6de4fca commit 0c9572c

File tree

4 files changed

+49
-9
lines changed

4 files changed

+49
-9
lines changed

src/sagemaker/workflow/steps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
4949
"""Enum of step types."""
5050

5151
CONDITION = "Condition"
52-
CREATE_MODEL = "CreateModel"
52+
CREATE_MODEL = "Model"
5353
FAIL = "Fail"
5454
PROCESSING = "Processing"
5555
REGISTER_MODEL = "RegisterModel"

tests/integ/test_workflow.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323
from botocore.config import Config
2424
from botocore.exceptions import WaiterError
25-
from sagemaker.inputs import TrainingInput
25+
from sagemaker.inputs import CreateModelInput, TrainingInput
26+
from sagemaker.model import Model
2627
from sagemaker.processing import ProcessingInput, ProcessingOutput
2728
from sagemaker.pytorch.estimator import PyTorch
2829
from sagemaker.session import get_execution_role, Session
@@ -35,6 +36,7 @@
3536
ParameterString,
3637
)
3738
from sagemaker.workflow.steps import (
39+
CreateModelStep,
3840
ProcessingStep,
3941
TrainingStep,
4042
)
@@ -95,7 +97,7 @@ def pipeline_name():
9597
return f"my-pipeline-{int(time.time() * 10**7)}"
9698

9799

98-
def test_two_step_definition(
100+
def test_three_step_definition(
99101
sagemaker_session, workflow_session, region_name, role, script_dir, pipeline_name
100102
):
101103
framework_version = "0.20.0"
@@ -140,10 +142,26 @@ def test_two_step_definition(
140142
),
141143
)
142144

145+
model = Model(
146+
image_uri=sklearn_train.image_uri,
147+
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
148+
sagemaker_session=sagemaker_session,
149+
role=role,
150+
)
151+
model_inputs = CreateModelInput(
152+
instance_type="ml.m5.large",
153+
accelerator_type="ml.eia1.medium",
154+
)
155+
step_model = CreateModelStep(
156+
name="my-model",
157+
model=model,
158+
inputs=model_inputs,
159+
)
160+
143161
pipeline = Pipeline(
144162
name=pipeline_name,
145163
parameters=[instance_type, instance_count],
146-
steps=[step_process, step_train],
164+
steps=[step_process, step_train, step_model],
147165
sagemaker_session=workflow_session,
148166
)
149167

@@ -160,7 +178,7 @@ def test_two_step_definition(
160178
)
161179

162180
steps = definition["Steps"]
163-
assert len(steps) == 2
181+
assert len(steps) == 3
164182

165183
names_and_types = []
166184
processing_args = {}
@@ -171,11 +189,14 @@ def test_two_step_definition(
171189
processing_args = step["Arguments"]
172190
if step["Type"] == "Training":
173191
training_args = step["Arguments"]
192+
if step["Type"] == "Model":
193+
model_args = step["Arguments"]
174194

175195
assert set(names_and_types) == set(
176196
[
177197
("my-process", "Processing"),
178198
("my-train", "Training"),
199+
("my-model", "Model"),
179200
]
180201
)
181202

@@ -193,6 +214,9 @@ def test_two_step_definition(
193214
assert training_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] == {
194215
"Get": "Steps.my-process.ProcessingOutputConfig.Outputs['train_data'].S3Output.S3Uri"
195216
}
217+
assert model_args["PrimaryContainer"]["ModelDataUrl"] == {
218+
"Get": "Steps.my-train.ModelArtifacts.S3ModelArtifacts"
219+
}
196220

197221

198222
# TODO-reinvent-2020: Modify use of the workflow client
@@ -324,11 +348,27 @@ def test_conditional_pytorch_training_model_registration(
324348
transform_instances=["*"],
325349
)
326350

351+
model = Model(
352+
image_uri=pytorch_estimator.training_image_uri(),
353+
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
354+
sagemaker_session=sagemaker_session,
355+
role=role,
356+
)
357+
model_inputs = CreateModelInput(
358+
instance_type="ml.m5.large",
359+
accelerator_type="ml.eia1.medium",
360+
)
361+
step_model = CreateModelStep(
362+
name="pytorch-model",
363+
model=model,
364+
inputs=model_inputs,
365+
)
366+
327367
step_cond = ConditionStep(
328368
name="cond-good-enough",
329369
conditions=[ConditionGreaterThanOrEqualTo(left=good_enough_input, right=1)],
330370
if_steps=[step_train, step_register],
331-
else_steps=[],
371+
else_steps=[step_model],
332372
)
333373

334374
pipeline = Pipeline(

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,10 @@ def test_estimator_transformer(estimator):
194194
request_dicts = estimator_transformer.request_dicts()
195195
assert len(request_dicts) == 2
196196
for request_dict in request_dicts:
197-
if request_dict["Type"] == "CreateModel":
197+
if request_dict["Type"] == "Model":
198198
assert request_dict == {
199199
"Name": "EstimatorTransformerStepCreateModelStep",
200-
"Type": "CreateModel",
200+
"Type": "Model",
201201
"Arguments": {
202202
"ExecutionRoleArn": "DummyRole",
203203
"PrimaryContainer": {

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def test_create_model_step(sagemaker_session):
220220

221221
assert step.to_request() == {
222222
"Name": "MyCreateModelStep",
223-
"Type": "CreateModel",
223+
"Type": "Model",
224224
"Arguments": {
225225
"ExecutionRoleArn": "DummyRole",
226226
"PrimaryContainer": {"Environment": {}, "Image": "fakeimage"},

0 commit comments

Comments
 (0)