@@ -553,11 +553,15 @@ def test_validate_version_or_image_args_raises():
553553
554554def test_validate_smdistributed_not_raises ():
555555 smdataparallel_enabled = {"smdistributed" : {"dataparallel" : {"enabled" : True }}}
556+ smdataparallel_enabled_custom_mpi = {
557+ "smdistributed" : {"dataparallel" : {"enabled" : True , "custom_mpi_options" : "--verbose" }}
558+ }
556559 smdataparallel_disabled = {"smdistributed" : {"dataparallel" : {"enabled" : False }}}
557560 instance_types = list (fw_utils .SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES )
558561
559562 good_args = [
560563 (smdataparallel_enabled , "custom-container" ),
564+ (smdataparallel_enabled_custom_mpi , "custom-container" ),
561565 (smdataparallel_disabled , "custom-container" ),
562566 ]
563567 frameworks = ["tensorflow" , "pytorch" ]
@@ -576,17 +580,17 @@ def test_validate_smdistributed_not_raises():
576580
577581def test_validate_smdistributed_raises ():
578582 bad_args = [
579- {"smdistributed" : {"dataparallel" : {"enabled" : True }}},
580583 {"smdistributed" : "dummy" },
581584 {"smdistributed" : {"dummy" }},
582585 {"smdistributed" : {"dummy" : "val" }},
583586 {"smdistributed" : {"dummy" : {"enabled" : True }}},
584587 ]
588+ instance_types = list (fw_utils .SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES )
585589 frameworks = ["tensorflow" , "pytorch" ]
586- for framework , distribution in product (frameworks , bad_args ):
590+ for framework , distribution , instance_type in product (frameworks , bad_args , instance_types ):
587591 with pytest .raises (ValueError ):
588592 fw_utils .validate_smdistributed (
589- instance_type = None ,
593+ instance_type = instance_type ,
590594 framework_name = framework ,
591595 framework_version = None ,
592596 py_version = None ,
@@ -624,6 +628,9 @@ def test_validate_smdataparallel_args_raises():
624628
625629def test_validate_smdataparallel_args_not_raises ():
626630 smdataparallel_enabled = {"smdistributed" : {"dataparallel" : {"enabled" : True }}}
631+ smdataparallel_enabled_custom_mpi = {
632+ "smdistributed" : {"dataparallel" : {"enabled" : True , "custom_mpi_options" : "--verbose" }}
633+ }
627634 smdataparallel_disabled = {"smdistributed" : {"dataparallel" : {"enabled" : False }}}
628635
629636 # Cases {PT|TF2}
@@ -644,6 +651,8 @@ def test_validate_smdataparallel_args_not_raises():
644651 ("ml.p3.16xlarge" , "pytorch" , "1.8.0" , "py3" , smdataparallel_enabled ),
645652 ("ml.p3.16xlarge" , "pytorch" , "1.8.1" , "py3" , smdataparallel_enabled ),
646653 ("ml.p3.16xlarge" , "pytorch" , "1.8" , "py3" , smdataparallel_enabled ),
654+ ("ml.p3.16xlarge" , "tensorflow" , "2.4.1" , "py3" , smdataparallel_enabled_custom_mpi ),
655+ ("ml.p3.16xlarge" , "pytorch" , "1.8.0" , "py3" , smdataparallel_enabled_custom_mpi ),
647656 ]
648657 for instance_type , framework_name , framework_version , py_version , distribution in good_args :
649658 fw_utils ._validate_smdataparallel_args (
0 commit comments