Skip to content

Commit a21fdfe

Browse files
committed
Simplify how we handle the deprecated training.param_grid attribute
This converts from training.param_grid to the equivalent training.model_parameter_search representation, which reduces the complexity of the surrounding function a bit.
1 parent 5152468 commit a21fdfe

File tree

1 file changed

+51
-42
lines changed

1 file changed

+51
-42
lines changed

hlink/linking/model_exploration/link_step_train_test_models.py

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,18 @@ def _choose_randomized_parameters(
947947
return parameter_choices
948948

949949

950+
def _handle_param_grid_attribute(training_settings: dict[str, Any]) -> dict[str, Any]:
951+
"""
952+
Given the training settings, handle the param_grid attribute and return the
953+
equivalent model_parameter_search table. This falls back to the default of
954+
param_grid = false (a.k.a. search strategy explicit) when param_grid is not
955+
present.
956+
"""
957+
param_grid = training_settings.get("param_grid", False)
958+
strategy = "grid" if param_grid else "explicit"
959+
return {"strategy": strategy}
960+
961+
950962
def _get_model_parameters(training_settings: dict[str, Any]) -> list[dict[str, Any]]:
951963
if "param_grid" in training_settings:
952964
print(
@@ -965,53 +977,50 @@ def _get_model_parameters(training_settings: dict[str, Any]) -> list[dict[str, A
965977
)
966978

967979
model_parameters = training_settings["model_parameters"]
968-
model_parameter_search = training_settings.get("model_parameter_search")
980+
fallback_parameter_search = _handle_param_grid_attribute(training_settings)
981+
model_parameter_search = training_settings.get(
982+
"model_parameter_search", fallback_parameter_search
983+
)
969984
seed = training_settings.get("seed")
970-
use_param_grid = training_settings.get("param_grid", False)
971985

972986
if model_parameters == []:
973987
raise ValueError(
974988
"model_parameters is empty, so there are no models to evaluate"
975989
)
976990

977-
if model_parameter_search is not None:
978-
strategy = model_parameter_search["strategy"]
979-
if strategy == "explicit":
980-
return model_parameters
981-
elif strategy == "grid":
982-
return _custom_param_grid_builder(model_parameters)
983-
elif strategy == "randomized":
984-
num_samples = model_parameter_search["num_samples"]
985-
rng = random.Random(seed)
986-
987-
return_parameters = []
988-
# These keys are special and should not be sampled or modified. All
989-
# other keys are hyper-parameters to the model and should be sampled.
990-
frozen_keys = {"type", "threshold", "threshold_ratio"}
991-
for _ in range(num_samples):
992-
parameter_spec = rng.choice(model_parameters)
993-
sample_parameters = {
994-
key: value
995-
for (key, value) in parameter_spec.items()
996-
if key not in frozen_keys
997-
}
998-
frozen_parameters = {
999-
key: value
1000-
for (key, value) in parameter_spec.items()
1001-
if key in frozen_keys
1002-
}
1003-
1004-
randomized = _choose_randomized_parameters(rng, sample_parameters)
1005-
result = {**frozen_parameters, **randomized}
1006-
return_parameters.append(result)
1007-
1008-
return return_parameters
1009-
else:
1010-
raise ValueError(
1011-
f"Unknown model_parameter_search strategy '{strategy}'. "
1012-
"Please choose one of 'explicit', 'grid', or 'randomized'."
1013-
)
1014-
elif use_param_grid:
991+
strategy = model_parameter_search["strategy"]
992+
if strategy == "explicit":
993+
return model_parameters
994+
elif strategy == "grid":
1015995
return _custom_param_grid_builder(model_parameters)
1016-
1017-
return model_parameters
996+
elif strategy == "randomized":
997+
num_samples = model_parameter_search["num_samples"]
998+
rng = random.Random(seed)
999+
1000+
return_parameters = []
1001+
# These keys are special and should not be sampled or modified. All
1002+
# other keys are hyper-parameters to the model and should be sampled.
1003+
frozen_keys = {"type", "threshold", "threshold_ratio"}
1004+
for _ in range(num_samples):
1005+
parameter_spec = rng.choice(model_parameters)
1006+
sample_parameters = {
1007+
key: value
1008+
for (key, value) in parameter_spec.items()
1009+
if key not in frozen_keys
1010+
}
1011+
frozen_parameters = {
1012+
key: value
1013+
for (key, value) in parameter_spec.items()
1014+
if key in frozen_keys
1015+
}
1016+
1017+
randomized = _choose_randomized_parameters(rng, sample_parameters)
1018+
result = {**frozen_parameters, **randomized}
1019+
return_parameters.append(result)
1020+
1021+
return return_parameters
1022+
else:
1023+
raise ValueError(
1024+
f"Unknown model_parameter_search strategy '{strategy}'. "
1025+
"Please choose one of 'explicit', 'grid', or 'randomized'."
1026+
)

0 commit comments

Comments
 (0)