Skip to content

Commit 1eacbad

Browse files
committed
add response format
1 parent 82d3f1f commit 1eacbad

File tree

5 files changed

+41
-38
lines changed

5 files changed

+41
-38
lines changed

rdagent/scenarios/data_science/dev/runner/eval.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ def evaluate(
5555
else:
5656
running_timeout_period = DS_RD_SETTING.full_timeout
5757
env = get_ds_env(
58-
extra_volumes={
59-
f"{DS_RD_SETTING.local_data_path}/{self.scen.competition}": T(
60-
"scenarios.data_science.share:scen.input_path"
61-
).r()
62-
},
63-
running_timeout_period=running_timeout_period,
64-
)
58+
extra_volumes={
59+
f"{DS_RD_SETTING.local_data_path}/{self.scen.competition}": T(
60+
"scenarios.data_science.share:scen.input_path"
61+
).r()
62+
},
63+
running_timeout_period=running_timeout_period,
64+
)
6565

6666
stdout = implementation.execute(
6767
env=env, entry=get_clear_ws_cmd()
Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import re
1+
import re
22
from rdagent.core.proposal import ExpGen
33
from rdagent.core.scenario import Scenario
44
from rdagent.oai.llm_utils import APIBackend
55

66

7-
8-
97
class DS_EnsembleExpGen(ExpGen):
108
def __init__(self, *args, **kwargs):
119
super().__init__(*args, **kwargs)
12-
self.supports_response_schema = APIBackend().supports_response_schema()
10+
self.supports_response_schema = APIBackend().supports_response_schema()

rdagent/scenarios/data_science/proposal/exp_gen/proposal.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -717,33 +717,33 @@ def hypothesis_rank(
717717
problem_desc=problem_dict.get("problem", "Problem description not provided"),
718718
problem_label=problem_dict.get("label", "FEEDBACK_PROBLEM"),
719719
)
720-
721-
def hypothesis_select_with_llm(self,
722-
scenario_desc: str,
723-
exp_feedback_list_desc: str,
724-
sota_exp_desc: str,
725-
hypothesis_candidates:dict):
726-
720+
721+
def hypothesis_select_with_llm(
722+
self, scenario_desc: str, exp_feedback_list_desc: str, sota_exp_desc: str, hypothesis_candidates: dict
723+
):
724+
727725
# time_use_current = 0
728726
# for exp, feedback in trace.hist:
729727
# if exp.running_info.running_time is not None:
730728
# time_use_current += exp.running_info.running_time
731729
# res_time = 12*3600 - time_use_current
732730
res_time = RD_Agent_TIMER_wrapper.timer.remain_time()
733731
total_time = RD_Agent_TIMER_wrapper.timer.all_duration
734-
use_time = round(total_time.total_seconds(),2) - round(res_time.total_seconds(),2)
735-
use_ratio = 100* use_time / round(total_time.total_seconds(),2)
732+
use_time = round(total_time.total_seconds(), 2) - round(res_time.total_seconds(), 2)
733+
use_ratio = 100 * use_time / round(total_time.total_seconds(), 2)
736734
use_ratio = round(use_ratio, 2)
737735

738736
ensemble_timeout = DS_RD_SETTING.ensemble_timeout
739-
hypothesis_candidates = str(json.dumps(hypothesis_candidates, indent=2))
737+
hypothesis_candidates = str(json.dumps(hypothesis_candidates, indent=2))
740738

741739
sys_prompt = T(".prompts_v2:hypothesis_select.system").r(
742-
hypothesis_candidates = hypothesis_candidates,
743-
res_time = round(res_time.total_seconds(),2),
744-
ensemble_timeout = ensemble_timeout,
745-
use_ratio = use_ratio,
746-
hypothesis_output_format = T(".prompts_v2:output_format.hypothesis_select_format").r(hypothesis_candidates = hypothesis_candidates)
740+
hypothesis_candidates=hypothesis_candidates,
741+
res_time=round(res_time.total_seconds(), 2),
742+
ensemble_timeout=ensemble_timeout,
743+
use_ratio=use_ratio,
744+
hypothesis_output_format=T(".prompts_v2:output_format.hypothesis_select_format").r(
745+
hypothesis_candidates=hypothesis_candidates
746+
),
747747
)
748748

749749
user_prompt = T(".prompts_v2:hypothesis_select.user").r(
@@ -755,12 +755,15 @@ def hypothesis_select_with_llm(self,
755755
response = APIBackend().build_messages_and_create_chat_completion(
756756
user_prompt=user_prompt,
757757
system_prompt=sys_prompt,
758+
response_format=HypothesisList if self.supports_response_schema else {"type": "json_object"},
759+
json_target_type=(
760+
Dict[str, Dict[str, str | Dict[str, str | int]]] if not self.supports_response_schema else None
761+
),
758762
)
759763

760764
response_dict = json.loads(response)
761765
return response_dict
762766

763-
764767
def task_gen(
765768
self,
766769
component_desc: str,
@@ -846,7 +849,7 @@ def get_scenario_all_desc(self, trace: DSTrace, eda_output=None) -> str:
846849
raw_description=trace.scen.raw_description,
847850
use_raw_description=DS_RD_SETTING.use_raw_description,
848851
time_limit=f"{DS_RD_SETTING.full_timeout / 60 / 60 : .2f} hours",
849-
ensemble_limit = f"{DS_RD_SETTING.ensemble_timeout / 60 / 60 : .2f} hours",
852+
ensemble_limit=f"{DS_RD_SETTING.ensemble_timeout / 60 / 60 : .2f} hours",
850853
eda_output=eda_output,
851854
)
852855

@@ -868,7 +871,7 @@ def get_all_hypotheses(self, problem_dict: dict, hypothesis_dict: dict) -> list[
868871

869872
def gen(
870873
self,
871-
trace: DSTrace,
874+
trace: DSTrace,
872875
) -> DSExperiment:
873876
pipeline = DS_RD_SETTING.coder_on_whole_pipeline
874877
if not pipeline and (draft_exp := draft_exp_in_decomposition(self.scen, trace)):
@@ -973,11 +976,12 @@ def gen(
973976
# problem_dict= all_problems,
974977
# )
975978

976-
response_dict= self.hypothesis_select_with_llm(scenario_desc=scenario_desc,
977-
exp_feedback_list_desc=exp_feedback_list_desc,
978-
sota_exp_desc=sota_exp_desc,
979-
hypothesis_candidates =hypothesis_dict
980-
)
979+
response_dict = self.hypothesis_select_with_llm(
980+
scenario_desc=scenario_desc,
981+
exp_feedback_list_desc=exp_feedback_list_desc,
982+
sota_exp_desc=sota_exp_desc,
983+
hypothesis_candidates=hypothesis_dict,
984+
)
981985
component_map = {
982986
"Model": HypothesisComponent.Model,
983987
"Ensemble": HypothesisComponent.Ensemble,
@@ -992,7 +996,7 @@ def gen(
992996
if comp_str in component_map and hypo_str is not None:
993997
new_hypothesis = DSHypothesis(component=component_map[comp_str], hypothesis=hypo_str)
994998

995-
pickled_problem_name= None
999+
pickled_problem_name = None
9961000
# Step 3.5: Update knowledge base with the picked problem
9971001
if DS_RD_SETTING.enable_knowledge_base:
9981002
trace.knowledge_base.update_pickled_problem(all_problems, pickled_problem_name)

rdagent/scenarios/data_science/proposal/exp_gen/router/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from rdagent.scenarios.data_science.proposal.exp_gen.draft.draft import DSDraftV2ExpGen
66
from rdagent.scenarios.data_science.proposal.exp_gen.proposal import DSProposalV2ExpGen
77

8+
89
class DraftRouterExpGen(ExpGen):
910
"""
1011
A intelligent router for drafting and proposing.
@@ -29,4 +30,4 @@ def __init__(self, *args, **kwargs):
2930
super().__init__(*args, **kwargs)
3031

3132
def gen(self, trace: DSTrace) -> DSExperiment:
32-
pass
33+
pass

rdagent/scenarios/data_science/scen/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def get_competition_full_desc(self) -> str:
140140
metric_direction=self.metric_direction,
141141
raw_description=self.raw_description,
142142
use_raw_description=DS_RD_SETTING.use_raw_description,
143-
ensemble_limit = DS_RD_SETTING.ensemble_timeout,
143+
ensemble_limit=DS_RD_SETTING.ensemble_timeout,
144144
time_limit=None,
145145
eda_output=None,
146146
)
@@ -158,7 +158,7 @@ def get_scenario_all_desc(self, eda_output=None) -> str:
158158
raw_description=self.raw_description,
159159
use_raw_description=DS_RD_SETTING.use_raw_description,
160160
time_limit=f"{DS_RD_SETTING.full_timeout / 60 / 60 : .2f} hours",
161-
ensemble_limit = DS_RD_SETTING.ensemble_timeout,
161+
ensemble_limit=DS_RD_SETTING.ensemble_timeout,
162162
eda_output=eda_output,
163163
)
164164

0 commit comments

Comments
 (0)