1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
14+ from copy import deepcopy
1415
1516import logging
1617import json
@@ -3825,6 +3826,12 @@ def test_script_mode_estimator_same_calls_as_framework(
38253826
38263827 model_uri = "s3://someprefix2/models/model.tar.gz"
38273828 training_data_uri = "s3://bucket/mydata"
3829+ hyperparameters = {
3830+ "int_hyperparam" : 1 ,
3831+ "string_hyperparam" : "hello" ,
3832+ "stringified_numeric_hyperparam" : "44" ,
3833+ "float_hyperparam" : 1.234 ,
3834+ }
38283835
38293836 generic_estimator = Estimator (
38303837 entry_point = SCRIPT_PATH ,
@@ -3838,6 +3845,7 @@ def test_script_mode_estimator_same_calls_as_framework(
38383845 model_uri = model_uri ,
38393846 dependencies = [],
38403847 debugger_hook_config = {},
3848+ hyperparameters = deepcopy (hyperparameters ),
38413849 )
38423850 generic_estimator .fit (training_data_uri )
38433851
@@ -3858,6 +3866,7 @@ def test_script_mode_estimator_same_calls_as_framework(
38583866 model_uri = model_uri ,
38593867 dependencies = [],
38603868 debugger_hook_config = {},
3869+ hyperparameters = deepcopy (hyperparameters ),
38613870 )
38623871 framework_estimator .fit (training_data_uri )
38633872
@@ -4394,3 +4403,51 @@ def test_insert_invalid_source_code_args():
43944403 assert (
43954404 "The entry_point should not be a pipeline variable " "when source_dir is a local path"
43964405 ) in str (err .value )
4406+
4407+
4408+ @patch ("time.time" , return_value = TIME )
4409+ @patch ("sagemaker.estimator.tar_and_upload_dir" )
4410+ @patch ("sagemaker.model.Model._upload_code" )
4411+ def test_script_mode_estimator_escapes_hyperparameters_as_json (
4412+ patched_upload_code , patched_tar_and_upload_dir , sagemaker_session
4413+ ):
4414+ patched_tar_and_upload_dir .return_value = UploadedCode (
4415+ s3_prefix = "s3://%s/%s" % ("bucket" , "key" ), script_name = "script_name"
4416+ )
4417+ sagemaker_session .boto_region_name = REGION
4418+
4419+ instance_type = "ml.p2.xlarge"
4420+ instance_count = 1
4421+
4422+ training_data_uri = "s3://bucket/mydata"
4423+
4424+ jumpstart_source_dir = f"s3://{ list (JUMPSTART_BUCKET_NAME_SET )[0 ]} /source_dirs/source.tar.gz"
4425+
4426+ hyperparameters = {
4427+ "int_hyperparam" : 1 ,
4428+ "string_hyperparam" : "hello" ,
4429+ "stringified_numeric_hyperparam" : "44" ,
4430+ "float_hyperparam" : 1.234 ,
4431+ }
4432+
4433+ generic_estimator = Estimator (
4434+ entry_point = SCRIPT_PATH ,
4435+ role = ROLE ,
4436+ region = REGION ,
4437+ sagemaker_session = sagemaker_session ,
4438+ instance_count = instance_count ,
4439+ instance_type = instance_type ,
4440+ source_dir = jumpstart_source_dir ,
4441+ image_uri = IMAGE_URI ,
4442+ model_uri = MODEL_DATA ,
4443+ hyperparameters = hyperparameters ,
4444+ )
4445+ generic_estimator .fit (training_data_uri )
4446+
4447+ formatted_hyperparams = EstimatorBase ._json_encode_hyperparameters (hyperparameters )
4448+
4449+ assert (
4450+ set (formatted_hyperparams .items ())
4451+ - set (sagemaker_session .train .call_args_list [0 ][1 ]["hyperparameters" ].items ())
4452+ == set ()
4453+ )
0 commit comments