@@ -943,3 +943,75 @@ def test_algorithm_no_required_hyperparameters(session):
943943 train_instance_count = 1 ,
944944 sagemaker_session = session ,
945945 )
946+
947+
948+ def test_algorithm_attach_from_hyperparameter_tuning ():
949+ session = Mock ()
950+ job_name = "training-job-that-is-part-of-a-tuning-job"
951+ algo_arn = "arn:aws:sagemaker:us-east-2:000000000000:algorithm/scikit-decision-trees"
952+ role_arn = "arn:aws:iam::123412341234:role/SageMakerRole"
953+ instance_count = 1
954+ instance_type = "ml.m4.xlarge"
955+ train_volume_size = 30
956+ input_mode = "File"
957+
958+ session .sagemaker_client .list_tags .return_value = {"Tags" : []}
959+ session .sagemaker_client .describe_algorithm .return_value = DESCRIBE_ALGORITHM_RESPONSE
960+ session .sagemaker_client .describe_training_job .return_value = {
961+ "TrainingJobName" : job_name ,
962+ "TrainingJobArn" : "arn:aws:sagemaker:us-east-2:123412341234:training-job/%s" % job_name ,
963+ "TuningJobArn" : "arn:aws:sagemaker:us-east-2:123412341234:hyper-parameter-tuning-job/%s"
964+ % job_name ,
965+ "ModelArtifacts" : {
966+ "S3ModelArtifacts" : "s3://sagemaker-us-east-2-123412341234/output/model.tar.gz"
967+ },
968+ "TrainingJobOutput" : {
969+ "S3TrainingJobOutput" : "s3://sagemaker-us-east-2-123412341234/output/output.tar.gz"
970+ },
971+ "TrainingJobStatus" : "Succeeded" ,
972+ "HyperParameters" : {
973+ "_tuning_objective_metric" : "validation:accuracy" ,
974+ "max_leaf_nodes" : 1 ,
975+ "free_text_hp1" : "foo" ,
976+ },
977+ "AlgorithmSpecification" : {"AlgorithmName" : algo_arn , "TrainingInputMode" : input_mode },
978+ "MetricDefinitions" : [
979+ {"Name" : "validation:accuracy" , "Regex" : "validation-accuracy: (\\ S+)" }
980+ ],
981+ "RoleArn" : role_arn ,
982+ "InputDataConfig" : [
983+ {
984+ "ChannelName" : "training" ,
985+ "DataSource" : {
986+ "S3DataSource" : {
987+ "S3DataType" : "S3Prefix" ,
988+ "S3Uri" : "s3://sagemaker-us-east-2-123412341234/input/training.csv" ,
989+ "S3DataDistributionType" : "FullyReplicated" ,
990+ }
991+ },
992+ "CompressionType" : "None" ,
993+ "RecordWrapperType" : "None" ,
994+ }
995+ ],
996+ "OutputDataConfig" : {
997+ "KmsKeyId" : "" ,
998+ "S3OutputPath" : "s3://sagemaker-us-east-2-123412341234/output" ,
999+ "RemoveJobNameFromS3OutputPath" : False ,
1000+ },
1001+ "ResourceConfig" : {
1002+ "InstanceType" : instance_type ,
1003+ "InstanceCount" : instance_count ,
1004+ "VolumeSizeInGB" : train_volume_size ,
1005+ },
1006+ "StoppingCondition" : {"MaxRuntimeInSeconds" : 86400 },
1007+ }
1008+
1009+ estimator = AlgorithmEstimator .attach (job_name , sagemaker_session = session )
1010+ assert estimator .hyperparameters () == {"max_leaf_nodes" : 1 , "free_text_hp1" : "foo" }
1011+ assert estimator .algorithm_arn == algo_arn
1012+ assert estimator .role == role_arn
1013+ assert estimator .train_instance_count == instance_count
1014+ assert estimator .train_instance_type == instance_type
1015+ assert estimator .train_volume_size == train_volume_size
1016+ assert estimator .input_mode == input_mode
1017+ assert estimator .sagemaker_session == session
0 commit comments