2424from sagemaker .tuner import HyperparameterTuner
2525from sagemaker .workflow .pipeline_context import PipelineSession
2626
27- from sagemaker .processing import Processor , ScriptProcessor , FrameworkProcessor
27+ from sagemaker .processing import (
28+ Processor ,
29+ ScriptProcessor ,
30+ FrameworkProcessor ,
31+ ProcessingOutput ,
32+ ProcessingInput ,
33+ )
2834from sagemaker .sklearn .processing import SKLearnProcessor
2935from sagemaker .pytorch .processing import PyTorchProcessor
3036from sagemaker .tensorflow .processing import TensorFlowProcessor
3440from sagemaker .wrangler .processing import DataWranglerProcessor
3541from sagemaker .spark .processing import SparkJarProcessor , PySparkProcessor
3642
37- from sagemaker .processing import ProcessingInput
3843
3944from sagemaker .workflow .steps import CacheConfig , ProcessingStep
4045from sagemaker .workflow .pipeline import Pipeline
4146from sagemaker .workflow .properties import PropertyFile
47+ from sagemaker .workflow .parameters import ParameterString
48+ from sagemaker .workflow .functions import Join
4249
4350from sagemaker .network import NetworkConfig
4451from sagemaker .pytorch .estimator import PyTorch
6269DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
6370INSTANCE_TYPE = "ml.m4.xlarge"
6471
72+ FRAMEWORK_PROCESSOR = [
73+ (
74+ FrameworkProcessor (
75+ framework_version = "1.8" ,
76+ instance_type = INSTANCE_TYPE ,
77+ instance_count = 1 ,
78+ role = ROLE ,
79+ estimator_cls = PyTorch ,
80+ ),
81+ {"code" : DUMMY_S3_SCRIPT_PATH },
82+ ),
83+ (
84+ SKLearnProcessor (
85+ framework_version = "0.23-1" ,
86+ instance_type = INSTANCE_TYPE ,
87+ instance_count = 1 ,
88+ role = ROLE ,
89+ ),
90+ {"code" : DUMMY_S3_SCRIPT_PATH },
91+ ),
92+ (
93+ PyTorchProcessor (
94+ role = ROLE ,
95+ instance_type = INSTANCE_TYPE ,
96+ instance_count = 1 ,
97+ framework_version = "1.8.0" ,
98+ py_version = "py3" ,
99+ ),
100+ {"code" : DUMMY_S3_SCRIPT_PATH },
101+ ),
102+ (
103+ TensorFlowProcessor (
104+ role = ROLE ,
105+ instance_type = INSTANCE_TYPE ,
106+ instance_count = 1 ,
107+ framework_version = "2.0" ,
108+ ),
109+ {"code" : DUMMY_S3_SCRIPT_PATH },
110+ ),
111+ (
112+ HuggingFaceProcessor (
113+ transformers_version = "4.6" ,
114+ pytorch_version = "1.7" ,
115+ role = ROLE ,
116+ instance_count = 1 ,
117+ instance_type = "ml.p3.2xlarge" ,
118+ ),
119+ {"code" : DUMMY_S3_SCRIPT_PATH },
120+ ),
121+ (
122+ XGBoostProcessor (
123+ framework_version = "1.3-1" ,
124+ py_version = "py3" ,
125+ role = ROLE ,
126+ instance_count = 1 ,
127+ instance_type = INSTANCE_TYPE ,
128+ base_job_name = "test-xgboost" ,
129+ ),
130+ {"code" : DUMMY_S3_SCRIPT_PATH },
131+ ),
132+ (
133+ MXNetProcessor (
134+ framework_version = "1.4.1" ,
135+ py_version = "py3" ,
136+ role = ROLE ,
137+ instance_count = 1 ,
138+ instance_type = INSTANCE_TYPE ,
139+ base_job_name = "test-mxnet" ,
140+ ),
141+ {"code" : DUMMY_S3_SCRIPT_PATH },
142+ ),
143+ (
144+ DataWranglerProcessor (
145+ role = ROLE ,
146+ data_wrangler_flow_source = "s3://my-bucket/dw.flow" ,
147+ instance_count = 1 ,
148+ instance_type = INSTANCE_TYPE ,
149+ ),
150+ {},
151+ ),
152+ (
153+ SparkJarProcessor (
154+ role = ROLE ,
155+ framework_version = "2.4" ,
156+ instance_count = 1 ,
157+ instance_type = INSTANCE_TYPE ,
158+ ),
159+ {
160+ "submit_app" : "s3://my-jar" ,
161+ "submit_class" : "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp" ,
162+ "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
163+ },
164+ ),
165+ (
166+ PySparkProcessor (
167+ role = ROLE ,
168+ framework_version = "2.4" ,
169+ instance_count = 1 ,
170+ instance_type = INSTANCE_TYPE ,
171+ ),
172+ {
173+ "submit_app" : "s3://my-jar" ,
174+ "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
175+ },
176+ ),
177+ ]
178+
179+ PROCESSING_INPUT = [
180+ ProcessingInput (source = "s3://my-bucket/processing_manifest" , destination = "processing_manifest" ),
181+ ProcessingInput (
182+ source = ParameterString (name = "my-processing-input" ),
183+ destination = "processing-input" ,
184+ ),
185+ ProcessingInput (
186+ source = ParameterString (
187+ name = "my-processing-input" , default_value = "s3://my-bucket/my-processing"
188+ ),
189+ destination = "processing-input" ,
190+ ),
191+ ProcessingInput (
192+ source = Join (on = "/" , values = ["s3://my-bucket" , "my-input" ]),
193+ destination = "processing-input" ,
194+ ),
195+ ]
196+
197+ PROCESSING_OUTPUT = [
198+ ProcessingOutput (source = "/opt/ml/output" , destination = "s3://my-bucket/my-output" ),
199+ ProcessingOutput (source = "/opt/ml/output" , destination = ParameterString (name = "my-output" )),
200+ ProcessingOutput (
201+ source = "/opt/ml/output" ,
202+ destination = ParameterString (name = "my-output" , default_value = "s3://my-bucket/my-output" ),
203+ ),
204+ ProcessingOutput (
205+ source = "/opt/ml/output" ,
206+ destination = Join (on = "/" , values = ["s3://my-bucket" , "my-output" ]),
207+ ),
208+ ]
209+
65210
66211@pytest .fixture
67212def client ():
@@ -253,117 +398,11 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu
253398 }
254399
255400
256- @pytest .mark .parametrize (
257- "framework_processor" ,
258- [
259- (
260- FrameworkProcessor (
261- framework_version = "1.8" ,
262- instance_type = INSTANCE_TYPE ,
263- instance_count = 1 ,
264- role = ROLE ,
265- estimator_cls = PyTorch ,
266- ),
267- {"code" : DUMMY_S3_SCRIPT_PATH },
268- ),
269- (
270- SKLearnProcessor (
271- framework_version = "0.23-1" ,
272- instance_type = INSTANCE_TYPE ,
273- instance_count = 1 ,
274- role = ROLE ,
275- ),
276- {"code" : DUMMY_S3_SCRIPT_PATH },
277- ),
278- (
279- PyTorchProcessor (
280- role = ROLE ,
281- instance_type = INSTANCE_TYPE ,
282- instance_count = 1 ,
283- framework_version = "1.8.0" ,
284- py_version = "py3" ,
285- ),
286- {"code" : DUMMY_S3_SCRIPT_PATH },
287- ),
288- (
289- TensorFlowProcessor (
290- role = ROLE ,
291- instance_type = INSTANCE_TYPE ,
292- instance_count = 1 ,
293- framework_version = "2.0" ,
294- ),
295- {"code" : DUMMY_S3_SCRIPT_PATH },
296- ),
297- (
298- HuggingFaceProcessor (
299- transformers_version = "4.6" ,
300- pytorch_version = "1.7" ,
301- role = ROLE ,
302- instance_count = 1 ,
303- instance_type = "ml.p3.2xlarge" ,
304- ),
305- {"code" : DUMMY_S3_SCRIPT_PATH },
306- ),
307- (
308- XGBoostProcessor (
309- framework_version = "1.3-1" ,
310- py_version = "py3" ,
311- role = ROLE ,
312- instance_count = 1 ,
313- instance_type = INSTANCE_TYPE ,
314- base_job_name = "test-xgboost" ,
315- ),
316- {"code" : DUMMY_S3_SCRIPT_PATH },
317- ),
318- (
319- MXNetProcessor (
320- framework_version = "1.4.1" ,
321- py_version = "py3" ,
322- role = ROLE ,
323- instance_count = 1 ,
324- instance_type = INSTANCE_TYPE ,
325- base_job_name = "test-mxnet" ,
326- ),
327- {"code" : DUMMY_S3_SCRIPT_PATH },
328- ),
329- (
330- DataWranglerProcessor (
331- role = ROLE ,
332- data_wrangler_flow_source = f"s3://{ BUCKET } /dw.flow" ,
333- instance_count = 1 ,
334- instance_type = INSTANCE_TYPE ,
335- ),
336- {},
337- ),
338- (
339- SparkJarProcessor (
340- role = ROLE ,
341- framework_version = "2.4" ,
342- instance_count = 1 ,
343- instance_type = INSTANCE_TYPE ,
344- ),
345- {
346- "submit_app" : "s3://my-jar" ,
347- "submit_class" : "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp" ,
348- "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
349- },
350- ),
351- (
352- PySparkProcessor (
353- role = ROLE ,
354- framework_version = "2.4" ,
355- instance_count = 1 ,
356- instance_type = INSTANCE_TYPE ,
357- ),
358- {
359- "submit_app" : "s3://my-jar" ,
360- "arguments" : ["--input" , "input-data-uri" , "--output" , "output-data-uri" ],
361- },
362- ),
363- ],
364- )
401+ @pytest .mark .parametrize ("framework_processor" , FRAMEWORK_PROCESSOR )
402+ @pytest .mark .parametrize ("processing_input" , PROCESSING_INPUT )
403+ @pytest .mark .parametrize ("processing_output" , PROCESSING_OUTPUT )
365404def test_processing_step_with_framework_processor (
366- framework_processor , pipeline_session , processing_input , network_config
405+ framework_processor , pipeline_session , processing_input , processing_output , network_config
367406):
368407
369408 processor , run_inputs = framework_processor
@@ -373,7 +412,8 @@ def test_processing_step_with_framework_processor(
373412 processor .volume_kms_key = "volume-kms-key"
374413 processor .network_config = network_config
375414
376- run_inputs ["inputs" ] = processing_input
415+ run_inputs ["inputs" ] = [processing_input ]
416+ run_inputs ["outputs" ] = [processing_output ]
377417
378418 step_args = processor .run (** run_inputs )
379419
@@ -387,10 +427,25 @@ def test_processing_step_with_framework_processor(
387427 sagemaker_session = pipeline_session ,
388428 )
389429
390- assert json .loads (pipeline .definition ())["Steps" ][0 ] == {
430+ step_args = step_args .args
431+ step_def = json .loads (pipeline .definition ())["Steps" ][0 ]
432+
433+ assert step_args ["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ] == processing_input .source
434+ assert (
435+ step_args ["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
436+ == processing_output .destination
437+ )
438+
439+ del step_args ["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ]
440+ del step_def ["Arguments" ]["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ]
441+
442+ del step_args ["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
443+ del step_def ["Arguments" ]["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
444+
445+ assert step_def == {
391446 "Name" : "MyProcessingStep" ,
392447 "Type" : "Processing" ,
393- "Arguments" : step_args . args ,
448+ "Arguments" : step_args ,
394449 }
395450
396451
0 commit comments