1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
1414
15+ from mock import Mock
16+
1517from sagemaker .model_monitor import DataCaptureConfig
1618
1719DEFAULT_ENABLE_CAPTURE = True
1820DEFAULT_SAMPLING_PERCENTAGE = 20
1921DEFAULT_BUCKET_NAME = "default-bucket"
20- DEFAULT_DESTINATION_S3_URI = "s3://" + DEFAULT_BUCKET_NAME + " /model-monitor/data-capture"
22+ DEFAULT_DESTINATION_S3_URI = "s3://{} /model-monitor/data-capture" . format ( DEFAULT_BUCKET_NAME )
2123DEFAULT_KMS_KEY_ID = None
2224DEFAULT_CAPTURE_MODES = ["REQUEST" , "RESPONSE" ]
2325DEFAULT_CSV_CONTENT_TYPES = ["text/csv" ]
3335NON_DEFAULT_JSON_CONTENT_TYPES = ["custom/json-format" ]
3436
3537
36- def test_to_request_dict_returns_correct_params_when_non_defaults_provided ():
38+ def test_init_when_non_defaults_provided ():
3739 data_capture_config = DataCaptureConfig (
3840 enable_capture = NON_DEFAULT_ENABLE_CAPTURE ,
3941 sampling_percentage = NON_DEFAULT_SAMPLING_PERCENTAGE ,
@@ -51,9 +53,12 @@ def test_to_request_dict_returns_correct_params_when_non_defaults_provided():
5153 assert data_capture_config .json_content_types == NON_DEFAULT_JSON_CONTENT_TYPES
5254
5355
54- def test_to_request_dict_returns_correct_default_params_when_optionals_not_provided ():
56+ def test_init_when_optionals_not_provided ():
57+ sagemaker_session = Mock ()
58+ sagemaker_session .default_bucket .return_value = DEFAULT_BUCKET_NAME
59+
5560 data_capture_config = DataCaptureConfig (
56- enable_capture = DEFAULT_ENABLE_CAPTURE , destination_s3_uri = DEFAULT_DESTINATION_S3_URI
61+ enable_capture = DEFAULT_ENABLE_CAPTURE , sagemaker_session = sagemaker_session
5762 )
5863
5964 assert data_capture_config .enable_capture == DEFAULT_ENABLE_CAPTURE
0 commit comments