@@ -947,6 +947,18 @@ def _choose_randomized_parameters(
947947 return parameter_choices
948948
949949
950+ def _handle_param_grid_attribute (training_settings : dict [str , Any ]) -> dict [str , Any ]:
951+ """
952+ Given the training settings, handle the param_grid attribute and return the
953+ equivalent model_parameter_search table. This falls back to the default of
954+ param_grid = false (a.k.a. search strategy explicit) when param_grid is not
955+ present.
956+ """
957+ param_grid = training_settings .get ("param_grid" , False )
958+ strategy = "grid" if param_grid else "explicit"
959+ return {"strategy" : strategy }
960+
961+
950962def _get_model_parameters (training_settings : dict [str , Any ]) -> list [dict [str , Any ]]:
951963 if "param_grid" in training_settings :
952964 print (
@@ -965,53 +977,50 @@ def _get_model_parameters(training_settings: dict[str, Any]) -> list[dict[str, A
965977 )
966978
967979 model_parameters = training_settings ["model_parameters" ]
968- model_parameter_search = training_settings .get ("model_parameter_search" )
980+ fallback_parameter_search = _handle_param_grid_attribute (training_settings )
981+ model_parameter_search = training_settings .get (
982+ "model_parameter_search" , fallback_parameter_search
983+ )
969984 seed = training_settings .get ("seed" )
970- use_param_grid = training_settings .get ("param_grid" , False )
971985
972986 if model_parameters == []:
973987 raise ValueError (
974988 "model_parameters is empty, so there are no models to evaluate"
975989 )
976990
977- if model_parameter_search is not None :
978- strategy = model_parameter_search ["strategy" ]
979- if strategy == "explicit" :
980- return model_parameters
981- elif strategy == "grid" :
982- return _custom_param_grid_builder (model_parameters )
983- elif strategy == "randomized" :
984- num_samples = model_parameter_search ["num_samples" ]
985- rng = random .Random (seed )
986-
987- return_parameters = []
988- # These keys are special and should not be sampled or modified. All
989- # other keys are hyper-parameters to the model and should be sampled.
990- frozen_keys = {"type" , "threshold" , "threshold_ratio" }
991- for _ in range (num_samples ):
992- parameter_spec = rng .choice (model_parameters )
993- sample_parameters = {
994- key : value
995- for (key , value ) in parameter_spec .items ()
996- if key not in frozen_keys
997- }
998- frozen_parameters = {
999- key : value
1000- for (key , value ) in parameter_spec .items ()
1001- if key in frozen_keys
1002- }
1003-
1004- randomized = _choose_randomized_parameters (rng , sample_parameters )
1005- result = {** frozen_parameters , ** randomized }
1006- return_parameters .append (result )
1007-
1008- return return_parameters
1009- else :
1010- raise ValueError (
1011- f"Unknown model_parameter_search strategy '{ strategy } '. "
1012- "Please choose one of 'explicit', 'grid', or 'randomized'."
1013- )
1014- elif use_param_grid :
991+ strategy = model_parameter_search ["strategy" ]
992+ if strategy == "explicit" :
993+ return model_parameters
994+ elif strategy == "grid" :
1015995 return _custom_param_grid_builder (model_parameters )
1016-
1017- return model_parameters
996+ elif strategy == "randomized" :
997+ num_samples = model_parameter_search ["num_samples" ]
998+ rng = random .Random (seed )
999+
1000+ return_parameters = []
1001+ # These keys are special and should not be sampled or modified. All
1002+ # other keys are hyper-parameters to the model and should be sampled.
1003+ frozen_keys = {"type" , "threshold" , "threshold_ratio" }
1004+ for _ in range (num_samples ):
1005+ parameter_spec = rng .choice (model_parameters )
1006+ sample_parameters = {
1007+ key : value
1008+ for (key , value ) in parameter_spec .items ()
1009+ if key not in frozen_keys
1010+ }
1011+ frozen_parameters = {
1012+ key : value
1013+ for (key , value ) in parameter_spec .items ()
1014+ if key in frozen_keys
1015+ }
1016+
1017+ randomized = _choose_randomized_parameters (rng , sample_parameters )
1018+ result = {** frozen_parameters , ** randomized }
1019+ return_parameters .append (result )
1020+
1021+ return return_parameters
1022+ else :
1023+ raise ValueError (
1024+ f"Unknown model_parameter_search strategy '{ strategy } '. "
1025+ "Please choose one of 'explicit', 'grid', or 'randomized'."
1026+ )
0 commit comments