22
22
23
23
from botocore .config import Config
24
24
from botocore .exceptions import WaiterError
25
- from sagemaker .inputs import TrainingInput
25
+ from sagemaker .inputs import CreateModelInput , TrainingInput
26
+ from sagemaker .model import Model
26
27
from sagemaker .processing import ProcessingInput , ProcessingOutput
27
28
from sagemaker .pytorch .estimator import PyTorch
28
29
from sagemaker .session import get_execution_role , Session
35
36
ParameterString ,
36
37
)
37
38
from sagemaker .workflow .steps import (
39
+ CreateModelStep ,
38
40
ProcessingStep ,
39
41
TrainingStep ,
40
42
)
@@ -95,7 +97,7 @@ def pipeline_name():
95
97
return f"my-pipeline-{ int (time .time () * 10 ** 7 )} "
96
98
97
99
98
- def test_two_step_definition (
100
+ def test_three_step_definition (
99
101
sagemaker_session , workflow_session , region_name , role , script_dir , pipeline_name
100
102
):
101
103
framework_version = "0.20.0"
@@ -140,10 +142,26 @@ def test_two_step_definition(
140
142
),
141
143
)
142
144
145
+ model = Model (
146
+ image_uri = sklearn_train .image_uri ,
147
+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
148
+ sagemaker_session = sagemaker_session ,
149
+ role = role ,
150
+ )
151
+ model_inputs = CreateModelInput (
152
+ instance_type = "ml.m5.large" ,
153
+ accelerator_type = "ml.eia1.medium" ,
154
+ )
155
+ step_model = CreateModelStep (
156
+ name = "my-model" ,
157
+ model = model ,
158
+ inputs = model_inputs ,
159
+ )
160
+
143
161
pipeline = Pipeline (
144
162
name = pipeline_name ,
145
163
parameters = [instance_type , instance_count ],
146
- steps = [step_process , step_train ],
164
+ steps = [step_process , step_train , step_model ],
147
165
sagemaker_session = workflow_session ,
148
166
)
149
167
@@ -160,7 +178,7 @@ def test_two_step_definition(
160
178
)
161
179
162
180
steps = definition ["Steps" ]
163
- assert len (steps ) == 2
181
+ assert len (steps ) == 3
164
182
165
183
names_and_types = []
166
184
processing_args = {}
@@ -171,11 +189,14 @@ def test_two_step_definition(
171
189
processing_args = step ["Arguments" ]
172
190
if step ["Type" ] == "Training" :
173
191
training_args = step ["Arguments" ]
192
+ if step ["Type" ] == "Model" :
193
+ model_args = step ["Arguments" ]
174
194
175
195
assert set (names_and_types ) == set (
176
196
[
177
197
("my-process" , "Processing" ),
178
198
("my-train" , "Training" ),
199
+ ("my-model" , "Model" ),
179
200
]
180
201
)
181
202
@@ -193,6 +214,9 @@ def test_two_step_definition(
193
214
assert training_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] == {
194
215
"Get" : "Steps.my-process.ProcessingOutputConfig.Outputs['train_data'].S3Output.S3Uri"
195
216
}
217
+ assert model_args ["PrimaryContainer" ]["ModelDataUrl" ] == {
218
+ "Get" : "Steps.my-train.ModelArtifacts.S3ModelArtifacts"
219
+ }
196
220
197
221
198
222
# TODO-reinvent-2020: Modify use of the workflow client
@@ -324,11 +348,27 @@ def test_conditional_pytorch_training_model_registration(
324
348
transform_instances = ["*" ],
325
349
)
326
350
351
+ model = Model (
352
+ image_uri = pytorch_estimator .training_image_uri (),
353
+ model_data = step_train .properties .ModelArtifacts .S3ModelArtifacts ,
354
+ sagemaker_session = sagemaker_session ,
355
+ role = role ,
356
+ )
357
+ model_inputs = CreateModelInput (
358
+ instance_type = "ml.m5.large" ,
359
+ accelerator_type = "ml.eia1.medium" ,
360
+ )
361
+ step_model = CreateModelStep (
362
+ name = "pytorch-model" ,
363
+ model = model ,
364
+ inputs = model_inputs ,
365
+ )
366
+
327
367
step_cond = ConditionStep (
328
368
name = "cond-good-enough" ,
329
369
conditions = [ConditionGreaterThanOrEqualTo (left = good_enough_input , right = 1 )],
330
370
if_steps = [step_train , step_register ],
331
- else_steps = [],
371
+ else_steps = [step_model ],
332
372
)
333
373
334
374
pipeline = Pipeline (
0 commit comments