@@ -402,6 +402,39 @@ def test_validate_smdistributed_backward_compat_p4_not_raises(sagemaker_session)
402402 f ._distribution_configuration (DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED )
403403
404404
405+ def test_validate_smdistributed_instance_groups_raises (sagemaker_session ):
406+ instance_group_1 = InstanceGroup ("train_group" , "ml.p4d.24xlarge" , 2 )
407+ instance_group_2 = InstanceGroup ("train_group" , "ml.p5.48xlarge" , 2 )
408+ f = DummyFramework (
409+ "some_script.py" ,
410+ role = "DummyRole" ,
411+ instance_groups = [instance_group_1 , instance_group_2 ],
412+ sagemaker_session = sagemaker_session ,
413+ output_path = "outputpath" ,
414+ image_uri = "some_acceptable_image" ,
415+ )
416+ # Testing instance_group with p5 raises exception
417+ with pytest .raises (ValueError ):
418+ f ._distribution_configuration (DISTRIBUTION_SM_DDP_ENABLED )
419+ with pytest .raises (ValueError ):
420+ f ._distribution_configuration (DISTRIBUTION_SM_DDP_DISABLED )
421+
422+
423+ def test_validate_smdistributed_instance_groups_not_raises (sagemaker_session ):
424+ instance_group_1 = InstanceGroup ("train_group" , "ml.p4d.24xlarge" , 2 )
425+ f = DummyFramework (
426+ "some_script.py" ,
427+ role = "DummyRole" ,
428+ instance_groups = [instance_group_1 ],
429+ sagemaker_session = sagemaker_session ,
430+ output_path = "outputpath" ,
431+ image_uri = "some_acceptable_image" ,
432+ )
433+ # Testing instance_group without p5 does not raise exception
434+ f ._distribution_configuration (DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED )
435+ f ._distribution_configuration (DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED )
436+
437+
405438def test_framework_all_init_args (sagemaker_session ):
406439 f = DummyFramework (
407440 "my_script.py" ,
0 commit comments