Skip to content

Commit 06ba314

Browse files
qew21you-n-gXupeteryangms
authored
feat: merge selectively (#888)
* chore: avoid incorporate changes best as sota merge hypothesis fix: max_retrieve_num after decision chore: select last experiments and feedbacks * add the set_current_selection before the exp_gen when merging add trace.NEW_ROOT fix: no scen_prob_multiplier fix: use regex with timeout chore: hypothesis_rank with selected_idx chore: define is_parent in proposal chore: rename collect_all_ancestors to get_parent_exps --------- Co-authored-by: you-n-g <[email protected]> Co-authored-by: Young <[email protected]> Co-authored-by: Xu <[email protected]> Co-authored-by: Xu Yang <[email protected]>
1 parent 837fff2 commit 06ba314

File tree

7 files changed

+411
-154
lines changed

7 files changed

+411
-154
lines changed

rdagent/core/proposal.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@
33
from __future__ import annotations
44

55
from abc import ABC, abstractmethod
6-
from typing import Generic, TypeVar
6+
from typing import Generic, List, Tuple, TypeVar
77

88
from rdagent.core.evaluation import Feedback
99
from rdagent.core.experiment import ASpecificExp, Experiment
1010
from rdagent.core.knowledge_base import KnowledgeBase
1111
from rdagent.core.scenario import Scenario
1212

13-
# class data_ana: XXX
14-
1513

1614
class Hypothesis:
1715
"""
@@ -105,6 +103,7 @@ def __str__(self) -> str:
105103

106104
class Trace(Generic[ASpecificScen, ASpecificKB]):
107105
NodeType = tuple[Experiment, ExperimentFeedback] # Define NodeType as a new type representing the tuple
106+
NEW_ROOT: Tuple = ()
108107

109108
def __init__(self, scen: ASpecificScen, knowledge_base: ASpecificKB | None = None) -> None:
110109
self.scen: ASpecificScen = scen
@@ -116,6 +115,7 @@ def __init__(self, scen: ASpecificScen, knowledge_base: ASpecificKB | None = Non
116115

117116
# TODO: self.hist is 2-tuple now, remove hypothesis from it, change old code for this later.
118117
self.knowledge_base: ASpecificKB | None = knowledge_base
118+
self.current_selection: tuple[int, ...] = (-1,)
119119

120120
def get_sota_hypothesis_and_experiment(self) -> tuple[Hypothesis | None, Experiment | None]:
121121
"""Access the last experiment result, sub-task, and the corresponding hypothesis."""
@@ -126,6 +126,77 @@ def get_sota_hypothesis_and_experiment(self) -> tuple[Hypothesis | None, Experim
126126

127127
return None, None
128128

129+
def is_selection_new_tree(self, selection: tuple[int, ...] | None = None) -> bool:
130+
"""
131+
Check if the current trace is a new tree.
132+
- selection maybe (-1,) when the dag_parent is empty.
133+
"""
134+
if selection is None:
135+
selection = self.get_current_selection()
136+
137+
return selection == self.NEW_ROOT or len(self.dag_parent) == 0
138+
139+
def get_current_selection(self) -> tuple[int, ...]:
140+
return self.current_selection
141+
142+
def set_current_selection(self, selection: tuple[int, ...]) -> None:
143+
self.current_selection = selection
144+
145+
def get_parent_exps(
146+
self,
147+
selection: tuple[int, ...] | None = None,
148+
) -> list[Trace.NodeType]:
149+
"""
150+
Collect all ancestors of the given selection.
151+
The return list follows the order of [root->...->parent->current_node].
152+
"""
153+
if selection is None:
154+
selection = self.get_current_selection()
155+
156+
if self.is_selection_new_tree(selection):
157+
return []
158+
159+
return [self.hist[i] for i in self.get_parents(selection[0])]
160+
161+
def exp2idx(self, exp: Experiment | List[Experiment]) -> int | List[int] | None:
162+
if isinstance(exp, list):
163+
exps: List[Experiment] = exp
164+
165+
# keep the order
166+
exp_to_index: dict[Experiment, int] = {_exp: i for i, (_exp, _) in enumerate(self.hist)}
167+
return [exp_to_index[_exp] for _exp in exps]
168+
else:
169+
for i, (_exp, _) in enumerate(self.hist):
170+
if _exp == exp:
171+
return i
172+
return None
173+
174+
def idx2exp(self, idx: int | List[int]) -> Experiment | List[Experiment]:
175+
if isinstance(idx, list):
176+
idxs: List[int] = idx
177+
return [self.hist[_idx][0] for _idx in idxs]
178+
else:
179+
return self.hist[idx][0]
180+
181+
def is_parent(self, parent_idx: int, child_idx: int) -> bool:
182+
ancestors = self.get_parents(child_idx)
183+
return parent_idx in ancestors
184+
185+
def get_parents(self, child_idx: int) -> List[int]:
186+
if self.is_selection_new_tree((child_idx,)):
187+
return []
188+
189+
ancestors: List[int] = []
190+
curr = child_idx
191+
while True:
192+
ancestors.insert(0, curr)
193+
parent_tuple = self.dag_parent[curr]
194+
if not parent_tuple or parent_tuple[0] == curr:
195+
break
196+
curr = parent_tuple[0]
197+
198+
return ancestors
199+
129200

130201
class CheckpointSelector:
131202
"""

rdagent/scenarios/data_science/proposal/exp_gen/base.py

Lines changed: 9 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import abstractmethod
2-
from typing import Literal
2+
from typing import List, Literal
33

44
from rdagent.app.data_science.conf import DS_RD_SETTING
55
from rdagent.core.evolving_framework import KnowledgeBase
@@ -61,21 +61,13 @@ def __init__(self, scen: DataScienceScen, knowledge_base: KnowledgeBase | None =
6161

6262
self.knowledge_base = knowledge_base
6363

64-
self.current_selection: tuple[int, ...] = (-1,)
65-
6664
self.sota_exp_to_submit: DSExperiment | None = None # grab the global best exp to submit
6765

6866
COMPLETE_ORDER = ("DataLoadSpec", "FeatureEng", "Model", "Ensemble", "Workflow")
6967

7068
def set_sota_exp_to_submit(self, exp: DSExperiment) -> None:
7169
self.sota_exp_to_submit = exp
7270

73-
def get_current_selection(self) -> tuple[int, ...]:
74-
return self.current_selection
75-
76-
def set_current_selection(self, selection: tuple[int, ...]) -> None:
77-
self.current_selection = selection
78-
7971
@property
8072
def sub_trace_count(self) -> int:
8173
return len(self.get_leaves())
@@ -144,50 +136,11 @@ def retrieve_search_list(
144136
return self.hist
145137

146138
elif search_type == "ancestors":
147-
148-
if selection is None:
149-
selection = self.get_current_selection()
150-
151-
if len(selection) == 0:
152-
# selection is (), which means we switch to a new trace
153-
return []
154-
155-
return self.collect_all_ancestors(selection)
139+
return self.get_parent_exps(selection)
156140

157141
else:
158142
raise ValueError(f"Invalid search type: {search_type}")
159143

160-
def collect_all_ancestors(
161-
self,
162-
selection: tuple[int, ...] | None = None,
163-
) -> list[tuple[DSExperiment, ExperimentFeedback]]:
164-
"""
165-
Collect all ancestors of the given selection.
166-
The return list follows the order of [root->...->parent->current_node].
167-
"""
168-
if selection is None:
169-
selection = self.get_current_selection()
170-
171-
if len(self.dag_parent) == 0:
172-
return []
173-
174-
else:
175-
all_ancestors = []
176-
177-
# start from the latest selection
178-
current_node_idx = selection[0]
179-
180-
# add the current node to the list
181-
all_ancestors.insert(0, self.hist[current_node_idx])
182-
183-
parent_idx = self.dag_parent[current_node_idx]
184-
185-
while len(parent_idx) > 0:
186-
all_ancestors.insert(0, self.hist[parent_idx[0]])
187-
parent_idx = self.dag_parent[parent_idx[0]]
188-
189-
return all_ancestors
190-
191144
def next_incomplete_component(
192145
self,
193146
search_type: Literal["all", "ancestors"] = "ancestors",
@@ -226,10 +179,6 @@ def experiment_and_feedback_list_after_init(
226179
Retrieve a list of experiments and feedbacks based on the return_type.
227180
"""
228181
search_list = self.retrieve_search_list(search_type, selection=selection)
229-
if max_retrieve_num is not None and len(search_list) > 0:
230-
retrieve_num = min(max_retrieve_num, len(search_list))
231-
search_list = search_list[:retrieve_num]
232-
233182
final_component = self.COMPLETE_ORDER[-1]
234183
has_final_component = True if DS_RD_SETTING.coder_on_whole_pipeline else False
235184
SOTA_exp_and_feedback_list = []
@@ -243,6 +192,13 @@ def experiment_and_feedback_list_after_init(
243192
failed_exp_and_feedback_list.append((exp, fb))
244193
if exp.hypothesis.component == final_component and fb:
245194
has_final_component = True
195+
if max_retrieve_num is not None and (SOTA_exp_and_feedback_list or failed_exp_and_feedback_list):
196+
SOTA_exp_and_feedback_list = SOTA_exp_and_feedback_list[
197+
-min(max_retrieve_num, len(SOTA_exp_and_feedback_list)) :
198+
]
199+
failed_exp_and_feedback_list = failed_exp_and_feedback_list[
200+
-min(max_retrieve_num, len(failed_exp_and_feedback_list)) :
201+
]
246202
if return_type == "all":
247203
return SOTA_exp_and_feedback_list + failed_exp_and_feedback_list
248204
elif return_type == "failed":

rdagent/scenarios/data_science/proposal/exp_gen/ckp_select.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def get_selection(self, trace: Trace) -> tuple[int, ...]:
5656
5757
Returns:
5858
(-1,): Continue with the current latest trial
59-
(): Start a new sub-trace if max trace limit not reached
59+
trace.NEW_ROOT: Start a new sub-trace if max trace limit not reached
6060
"""
6161

6262
if self.time_limit_pre_trace is None:
@@ -69,8 +69,8 @@ def get_selection(self, trace: Trace) -> tuple[int, ...]:
6969
logger.info(f"Starting initial sub-trace {trace.sub_trace_count} at {current_time}")
7070
return (-1,) # Continue with latest trial for new sub-trace
7171

72-
# Calculate elapsed time for current sub-trace
73-
elapsed_time = current_time - self.sub_trace_start_times[trace.sub_trace_count - 1]
72+
# Calculate elapsed time for current sub-trace, Trace count may be larger than MAX_TRACE_NUM druing merge process
73+
elapsed_time = current_time - self.sub_trace_start_times[min(trace.sub_trace_count, self.MAX_TRACE_NUM) - 1]
7474

7575
if elapsed_time < self.time_limit_pre_trace:
7676
# Continue with current sub-trace
@@ -94,7 +94,7 @@ def get_selection(self, trace: Trace) -> tuple[int, ...]:
9494
f"Elapsed time {elapsed_time} exceeds time limit {self.time_limit_pre_trace}, jump to a new sub-trace"
9595
)
9696
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
97-
return tuple() # Empty tuple signals starting a new sub-trace
97+
return trace.NEW_ROOT # Empty tuple signals starting a new sub-trace
9898

9999

100100
class SOTAJumpCKPSelector(CheckpointSelector):
@@ -140,7 +140,7 @@ def get_selection(self, trace: Trace) -> tuple[int, ...]:
140140
f"SOTA count {sota_count} is below threshold {self.SOTA_COUNT_THRESHOLD}, jump to a new sub-trace"
141141
)
142142
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
143-
return ()
143+
return trace.NEW_ROOT
144144
else:
145145
logger.info(
146146
f"SOTA count {sota_count} is above threshold {self.SOTA_COUNT_THRESHOLD}, continue the current latest trial"
@@ -201,7 +201,7 @@ def get_selection(self, trace: Trace) -> tuple[int, ...]:
201201
logger.info(
202202
f"SOTA count {sota_count} is below threshold {self.SOTA_COUNT_THRESHOLD}, jump a new sub-trace"
203203
)
204-
return () # reboot a new sub-trace
204+
return trace.NEW_ROOT # reboot a new sub-trace
205205
else:
206206
logger.info(
207207
f"SOTA count {sota_count} is below threshold {self.SOTA_COUNT_THRESHOLD}, jump back to the last second SOTA in hist (may not in current sub-trace)"
@@ -227,7 +227,7 @@ def get_selection(self, trace: Trace) -> tuple[int, ...]:
227227
f"SOTA count {sota_count} is below threshold {self.SOTA_COUNT_THRESHOLD}, jump a new sub-trace"
228228
)
229229
logger.info(f"current sub-trace count: {trace.sub_trace_count}")
230-
return () # reboot a new sub-trace
230+
return trace.NEW_ROOT # reboot a new sub-trace
231231

232232
else:
233233
logger.info(

0 commit comments

Comments
 (0)