Skip to content

Commit 78c203d

Browse files
you-n-gXianBWSunsetWolf
authored
fix: add async to direct_exp_gen avoid infinite loop (#992)
* refactor: convert direct_exp_gen to async and enforce parallel limit * fix bug * change coroutine function position * fix fin_quant's direct_exp_gen * format with isort --------- Co-authored-by: Bowen Xian <[email protected]> Co-authored-by: SunsetWolf <[email protected]>
1 parent 9e60c32 commit 78c203d

File tree

4 files changed

+39
-30
lines changed

4 files changed

+39
-30
lines changed

rdagent/app/qlib_rd_loop/factor_from_report.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
extract_first_page_screenshot_from_pdf,
1212
load_and_process_pdfs_by_langchain,
1313
)
14+
from rdagent.core.conf import RD_AGENT_SETTINGS
1415
from rdagent.core.proposal import Hypothesis
1516
from rdagent.log import rdagent_logger as logger
1617
from rdagent.oai.llm_utils import APIBackend
@@ -105,21 +106,23 @@ def __init__(self, report_folder: str = None):
105106

106107
self.loop_n = min(len(self.judge_pdf_data_items), FACTOR_FROM_REPORT_PROP_SETTING.report_limit)
107108

108-
def direct_exp_gen(self, prev_out: dict[str, Any]):
109+
async def direct_exp_gen(self, prev_out: dict[str, Any]):
109110
while True:
110-
report_file_path = self.judge_pdf_data_items[self.loop_idx]
111-
logger.info(f"Processing number {self.loop_idx} report: {report_file_path}")
112-
exp = extract_hypothesis_and_exp_from_reports(str(report_file_path))
113-
if exp is None:
114-
continue
115-
exp.based_experiments = [QlibFactorExperiment(sub_tasks=[], hypothesis=exp.hypothesis)] + [
116-
t[0] for t in self.trace.hist if t[1]
117-
]
118-
exp.sub_workspace_list = exp.sub_workspace_list[: FACTOR_FROM_REPORT_PROP_SETTING.max_factors_per_exp]
119-
exp.sub_tasks = exp.sub_tasks[: FACTOR_FROM_REPORT_PROP_SETTING.max_factors_per_exp]
120-
logger.log_object(exp.hypothesis, tag="hypothesis generation")
121-
logger.log_object(exp.sub_tasks, tag="experiment generation")
122-
return exp
111+
if self.get_unfinished_loop_cnt(self.loop_idx) < RD_AGENT_SETTINGS.get_max_parallel():
112+
report_file_path = self.judge_pdf_data_items[self.loop_idx]
113+
logger.info(f"Processing number {self.loop_idx} report: {report_file_path}")
114+
exp = extract_hypothesis_and_exp_from_reports(str(report_file_path))
115+
if exp is None:
116+
continue
117+
exp.based_experiments = [QlibFactorExperiment(sub_tasks=[], hypothesis=exp.hypothesis)] + [
118+
t[0] for t in self.trace.hist if t[1]
119+
]
120+
exp.sub_workspace_list = exp.sub_workspace_list[: FACTOR_FROM_REPORT_PROP_SETTING.max_factors_per_exp]
121+
exp.sub_tasks = exp.sub_tasks[: FACTOR_FROM_REPORT_PROP_SETTING.max_factors_per_exp]
122+
logger.log_object(exp.hypothesis, tag="hypothesis generation")
123+
logger.log_object(exp.sub_tasks, tag="experiment generation")
124+
return exp
125+
await asyncio.sleep(1)
123126

124127
def coding(self, prev_out: dict[str, Any]):
125128
exp = self.coder.develop(prev_out["direct_exp_gen"])

rdagent/app/qlib_rd_loop/quant.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from rdagent.app.qlib_rd_loop.conf import QUANT_PROP_SETTING
1111
from rdagent.components.workflow.conf import BasePropSetting
1212
from rdagent.components.workflow.rd_loop import RDLoop
13+
from rdagent.core.conf import RD_AGENT_SETTINGS
1314
from rdagent.core.developer import Developer
1415
from rdagent.core.exception import FactorEmptyError, ModelEmptyError
1516
from rdagent.core.proposal import (
@@ -64,15 +65,18 @@ def __init__(self, PROP_SETTING: BasePropSetting):
6465
self.trace = QuantTrace(scen=scen)
6566
super(RDLoop, self).__init__()
6667

67-
def direct_exp_gen(self, prev_out: dict[str, Any]):
68-
hypo = self._propose()
69-
assert hypo.action in ["factor", "model"]
70-
if hypo.action == "factor":
71-
exp = self.factor_hypothesis2experiment.convert(hypo, self.trace)
72-
else:
73-
exp = self.model_hypothesis2experiment.convert(hypo, self.trace)
74-
logger.log_object(exp.sub_tasks, tag="experiment generation")
75-
return {"propose": hypo, "exp_gen": exp}
68+
async def direct_exp_gen(self, prev_out: dict[str, Any]):
69+
while True:
70+
if self.get_unfinished_loop_cnt(self.loop_idx) < RD_AGENT_SETTINGS.get_max_parallel():
71+
hypo = self._propose()
72+
assert hypo.action in ["factor", "model"]
73+
if hypo.action == "factor":
74+
exp = self.factor_hypothesis2experiment.convert(hypo, self.trace)
75+
else:
76+
exp = self.model_hypothesis2experiment.convert(hypo, self.trace)
77+
logger.log_object(exp.sub_tasks, tag="experiment generation")
78+
return {"propose": hypo, "exp_gen": exp}
79+
await asyncio.sleep(1)
7680

7781
def coding(self, prev_out: dict[str, Any]):
7882
if prev_out["direct_exp_gen"]["propose"].action == "factor":

rdagent/components/workflow/rd_loop.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
It is from `rdagent/app/qlib_rd_loop/model.py` and try to replace `rdagent/app/qlib_rd_loop/RDAgent.py`
44
"""
55

6+
import asyncio
67
from typing import Any
78

89
from rdagent.components.workflow.conf import BasePropSetting
10+
from rdagent.core.conf import RD_AGENT_SETTINGS
911
from rdagent.core.developer import Developer
1012
from rdagent.core.proposal import (
1113
Experiment2Feedback,
@@ -55,10 +57,13 @@ def _exp_gen(self, hypothesis: Hypothesis):
5557
return exp
5658

5759
# included steps
58-
def direct_exp_gen(self, prev_out: dict[str, Any]):
59-
hypo = self._propose()
60-
exp = self._exp_gen(hypo)
61-
return {"propose": hypo, "exp_gen": exp}
60+
async def direct_exp_gen(self, prev_out: dict[str, Any]):
61+
while True:
62+
if self.get_unfinished_loop_cnt(self.loop_idx) < RD_AGENT_SETTINGS.get_max_parallel():
63+
hypo = self._propose()
64+
exp = self._exp_gen(hypo)
65+
return {"propose": hypo, "exp_gen": exp}
66+
await asyncio.sleep(1)
6267

6368
def coding(self, prev_out: dict[str, Any]):
6469
exp = self.coder.develop(prev_out["direct_exp_gen"]["exp_gen"])

rdagent/scenarios/data_science/loop.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,6 @@ async def direct_exp_gen(self, prev_out: dict[str, Any]):
156156
exp = await self.exp_gen.async_gen(self.trace, self)
157157

158158
logger.log_object(exp)
159-
160-
# FIXME: this is for LLM debug webapp, remove this when the debugging is done.
161-
logger.log_object(exp, tag="debug_exp_gen")
162159
return exp
163160

164161
def coding(self, prev_out: dict[str, Any]):

0 commit comments

Comments
 (0)