2222
2323from botocore .config import Config
2424from botocore .exceptions import WaiterError
25- from sagemaker .inputs import TrainingInput
25+ from sagemaker .inputs import CreateModelInput , TrainingInput
26+ from sagemaker .model import Model
2627from sagemaker .processing import ProcessingInput , ProcessingOutput
2728from sagemaker .pytorch .estimator import PyTorch
2829from sagemaker .session import get_execution_role , Session
3536 ParameterString ,
3637)
3738from sagemaker .workflow .steps import (
39+ CreateModelStep ,
3840 ProcessingStep ,
3941 TrainingStep ,
4042)
@@ -95,7 +97,7 @@ def pipeline_name():
9597 return f"my-pipeline-{ int (time .time () * 10 ** 7 )} "
9698
9799
98- def test_two_step_definition (
100+ def test_three_step_definition (
99101 sagemaker_session , workflow_session , region_name , role , script_dir , pipeline_name
100102):
101103 framework_version = "0.20.0"
@@ -140,10 +142,26 @@ def test_two_step_definition(
140142 ),
141143 )
142144
145+ model = Model (
146+ image_uri = sklearn_train .image_uri ,
147+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
148+ sagemaker_session = sagemaker_session ,
149+ role = role ,
150+ )
151+ model_inputs = CreateModelInput (
152+ instance_type = "ml.m5.large" ,
153+ accelerator_type = "ml.eia1.medium" ,
154+ )
155+ step_model = CreateModelStep (
156+ name = "my-model" ,
157+ model = model ,
158+ inputs = model_inputs ,
159+ )
160+
143161 pipeline = Pipeline (
144162 name = pipeline_name ,
145163 parameters = [instance_type , instance_count ],
146- steps = [step_process , step_train ],
164+ steps = [step_process , step_train , step_model ],
147165 sagemaker_session = workflow_session ,
148166 )
149167
@@ -160,7 +178,7 @@ def test_two_step_definition(
160178 )
161179
162180 steps = definition ["Steps" ]
163- assert len (steps ) == 2
181+ assert len (steps ) == 3
164182
165183 names_and_types = []
166184 processing_args = {}
@@ -171,11 +189,14 @@ def test_two_step_definition(
171189 processing_args = step ["Arguments" ]
172190 if step ["Type" ] == "Training" :
173191 training_args = step ["Arguments" ]
192+ if step ["Type" ] == "Model" :
193+ model_args = step ["Arguments" ]
174194
175195 assert set (names_and_types ) == set (
176196 [
177197 ("my-process" , "Processing" ),
178198 ("my-train" , "Training" ),
199+ ("my-model" , "Model" ),
179200 ]
180201 )
181202
@@ -193,6 +214,9 @@ def test_two_step_definition(
193214 assert training_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] == {
194215 "Get" : "Steps.my-process.ProcessingOutputConfig.Outputs['train_data'].S3Output.S3Uri"
195216 }
217+ assert model_args ["PrimaryContainer" ]["ModelDataUrl" ] == {
218+ "Get" : "Steps.my-train.ModelArtifacts.S3ModelArtifacts"
219+ }
196220
197221
198222# TODO-reinvent-2020: Modify use of the workflow client
@@ -324,11 +348,27 @@ def test_conditional_pytorch_training_model_registration(
324348 transform_instances = ["*" ],
325349 )
326350
351+ model = Model (
352+ image_uri = pytorch_estimator .training_image_uri (),
353+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
354+ sagemaker_session = sagemaker_session ,
355+ role = role ,
356+ )
357+ model_inputs = CreateModelInput (
358+ instance_type = "ml.m5.large" ,
359+ accelerator_type = "ml.eia1.medium" ,
360+ )
361+ step_model = CreateModelStep (
362+ name = "pytorch-model" ,
363+ model = model ,
364+ inputs = model_inputs ,
365+ )
366+
327367 step_cond = ConditionStep (
328368 name = "cond-good-enough" ,
329369 conditions = [ConditionGreaterThanOrEqualTo (left = good_enough_input , right = 1 )],
330370 if_steps = [step_train , step_register ],
331- else_steps = [],
371+ else_steps = [step_model ],
332372 )
333373
334374 pipeline = Pipeline (
0 commit comments