File tree Expand file tree Collapse file tree 3 files changed +34
-12
lines changed Expand file tree Collapse file tree 3 files changed +34
-12
lines changed Original file line number Diff line number Diff line change @@ -6,6 +6,7 @@ CHANGELOG
66======
77
88* feature: ``PipelineModel ``: Create a Transformer from a PipelineModel
9+ * bug-fix: ``AlgorithmEstimator ``: Make SupportedHyperParameters optional
910
10111.18.4
1112======
Original file line number Diff line number Diff line change @@ -375,20 +375,23 @@ def _validate_and_set_default_hyperparameters(self):
375375 raise ValueError ('Required hyperparameter: %s is not set' % name )
376376
377377 def _parse_hyperparameters (self ):
378- hyperparameters = self .algorithm_spec ['TrainingSpecification' ]['SupportedHyperParameters' ]
379378 definitions = {}
380- for h in hyperparameters :
381- parameter_type = h ['Type' ]
382- name = h ['Name' ]
383- parameter_class , parameter_range = self ._hyperparameter_range_and_class (
384- parameter_type , h
385- )
386379
387- definitions [name ] = {'spec' : h }
388- if parameter_range :
389- definitions [name ]['range' ] = parameter_range
390- if parameter_class :
391- definitions [name ]['class' ] = parameter_class
380+ training_spec = self .algorithm_spec ['TrainingSpecification' ]
381+ if 'SupportedHyperParameters' in training_spec :
382+ hyperparameters = training_spec ['SupportedHyperParameters' ]
383+ for h in hyperparameters :
384+ parameter_type = h ['Type' ]
385+ name = h ['Name' ]
386+ parameter_class , parameter_range = self ._hyperparameter_range_and_class (
387+ parameter_type , h
388+ )
389+
390+ definitions [name ] = {'spec' : h }
391+ if parameter_range :
392+ definitions [name ]['range' ] = parameter_range
393+ if parameter_class :
394+ definitions [name ]['class' ] = parameter_class
392395
393396 return definitions
394397
Original file line number Diff line number Diff line change @@ -913,3 +913,21 @@ def test_algorithm_encrypt_inter_container_traffic(sagemaker_session):
913913
914914 encrypt_inter_container_traffic = estimator .encrypt_inter_container_traffic
915915 assert encrypt_inter_container_traffic is True
916+
917+
918+ def test_algorithm_no_required_hyperparameters (sagemaker_session ):
919+ some_algo = copy .deepcopy (DESCRIBE_ALGORITHM_RESPONSE )
920+ del some_algo ['TrainingSpecification' ]['SupportedHyperParameters' ]
921+
922+ sagemaker_session .sagemaker_client .describe_algorithm = Mock (return_value = some_algo )
923+
924+ # Calling AlgorithmEstimator() with unset required hyperparameters
925+ # should fail if they are required.
926+ # Pass training and hyperparameters channels. This should work
927+ assert AlgorithmEstimator (
928+ algorithm_arn = 'arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees' ,
929+ role = 'SageMakerRole' ,
930+ train_instance_type = 'ml.m4.2xlarge' ,
931+ train_instance_count = 1 ,
932+ sagemaker_session = sagemaker_session ,
933+ )
You can’t perform that action at this time.
0 commit comments