Skip to content

Commit 4ea9650

Browse files
Copilotthinkall
andauthored
Fix nested dictionary merge in SearchThread losing sampled hyperparameters (#1494)
* Initial plan * Add recursive dict update to fix nested config merge Co-authored-by: thinkall <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: thinkall <[email protected]> Co-authored-by: Li Jiang <[email protected]>
1 parent fa1a32a commit 4ea9650

File tree

2 files changed

+125
-1
lines changed

2 files changed

+125
-1
lines changed

flaml/tune/searcher/search_thread.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,31 @@
2525
logger = logging.getLogger(__name__)
2626

2727

28+
def _recursive_dict_update(target: Dict, source: Dict) -> None:
29+
"""Recursively update target dictionary with source dictionary.
30+
31+
Unlike dict.update(), this function merges nested dictionaries instead of
32+
replacing them entirely. This is crucial for configurations with nested
33+
structures (e.g., XGBoost params).
34+
35+
Args:
36+
target: The dictionary to be updated (modified in place).
37+
source: The dictionary containing values to merge into target.
38+
39+
Example:
40+
>>> target = {'params': {'eta': 0.1, 'max_depth': 3}}
41+
>>> source = {'params': {'verbosity': 0}}
42+
>>> _recursive_dict_update(target, source)
43+
>>> target
44+
{'params': {'eta': 0.1, 'max_depth': 3, 'verbosity': 0}}
45+
"""
46+
for key, value in source.items():
47+
if isinstance(value, dict) and key in target and isinstance(target[key], dict):
48+
_recursive_dict_update(target[key], value)
49+
else:
50+
target[key] = value
51+
52+
2853
class SearchThread:
2954
"""Class of global or local search thread."""
3055

@@ -65,7 +90,7 @@ def suggest(self, trial_id: str) -> Optional[Dict]:
6590
try:
6691
config = self._search_alg.suggest(trial_id)
6792
if isinstance(self._search_alg._space, dict):
68-
config.update(self._const)
93+
_recursive_dict_update(config, self._const)
6994
else:
7095
# define by run
7196
config, self.space = unflatten_hierarchical(config, self._space)

test/tune/test_search_thread.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Tests for SearchThread nested dictionary update fix."""
2+
3+
import pytest
4+
5+
from flaml.tune.searcher.search_thread import _recursive_dict_update
6+
7+
8+
def test_recursive_dict_update_simple():
9+
"""Test simple non-nested dictionary update."""
10+
target = {"a": 1, "b": 2}
11+
source = {"c": 3}
12+
_recursive_dict_update(target, source)
13+
assert target == {"a": 1, "b": 2, "c": 3}
14+
15+
16+
def test_recursive_dict_update_override():
17+
"""Test that source values override target values for non-dict values."""
18+
target = {"a": 1, "b": 2}
19+
source = {"b": 3}
20+
_recursive_dict_update(target, source)
21+
assert target == {"a": 1, "b": 3}
22+
23+
24+
def test_recursive_dict_update_nested():
25+
"""Test nested dictionary merge (the main use case for XGBoost params)."""
26+
target = {
27+
"num_boost_round": 10,
28+
"params": {
29+
"max_depth": 12,
30+
"eta": 0.020168455186106736,
31+
"min_child_weight": 1.4504723523894132,
32+
"scale_pos_weight": 3.794258636185337,
33+
"gamma": 0.4985070123025904,
34+
},
35+
}
36+
source = {
37+
"params": {
38+
"verbosity": 3,
39+
"booster": "gbtree",
40+
"eval_metric": "auc",
41+
"tree_method": "hist",
42+
"objective": "binary:logistic",
43+
}
44+
}
45+
_recursive_dict_update(target, source)
46+
47+
# Check that sampled params are preserved
48+
assert target["params"]["max_depth"] == 12
49+
assert target["params"]["eta"] == 0.020168455186106736
50+
assert target["params"]["min_child_weight"] == 1.4504723523894132
51+
assert target["params"]["scale_pos_weight"] == 3.794258636185337
52+
assert target["params"]["gamma"] == 0.4985070123025904
53+
54+
# Check that const params are added
55+
assert target["params"]["verbosity"] == 3
56+
assert target["params"]["booster"] == "gbtree"
57+
assert target["params"]["eval_metric"] == "auc"
58+
assert target["params"]["tree_method"] == "hist"
59+
assert target["params"]["objective"] == "binary:logistic"
60+
61+
# Check top-level param is preserved
62+
assert target["num_boost_round"] == 10
63+
64+
65+
def test_recursive_dict_update_deeply_nested():
66+
"""Test deeply nested dictionary merge."""
67+
target = {"a": {"b": {"c": 1, "d": 2}}}
68+
source = {"a": {"b": {"e": 3}}}
69+
_recursive_dict_update(target, source)
70+
assert target == {"a": {"b": {"c": 1, "d": 2, "e": 3}}}
71+
72+
73+
def test_recursive_dict_update_mixed_types():
74+
"""Test that non-dict values in source replace dict values in target."""
75+
target = {"a": {"b": 1}}
76+
source = {"a": 2}
77+
_recursive_dict_update(target, source)
78+
assert target == {"a": 2}
79+
80+
81+
def test_recursive_dict_update_empty_dicts():
82+
"""Test with empty dictionaries."""
83+
target = {}
84+
source = {"a": 1}
85+
_recursive_dict_update(target, source)
86+
assert target == {"a": 1}
87+
88+
target = {"a": 1}
89+
source = {}
90+
_recursive_dict_update(target, source)
91+
assert target == {"a": 1}
92+
93+
94+
def test_recursive_dict_update_none_values():
95+
"""Test that None values are properly handled."""
96+
target = {"a": 1, "b": None}
97+
source = {"b": 2, "c": None}
98+
_recursive_dict_update(target, source)
99+
assert target == {"a": 1, "b": 2, "c": None}

0 commit comments

Comments
 (0)