Skip to content

Commit c208209

Browse files
authored
feat: enable meta planner (#1103)
* enable meta planner * fix a small bug * ADD PLAN TO GEN * remove ensemble in planner * fix CI * fix CI * fix planner threshold
1 parent 13c92a9 commit c208209

File tree

13 files changed

+301
-68
lines changed

13 files changed

+301
-68
lines changed

rdagent/app/data_science/conf.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
1818
- For custom data science scenarios, use: "rdagent.scenarios.data_science.scen.DataScienceScen"
1919
"""
2020

21-
hypothesis_gen: str = "rdagent.scenarios.data_science.proposal.exp_gen.proposal.DSProposalV2ExpGen"
21+
planner: str = "rdagent.scenarios.data_science.proposal.exp_gen.planner.DSExpPlannerHandCraft"
22+
hypothesis_gen: str = "rdagent.scenarios.data_science.proposal.exp_gen.router.ParallelMultiTraceExpGen"
2223
"""Hypothesis generation class"""
2324

2425
summarizer: str = "rdagent.scenarios.data_science.dev.feedback.DSExperiment2Feedback"
@@ -99,21 +100,24 @@ class DataScienceBasePropSetting(KaggleBasePropSetting):
99100
# inject diverse when start a new sub-trace
100101
enable_inject_diverse: bool = False
101102

102-
# inject knowledge at the root of the trace
103-
enable_inject_knowledge_at_root: bool = False
104-
105103
# enable different version of DSExpGen for multi-trace
106104
enable_multi_version_exp_gen: bool = False
107105
exp_gen_version_list: str = "v3,v2"
108106

109107
#### multi-trace: time for final multi-trace merge
110-
merge_hours: int = 2
108+
merge_hours: int = 0
111109
"""The time for merge"""
112110

113111
#### multi-trace: max SOTA-retrieved number, used in AutoSOTAexpSelector
114112
# constrains the number of SOTA experiments to retrieve, otherwise too many SOTA experiments to retrieve will cause the exceed of the context window of LLM
115113
max_sota_retrieved_num: int = 10
116114
"""The maximum number of SOTA experiments to retrieve in a LLM call"""
117115

116+
#### enable draft before first sota experiment
117+
enable_draft_before_first_sota: bool = False
118+
enable_planner: bool = False
119+
120+
model_architecture_suggestion_time_percent: float = 0.75
121+
118122

119123
DS_RD_SETTING = DataScienceBasePropSetting()

rdagent/components/proposal/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from rdagent.core.experiment import Experiment
55
from rdagent.core.proposal import (
6+
ExperimentPlan,
67
Hypothesis,
78
Hypothesis2Experiment,
89
HypothesisGen,
@@ -25,7 +26,11 @@ def prepare_context(self, trace: Trace) -> Tuple[dict, bool]: ...
2526
@abstractmethod
2627
def convert_response(self, response: str) -> Hypothesis: ...
2728

28-
def gen(self, trace: Trace) -> Hypothesis:
29+
def gen(
30+
self,
31+
trace: Trace,
32+
plan: ExperimentPlan | None = None,
33+
) -> Hypothesis:
2934
context_dict, json_flag = self.prepare_context(trace)
3035

3136
system_prompt = T(".prompts:hypothesis_gen.system_prompt").r(

rdagent/core/experiment.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,12 @@ def __str__(self) -> str:
292292
ASpecificWSForSubTasks = TypeVar("ASpecificWSForSubTasks", bound=Workspace)
293293

294294

295+
class ExperimentPlan(dict[str, Any]):
296+
"""
297+
A plan for the experiment, which is a dictionary that contains the plan to each stage.
298+
"""
299+
300+
295301
class Experiment(
296302
ABC,
297303
Generic[ASpecificTask, ASpecificWSForExperiment, ASpecificWSForSubTasks],
@@ -337,6 +343,9 @@ def __init__(
337343

338344
# For parallel multi-trace support
339345
self.local_selection: tuple[int, ...] | None = None
346+
self.plan: ExperimentPlan | None = (
347+
None # To store the planning information for this experiment, should be generated inside exp_gen.gen
348+
)
340349

341350
@property
342351
def result(self) -> object:
@@ -348,6 +357,7 @@ def result(self, value: object) -> None:
348357

349358

350359
ASpecificExp = TypeVar("ASpecificExp", bound=Experiment)
360+
ASpecificPlan = TypeVar("ASpecificPlan", bound=ExperimentPlan)
351361

352362
TaskOrExperiment = TypeVar("TaskOrExperiment", Task, Experiment)
353363

rdagent/core/proposal.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88

99
from rdagent.core.conf import RD_AGENT_SETTINGS
1010
from rdagent.core.evaluation import Feedback
11-
from rdagent.core.experiment import ASpecificExp, Experiment
11+
from rdagent.core.experiment import (
12+
ASpecificExp,
13+
ASpecificPlan,
14+
Experiment,
15+
ExperimentPlan,
16+
)
1217
from rdagent.core.knowledge_base import KnowledgeBase
1318
from rdagent.core.scenario import Scenario
1419

@@ -268,15 +273,34 @@ def get_sota_exp_to_submit(self, trace: Trace) -> Experiment | None:
268273
"""
269274

270275

276+
class ExpPlanner(ABC, Generic[ASpecificPlan]):
277+
"""
278+
An abstract class for planning the experiment.
279+
The planner should generate a plan for the experiment based on the trace.
280+
"""
281+
282+
def __init__(self, scen: Scenario) -> None:
283+
self.scen = scen
284+
285+
@abstractmethod
286+
def plan(self, trace: Trace) -> ASpecificPlan:
287+
"""
288+
Generate a plan for the experiment based on the trace.
289+
The plan should be a dictionary that contains the plan to each stage.
290+
"""
291+
292+
271293
class ExpGen(ABC):
272294

273295
def __init__(self, scen: Scenario) -> None:
274296
self.scen = scen
275297

276298
@abstractmethod
277-
def gen(self, trace: Trace) -> Experiment:
299+
def gen(self, trace: Trace, plan: ExperimentPlan | None = None) -> Experiment:
278300
"""
279301
Generate the experiment based on the trace.
302+
Planning is part of gen, but since we may support multi-stage planning,
303+
we need to pass plan as optional argument.
280304
281305
`ExpGen().gen()` play a role like
282306
@@ -306,7 +330,11 @@ def __init__(self, scen: Scenario) -> None:
306330
self.scen = scen
307331

308332
@abstractmethod
309-
def gen(self, trace: Trace) -> Hypothesis:
333+
def gen(
334+
self,
335+
trace: Trace,
336+
plan: ExperimentPlan | None = None,
337+
) -> Hypothesis:
310338
# def gen(self, scenario_desc: str, ) -> Hypothesis:
311339
"""
312340
Motivation of the variable `scenario_desc`:

rdagent/scenarios/data_science/proposal/exp_gen/draft/draft.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from rdagent.oai.llm_utils import APIBackend
1616
from rdagent.scenarios.data_science.experiment.experiment import COMPONENT, DSExperiment
1717
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSHypothesis, DSTrace
18+
from rdagent.scenarios.data_science.proposal.exp_gen.planner import DSExperimentPlan
1819
from rdagent.scenarios.data_science.proposal.exp_gen.utils import (
1920
CodingSketch,
2021
get_component,
@@ -61,6 +62,7 @@ def gen(
6162
self,
6263
component: COMPONENT,
6364
trace: DSTrace,
65+
plan: DSExperimentPlan | None = None,
6466
) -> DSExperiment:
6567
"""Handle any component using a unified approach.
6668
@@ -234,7 +236,11 @@ def task_gen(
234236
exp.pending_tasks_list.append([workflow_task])
235237
return exp
236238

237-
def gen(self, trace: DSTrace) -> DSExperiment:
239+
def gen(
240+
self,
241+
trace: DSTrace,
242+
plan: DSExperimentPlan | None = None,
243+
) -> DSExperiment:
238244
# Step 0: Prepare
239245
pipeline = DS_RD_SETTING.coder_on_whole_pipeline
240246
if pipeline:

rdagent/scenarios/data_science/proposal/exp_gen/merge.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@
1313
from rdagent.scenarios.data_science.experiment.experiment import DSExperiment
1414
from rdagent.scenarios.data_science.loop import DataScienceRDLoop
1515
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSHypothesis, DSTrace
16+
from rdagent.scenarios.data_science.proposal.exp_gen.planner import DSExperimentPlan
1617
from rdagent.scenarios.data_science.proposal.exp_gen.proposal import DSProposalV2ExpGen
1718
from rdagent.utils.agent.tpl import T
1819
from rdagent.utils.workflow import wait_retry
1920

2021

2122
class MergeExpGen(ExpGen):
22-
def gen(self, trace: DSTrace) -> DSExperiment:
23+
def gen(
24+
self,
25+
trace: DSTrace,
26+
plan: DSExperimentPlan | None = None,
27+
) -> DSExperiment:
2328
# Ignore the selection argument and use all leaves instead.
2429
leaves: list[int] = trace.get_leaves()
2530
trace.set_current_selection((leaves[0],)) # override the current selection.
@@ -136,7 +141,11 @@ def get_exp_index(self, trace: DSTrace) -> int:
136141
return min(trace_scores, key=lambda item: item[1])[0]
137142
return next((i for i, leaf in enumerate(leaves) if leaf != trace.current_selection[0]))
138143

139-
def gen(self, trace: DSTrace) -> DSExperiment:
144+
def gen(
145+
self,
146+
trace: DSTrace,
147+
plan: DSExperimentPlan | None = None,
148+
) -> DSExperiment:
140149
# Ignore the selection argument and use all leaves instead.
141150
sota_exp_fb = trace.sota_experiment_fb(selection=trace.current_selection)
142151

@@ -231,7 +240,11 @@ def __init__(self, *args, **kwargs):
231240
self.merge_exp_gen = MergeExpGen(self.scen)
232241
self.exp_gen = DataScienceRDLoop.default_exp_gen(self.scen)
233242

234-
def gen(self, trace: DSTrace) -> DSExperiment:
243+
def gen(
244+
self,
245+
trace: DSTrace,
246+
plan: DSExperimentPlan | None = None,
247+
) -> DSExperiment:
235248
timer: RDAgentTimer = RD_Agent_TIMER_wrapper.timer
236249
logger.info(f"Remain time: {timer.remain_time()}")
237250

@@ -257,7 +270,11 @@ def gen(self, trace: DSTrace) -> DSExperiment:
257270

258271

259272
class MergeExpGen_MultiTrace(ExpGen):
260-
def gen(self, trace: DSTrace) -> DSExperiment:
273+
def gen(
274+
self,
275+
trace: DSTrace,
276+
plan: DSExperimentPlan | None = None,
277+
) -> DSExperiment:
261278
# Ignore the selection argument and use all leaves instead.
262279
leaves: list[int] = trace.get_leaves()
263280

@@ -347,18 +364,13 @@ def reset_exp_gen_version(self, version: str = "v2"):
347364
# )
348365
raise NotImplementedError("You should not switch version with proposal_version")
349366

350-
def gen(self, trace: DSTrace, selection: tuple[int, ...] = (-1,)) -> DSExperiment:
367+
def gen(
368+
self, trace: DSTrace, plan: DSExperimentPlan | None = None, selection: tuple[int, ...] = (-1,)
369+
) -> DSExperiment:
351370
timer: RDAgentTimer = RD_Agent_TIMER_wrapper.timer
352371
logger.info(f"Remain time: {timer.remain_time()}")
353372

354373
if timer.remain_time() >= timedelta(hours=DS_RD_SETTING.merge_hours):
355-
356-
if DS_RD_SETTING.enable_inject_knowledge_at_root:
357-
if DS_RD_SETTING.knowledge_base_path is not None and DS_RD_SETTING.idea_pool_json_path is not None:
358-
if len(trace.hist) == 0:
359-
# set the knowledge base option to True for the first trace
360-
DS_RD_SETTING.enable_knowledge_base = True
361-
362374
if DS_RD_SETTING.enable_multi_version_exp_gen:
363375
exp_gen_version_list = DS_RD_SETTING.exp_gen_version_list.split(",")
364376
for version in exp_gen_version_list:
@@ -402,21 +414,15 @@ def __init__(self, *args, **kwargs):
402414
self.merge_exp_gen = ExpGen2Hypothesis(self.scen)
403415
self.exp_gen = DataScienceRDLoop.default_exp_gen(self.scen)
404416

405-
def gen(self, trace: DSTrace) -> DSExperiment:
417+
def gen(
418+
self,
419+
trace: DSTrace,
420+
plan: DSExperimentPlan | None = None,
421+
) -> DSExperiment:
406422
timer: RDAgentTimer = RD_Agent_TIMER_wrapper.timer
407423
logger.info(f"Remain time: {timer.remain_time()}")
408424

409425
if timer.remain_time() >= timedelta(hours=DS_RD_SETTING.merge_hours):
410-
411-
if DS_RD_SETTING.enable_inject_knowledge_at_root:
412-
413-
if len(trace.hist) == 0:
414-
# set the knowledge base option to True for the first trace
415-
DS_RD_SETTING.enable_knowledge_base = True
416-
417-
else:
418-
# set the knowledge base option back to False for the other traces
419-
DS_RD_SETTING.enable_knowledge_base = False
420426
return self.exp_gen.gen(trace)
421427
else:
422428
# disable reset in merging stage

rdagent/scenarios/data_science/proposal/exp_gen/naive.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,17 @@
66
from rdagent.core.proposal import ExpGen
77
from rdagent.scenarios.data_science.experiment.experiment import DSExperiment
88
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSHypothesis, DSTrace
9+
from rdagent.scenarios.data_science.proposal.exp_gen.router import DSExperimentPlan
910
from rdagent.utils.agent.tpl import T
1011
from rdagent.utils.agent.workflow import build_cls_from_json_with_retry
1112

1213

1314
class NaiveExpGen(ExpGen):
14-
def gen(self, trace: DSTrace) -> DSExperiment:
15+
def gen(
16+
self,
17+
trace: DSTrace,
18+
plan: DSExperimentPlan | None = None,
19+
) -> DSExperiment:
1520
sota_exp = trace.sota_experiment()
1621
scenario_desc = trace.scen.get_scenario_all_desc()
1722
sota_exp_desc = T("scenarios.data_science.share:describe.exp").r(

rdagent/scenarios/data_science/proposal/exp_gen/parallel.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from rdagent.app.data_science.conf import DS_RD_SETTING
88
from rdagent.core.conf import RD_AGENT_SETTINGS
9-
from rdagent.core.proposal import ExpGen
9+
from rdagent.core.proposal import ExperimentPlan, ExpGen
1010
from rdagent.log import rdagent_logger as logger
1111
from rdagent.log.timer import RD_Agent_TIMER_wrapper, RDAgentTimer
1212
from rdagent.scenarios.data_science.loop import DataScienceRDLoop
@@ -38,7 +38,11 @@ def __init__(self, *args, **kwargs):
3838
self.merge_exp_gen = ExpGen2Hypothesis(self.scen)
3939
self.trace_scheduler: TraceScheduler = RoundRobinScheduler(DS_RD_SETTING.max_trace_num)
4040

41-
def gen(self, trace: "DSTrace") -> "Experiment":
41+
def gen(
42+
self,
43+
trace: "DSTrace",
44+
plan: "ExperimentPlan" | None = None,
45+
) -> "Experiment":
4246
raise NotImplementedError(
4347
"ParallelMultiTraceExpGen is designed for async usage, please call async_gen instead."
4448
)
@@ -57,16 +61,6 @@ async def async_gen(self, trace: DSTrace, loop: LoopBase) -> DSExperiment:
5761

5862
if timer.remain_time() >= timedelta(hours=DS_RD_SETTING.merge_hours):
5963

60-
if DS_RD_SETTING.enable_inject_knowledge_at_root:
61-
62-
if len(trace.hist) == 0:
63-
# set the knowledge base option to True for the first trace
64-
DS_RD_SETTING.enable_knowledge_base = True
65-
66-
else:
67-
# set the knowledge base option back to False for the other traces
68-
DS_RD_SETTING.enable_knowledge_base = False
69-
7064
if loop.get_unfinished_loop_cnt(loop.loop_idx) < RD_AGENT_SETTINGS.get_max_parallel():
7165
local_selection = await self.trace_scheduler.next(trace)
7266

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from datetime import timedelta
2+
3+
from rdagent.app.data_science.conf import DS_RD_SETTING
4+
from rdagent.components.coder.CoSTEER import RD_Agent_TIMER_wrapper
5+
from rdagent.core.proposal import ExperimentPlan, ExpPlanner
6+
from rdagent.scenarios.data_science.proposal.exp_gen.base import DSTrace
7+
8+
9+
class DSExperimentPlan(ExperimentPlan):
10+
"""
11+
A specific plan for data science experiments.
12+
This plan can include various stages such as proposal, draft, and merge.
13+
"""
14+
15+
def __init__(self):
16+
super().__init__()
17+
self.setdefault("exp_gen", {}).setdefault("draft", False)
18+
self.setdefault("exp_gen", {}).setdefault("suggest_model_architecture", False)
19+
self.setdefault("exp_gen", {}).setdefault("suggest_model_ensemble", False)
20+
21+
22+
class DSExpPlannerHandCraft(ExpPlanner[DSExperimentPlan]):
23+
"""
24+
A specific planner for data science experiments.
25+
"""
26+
27+
def plan(self, trace: DSTrace) -> DSExperimentPlan:
28+
"""
29+
Generate a plan for the experiment based on the trace.
30+
The plan should be a dictionary that contains the plan to each stage.
31+
trace is well selected into sub trace mode
32+
"""
33+
plan = DSExperimentPlan()
34+
timer = RD_Agent_TIMER_wrapper.timer
35+
remain_percent = timer.remain_time() / timer.all_duration if timer.started else 1.0
36+
37+
if not trace.sota_experiment():
38+
plan["exp_gen"]["draft"] = True
39+
elif trace.sota_experiment() and remain_percent > DS_RD_SETTING.model_architecture_suggestion_time_percent:
40+
plan["exp_gen"]["suggest_model_architecture"] = True
41+
# elif DS_RD_SETTING.merge_hours > 0:
42+
# merge_percent = timedelta(hours=DS_RD_SETTING.merge_hours) / timer.all_duration
43+
# if merge_percent < remain_percent < merge_percent + 0.1:
44+
# plan["exp_gen"]["suggest_model_ensemble"] = True
45+
return plan

0 commit comments

Comments
 (0)