Skip to content

Commit 51ac909

Browse files
committed
Fix test_nested_run
1 parent a5191d9 commit 51ac909

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

flaml/fabric/mlflow.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,7 @@ def adopt_children(self, result=None):
896896
),
897897
)
898898
self.child_counter = 0
899+
num_infos = len(self.infos)
899900

900901
# From latest to earliest, remove duplicate cross-validation runs
901902
_exist_child_run_params = [] # for deduplication of cross-validation child runs
@@ -960,22 +961,37 @@ def adopt_children(self, result=None):
960961
)
961962
self.mlflow_client.set_tag(child_run_id, "flaml.child_counter", self.child_counter)
962963

963-
# merge autolog child run and corresponding manual run
964-
flaml_info = self.infos[self.child_counter]
965-
child_run = self.mlflow_client.get_run(child_run_id)
966-
self._log_info_to_run(flaml_info, child_run_id, log_params=False)
967-
968-
if self.experiment_type == "automl":
969-
if "learner" not in child_run.data.params:
970-
self.mlflow_client.log_param(child_run_id, "learner", flaml_info["params"]["learner"])
971-
if "sample_size" not in child_run.data.params:
972-
self.mlflow_client.log_param(
973-
child_run_id, "sample_size", flaml_info["params"]["sample_size"]
974-
)
964+
# Merge autolog child run and corresponding FLAML trial info (if available).
965+
# In nested scenarios (e.g., Tune -> AutoML -> MLflow autolog), MLflow can create
966+
# more child runs than the number of FLAML trials recorded in self.infos.
967+
# TODO: need more tests in nested scenarios.
968+
flaml_info = None
969+
child_run = None
970+
if self.child_counter < num_infos:
971+
flaml_info = self.infos[self.child_counter]
972+
child_run = self.mlflow_client.get_run(child_run_id)
973+
self._log_info_to_run(flaml_info, child_run_id, log_params=False)
974+
975+
if self.experiment_type == "automl":
976+
if "learner" not in child_run.data.params:
977+
self.mlflow_client.log_param(child_run_id, "learner", flaml_info["params"]["learner"])
978+
if "sample_size" not in child_run.data.params:
979+
self.mlflow_client.log_param(
980+
child_run_id, "sample_size", flaml_info["params"]["sample_size"]
981+
)
982+
else:
983+
logger.debug(
984+
"No corresponding FLAML info for MLflow child run %s (child_counter=%s, infos=%s); skipping merge.",
985+
child_run_id,
986+
self.child_counter,
987+
num_infos,
988+
)
975989

976-
if self.child_counter == best_iteration:
990+
if flaml_info is not None and self.child_counter == best_iteration:
977991
self.mlflow_client.set_tag(child_run_id, "flaml.best_run", True)
978992
if result is not None:
993+
if child_run is None:
994+
child_run = self.mlflow_client.get_run(child_run_id)
979995
result.best_run_id = child_run_id
980996
result.best_run_name = child_run.info.run_name
981997
self.best_run_id = child_run_id

test/tune/test_tune.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ def _easy_objective(config):
5353

5454

5555
def test_nested_run():
56+
"""
57+
nested tuning example: Tune -> AutoML -> MLflow autolog
58+
mlflow logging is complicated in nested tuning. It's better to turn off mlflow autologging to avoid
59+
potential issues in FLAML's mlflow_integration.adopt_children() function.
60+
"""
5661
from flaml import AutoML, tune
5762

5863
data, labels = sklearn.datasets.load_breast_cancer(return_X_y=True)

0 commit comments

Comments
 (0)