@@ -273,13 +273,61 @@ def test_spark_processor_base_extend_processing_args(
273273serialized_configuration = BytesIO ("test" .encode ("utf-8" ))
274274
275275
276+ @pytest .mark .parametrize (
277+ "config, expected" ,
278+ [
279+ (
280+ {
281+ "spark_processor_type" : "py_spark_processor" ,
282+ "configuration_location" : None ,
283+ },
284+ "s3://bucket/None/input/conf/configuration.json" ,
285+ ),
286+ (
287+ {
288+ "spark_processor_type" : "py_spark_processor" ,
289+ "configuration_location" : "s3://configbucket/someprefix/" ,
290+ },
291+ "s3://configbucket/someprefix/None/input/conf/configuration.json" ,
292+ ),
293+ (
294+ {
295+ "spark_processor_type" : "spark_jar_processor" ,
296+ "configuration_location" : None ,
297+ },
298+ "s3://bucket/None/input/conf/configuration.json" ,
299+ ),
300+ (
301+ {
302+ "spark_processor_type" : "spark_jar_processor" ,
303+ "configuration_location" : "s3://configbucket/someprefix" ,
304+ },
305+ "s3://configbucket/someprefix/None/input/conf/configuration.json" ,
306+ ),
307+ ],
308+ )
276309@patch ("sagemaker.spark.processing.BytesIO" )
277310@patch ("sagemaker.spark.processing.S3Uploader.upload_string_as_file_body" )
278- def test_stage_configuration (mock_s3_upload , mock_bytesIO , py_spark_processor , sagemaker_session ):
279- desired_s3_uri = "s3://bucket/None/input/conf/configuration.json"
311+ def test_stage_configuration (mock_s3_upload , mock_bytesIO , config , expected , sagemaker_session ):
312+ spark_processor_type = {
313+ "py_spark_processor" : PySparkProcessor ,
314+ "spark_jar_processor" : SparkJarProcessor ,
315+ }[config ["spark_processor_type" ]]
316+ spark_processor = spark_processor_type (
317+ base_job_name = "sm-spark" ,
318+ role = "AmazonSageMaker-ExecutionRole" ,
319+ framework_version = "2.4" ,
320+ instance_count = 1 ,
321+ instance_type = "ml.c5.xlarge" ,
322+ image_uri = "790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1" ,
323+ configuration_location = config ["configuration_location" ],
324+ sagemaker_session = sagemaker_session ,
325+ )
326+
327+ desired_s3_uri = expected
280328 mock_bytesIO .return_value = serialized_configuration
281329
282- result = py_spark_processor ._stage_configuration ({})
330+ result = spark_processor ._stage_configuration ({})
283331
284332 mock_s3_upload .assert_called_with (
285333 body = serialized_configuration ,
@@ -292,23 +340,121 @@ def test_stage_configuration(mock_s3_upload, mock_bytesIO, py_spark_processor, s
292340@pytest .mark .parametrize (
293341 "config, expected" ,
294342 [
295- ({"submit_deps" : None , "input_channel_name" : "channelName" }, ValueError ),
296- ({"submit_deps" : ["s3" ], "input_channel_name" : None }, ValueError ),
297- ({"submit_deps" : ["other" ], "input_channel_name" : "channelName" }, ValueError ),
298- ({"submit_deps" : ["file" ], "input_channel_name" : "channelName" }, ValueError ),
299- ({"submit_deps" : ["file" ], "input_channel_name" : "channelName" }, ValueError ),
300343 (
301- {"submit_deps" : ["s3" , "s3" ], "input_channel_name" : "channelName" },
344+ {
345+ "spark_processor_type" : "py_spark_processor" ,
346+ "dependency_location" : None ,
347+ "submit_deps" : None ,
348+ "input_channel_name" : "channelName" ,
349+ },
350+ ValueError ,
351+ ),
352+ (
353+ {
354+ "spark_processor_type" : "py_spark_processor" ,
355+ "dependency_location" : None ,
356+ "submit_deps" : ["s3" ],
357+ "input_channel_name" : None ,
358+ },
359+ ValueError ,
360+ ),
361+ (
362+ {
363+ "spark_processor_type" : "py_spark_processor" ,
364+ "dependency_location" : None ,
365+ "submit_deps" : ["other" ],
366+ "input_channel_name" : "channelName" ,
367+ },
368+ ValueError ,
369+ ),
370+ (
371+ {
372+ "spark_processor_type" : "py_spark_processor" ,
373+ "dependency_location" : None ,
374+ "submit_deps" : ["file" ],
375+ "input_channel_name" : "channelName" ,
376+ },
377+ ValueError ,
378+ ),
379+ (
380+ {
381+ "spark_processor_type" : "py_spark_processor" ,
382+ "dependency_location" : None ,
383+ "submit_deps" : ["file" ],
384+ "input_channel_name" : "channelName" ,
385+ },
386+ ValueError ,
387+ ),
388+ (
389+ {
390+ "spark_processor_type" : "py_spark_processor" ,
391+ "dependency_location" : None ,
392+ "submit_deps" : ["s3" , "s3" ],
393+ "input_channel_name" : "channelName" ,
394+ },
302395 (None , "s3://bucket,s3://bucket" ),
303396 ),
304397 (
305- {"submit_deps" : ["jar" ], "input_channel_name" : "channelName" },
306- (processing_input , "s3://bucket" ),
398+ {
399+ "spark_processor_type" : "py_spark_processor" ,
400+ "dependency_location" : None ,
401+ "submit_deps" : ["jar" ],
402+ "input_channel_name" : "channelName" ,
403+ },
404+ ("s3://bucket/None/input/channelName" , "/opt/ml/processing/input/channelName" ),
405+ ),
406+ (
407+ {
408+ "spark_processor_type" : "py_spark_processor" ,
409+ "dependency_location" : "s3://codebucket/someprefix/" ,
410+ "submit_deps" : ["jar" ],
411+ "input_channel_name" : "channelName" ,
412+ },
413+ (
414+ "s3://codebucket/someprefix/None/input/channelName" ,
415+ "/opt/ml/processing/input/channelName" ,
416+ ),
417+ ),
418+ (
419+ {
420+ "spark_processor_type" : "spark_jar_processor" ,
421+ "dependency_location" : None ,
422+ "submit_deps" : ["jar" ],
423+ "input_channel_name" : "channelName" ,
424+ },
425+ ("s3://bucket/None/input/channelName" , "/opt/ml/processing/input/channelName" ),
426+ ),
427+ (
428+ {
429+ "spark_processor_type" : "spark_jar_processor" ,
430+ "dependency_location" : "s3://codebucket/someprefix" ,
431+ "submit_deps" : ["jar" ],
432+ "input_channel_name" : "channelName" ,
433+ },
434+ (
435+ "s3://codebucket/someprefix/None/input/channelName" ,
436+ "/opt/ml/processing/input/channelName" ,
437+ ),
307438 ),
308439 ],
309440)
310441@patch ("sagemaker.spark.processing.S3Uploader" )
311- def test_stage_submit_deps (mock_s3_uploader , py_spark_processor , jar_file , config , expected ):
442+ def test_stage_submit_deps (mock_s3_uploader , jar_file , config , expected , sagemaker_session ):
443+ spark_processor_type = {
444+ "py_spark_processor" : PySparkProcessor ,
445+ "spark_jar_processor" : SparkJarProcessor ,
446+ }[config ["spark_processor_type" ]]
447+ spark_processor = spark_processor_type (
448+ base_job_name = "sm-spark" ,
449+ role = "AmazonSageMaker-ExecutionRole" ,
450+ framework_version = "2.4" ,
451+ instance_count = 1 ,
452+ instance_type = "ml.c5.xlarge" ,
453+ image_uri = "790336243319.dkr.ecr.us-west-2.amazonaws.com/sagemaker-spark:0.1" ,
454+ dependency_location = config ["dependency_location" ],
455+ sagemaker_session = sagemaker_session ,
456+ )
457+
312458 submit_deps_dict = {
313459 None : None ,
314460 "s3" : "s3://bucket" ,
@@ -322,21 +468,20 @@ def test_stage_submit_deps(mock_s3_uploader, py_spark_processor, jar_file, confi
322468
323469 if expected is ValueError :
324470 with pytest .raises (expected ) as e :
325- py_spark_processor ._stage_submit_deps (submit_deps , config ["input_channel_name" ])
471+ spark_processor ._stage_submit_deps (submit_deps , config ["input_channel_name" ])
326472
327473 assert isinstance (e .value , expected )
328474 else :
329- input_channel , spark_opt = py_spark_processor ._stage_submit_deps (
475+ input_channel , spark_opt = spark_processor ._stage_submit_deps (
330476 submit_deps , config ["input_channel_name" ]
331477 )
332478
333479 if expected [0 ] is None :
334480 assert input_channel is None
335481 assert spark_opt == expected [1 ]
336482 else :
337- expected_source = "s3://bucket/None/input/channelName"
338- assert input_channel .source == expected_source
339- assert spark_opt == "/opt/ml/processing/input/channelName"
483+ assert input_channel .source == expected [0 ]
484+ assert spark_opt == expected [1 ]
340485
341486
342487@pytest .mark .parametrize (
0 commit comments