@@ -336,6 +336,148 @@ def training_job_description(sagemaker_session):
336336 return returned_job_description
337337
338338
339+ def test_set_accept_eula_for_input_data_config_no_input_data_config ():
340+ """Test when InputDataConfig is not in train_args."""
341+ train_args = {}
342+ accept_eula = True
343+
344+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
345+
346+ # Verify train_args remains unchanged
347+ assert train_args == {}
348+
349+
350+ def test_set_accept_eula_for_input_data_config_none_accept_eula ():
351+ """Test when accept_eula is None."""
352+ train_args = {"InputDataConfig" : [{"DataSource" : {"S3DataSource" : {}}}]}
353+ accept_eula = None
354+
355+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
356+
357+ # Verify train_args remains unchanged
358+ assert train_args == {"InputDataConfig" : [{"DataSource" : {"S3DataSource" : {}}}]}
359+
360+
361+ def test_set_accept_eula_for_input_data_config_single_data_source ():
362+ """Test with a single S3DataSource."""
363+ with patch ("sagemaker.estimator.logger" ) as logger :
364+ train_args = {
365+ "InputDataConfig" : [{"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model" }}}]
366+ }
367+ accept_eula = True
368+
369+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
370+
371+ # Verify ModelAccessConfig and AcceptEula are set correctly
372+ assert train_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
373+ "ModelAccessConfig"
374+ ] == {"AcceptEula" : True }
375+
376+ # Verify no logging occurred since there's only one data source
377+ logger .info .assert_not_called ()
378+
379+
380+ def test_set_accept_eula_for_input_data_config_multiple_data_sources ():
381+ """Test with multiple S3DataSources."""
382+ with patch ("sagemaker.estimator.logger" ) as logger :
383+ train_args = {
384+ "InputDataConfig" : [
385+ {"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model1" }}},
386+ {"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model2" }}},
387+ ]
388+ }
389+ accept_eula = True
390+
391+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
392+
393+ # Verify ModelAccessConfig and AcceptEula are set correctly for both data sources
394+ assert train_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
395+ "ModelAccessConfig"
396+ ] == {"AcceptEula" : True }
397+ assert train_args ["InputDataConfig" ][1 ]["DataSource" ]["S3DataSource" ][
398+ "ModelAccessConfig"
399+ ] == {"AcceptEula" : True }
400+
401+ # Verify logging occurred with correct information
402+ logger .info .assert_called_once ()
403+ args = logger .info .call_args [0 ]
404+ assert args [0 ] == "Accepting EULA for %d S3 data sources: %s"
405+ assert args [1 ] == 2
406+ assert args [2 ] == "s3://bucket/model1, s3://bucket/model2"
407+
408+
409+ def test_set_accept_eula_for_input_data_config_existing_model_access_config ():
410+ """Test when ModelAccessConfig already exists."""
411+ train_args = {
412+ "InputDataConfig" : [
413+ {
414+ "DataSource" : {
415+ "S3DataSource" : {
416+ "S3Uri" : "s3://bucket/model" ,
417+ "ModelAccessConfig" : {"OtherSetting" : "value" },
418+ }
419+ }
420+ }
421+ ]
422+ }
423+ accept_eula = True
424+
425+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
426+
427+ # Verify AcceptEula is added to existing ModelAccessConfig
428+ assert train_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
429+ "OtherSetting" : "value" ,
430+ "AcceptEula" : True ,
431+ }
432+
433+
434+ def test_set_accept_eula_for_input_data_config_missing_s3_data_source ():
435+ """Test when S3DataSource is missing."""
436+ train_args = {"InputDataConfig" : [{"DataSource" : {"OtherDataSource" : {}}}]}
437+ accept_eula = True
438+
439+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
440+
441+ # Verify train_args remains unchanged
442+ assert train_args == {"InputDataConfig" : [{"DataSource" : {"OtherDataSource" : {}}}]}
443+
444+
445+ def test_set_accept_eula_for_input_data_config_missing_data_source ():
446+ """Test when DataSource is missing."""
447+ train_args = {"InputDataConfig" : [{"OtherKey" : {}}]}
448+ accept_eula = True
449+
450+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
451+
452+ # Verify train_args remains unchanged
453+ assert train_args == {"InputDataConfig" : [{"OtherKey" : {}}]}
454+
455+
456+ def test_set_accept_eula_for_input_data_config_mixed_data_sources ():
457+ """Test with a mix of S3DataSource and other data sources."""
458+ with patch ("sagemaker.estimator.logger" ) as logger :
459+ train_args = {
460+ "InputDataConfig" : [
461+ {"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model" }}},
462+ {"DataSource" : {"OtherDataSource" : {}}},
463+ ]
464+ }
465+ accept_eula = True
466+
467+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
468+
469+ # Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only
470+ assert train_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
471+ "ModelAccessConfig"
472+ ] == {"AcceptEula" : True }
473+ assert "ModelAccessConfig" not in train_args ["InputDataConfig" ][1 ]["DataSource" ].get (
474+ "OtherDataSource" , {}
475+ )
476+
477+ # Verify no logging occurred since there's only one S3 data source
478+ logger .info .assert_not_called ()
479+
480+
339481def test_validate_smdistributed_unsupported_image_raises (sagemaker_session ):
340482 # Test unsupported image raises error.
341483 for unsupported_image in DummyFramework .UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM :
0 commit comments