@@ -262,7 +262,7 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio
262262 py_version = "py3" ,
263263 framework_version = "1.15.2" ,
264264 role = "{{ role }}" ,
265- instance_count = "{{ instance_count }}" ,
265+ instance_count = 1 ,
266266 instance_type = "ml.c4.2xlarge" ,
267267 volume_size = "{{ volume_size }}" ,
268268 volume_kms_key = "{{ volume_kms_key }}" ,
@@ -276,6 +276,8 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio
276276 security_group_ids = ["{{ security_group_ids }}" ],
277277 metric_definitions = [{"Name" : "{{ name }}" , "Regex" : "{{ regex }}" }],
278278 sagemaker_session = sagemaker_session ,
279+ checkpoint_local_path = "{{ checkpoint_local_path }}" ,
280+ checkpoint_s3_uri = "{{ checkpoint_s3_uri }}" ,
279281 )
280282
281283 data = "{{ training_data }}"
@@ -294,7 +296,7 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio
294296 "TrainingJobName" : "{{ base_job_name }}-%s" % TIME_STAMP ,
295297 "StoppingCondition" : {"MaxRuntimeInSeconds" : "{{ max_run }}" },
296298 "ResourceConfig" : {
297- "InstanceCount" : "{{ instance_count }}" ,
299+ "InstanceCount" : 1 ,
298300 "InstanceType" : "ml.c4.2xlarge" ,
299301 "VolumeSizeInGB" : "{{ volume_size }}" ,
300302 "VolumeKmsKeyId" : "{{ volume_kms_key }}" ,
@@ -338,6 +340,10 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio
338340 }
339341 ]
340342 },
343+ "CheckpointConfig" : {
344+ "LocalPath" : "{{ checkpoint_local_path }}" ,
345+ "S3Uri" : "{{ checkpoint_s3_uri }}" ,
346+ },
341347 }
342348 assert config == expected_config
343349
0 commit comments