Skip to content

Commit fcda159

Browse files
committed
remove ensemble from hypo_gen
1 parent e2f537a commit fcda159

File tree

3 files changed

+10
-31
lines changed

3 files changed

+10
-31
lines changed

rdagent/scenarios/data_science/proposal/exp_gen/ensemble/ensemble.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
23
from rdagent.core.proposal import ExpGen
34
from rdagent.core.scenario import Scenario
45
from rdagent.oai.llm_utils import APIBackend

rdagent/scenarios/data_science/proposal/exp_gen/prompts_v2.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ output_format:
573573
"problem name 1 (should be exactly same as the problem name provided)": {
574574
{% if enable_idea_pool %}"inspired": "True or False. Set to True if the hypothesis is inspired by the user provided ideas. Otherwise, set it to False.",{% endif %}
575575
"reason": "Provide a clear, logical progression from problem identification to hypothesis formulation, grounded in evidence (e.g., trace history, domain principles, or competition constraints). Refer to the Hypothesis Guidelines for better understanding. Reason should be short with no more than two sentences.",
576-
"component": "The component tag of the hypothesis. Must be one of ('DataLoadSpec', 'FeatureEng', 'Model', 'Ensemble', 'Workflow').",
576+
"component": "The component tag of the hypothesis. Must be one of ('DataLoadSpec', 'FeatureEng', 'Model', 'Workflow').",
577577
"hypothesis": "A concise, testable statement derived from previous experimental outcomes. Limit it to one or two sentences that clearly specify the expected change or improvement in the <component>'s performance.",
578578
"evaluation": {
579579
"alignment_score": "The alignment of the proposed hypothesis with the identified problem.",

rdagent/scenarios/data_science/proposal/exp_gen/proposal.py

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1+
import asyncio
12
import json
3+
import re
24
from enum import Enum
35
from typing import Any, Dict, List, Optional, Tuple
46

57
import pandas as pd
68
from pydantic import BaseModel, Field
7-
from rdagent.oai.backend.base import RD_Agent_TIMER_wrapper
8-
from rdagent.log.timer import RDAgentTimer
9-
from rdagent.core.conf import RD_AGENT_SETTINGS
10-
import asyncio
119

1210
from rdagent.app.data_science.conf import DS_RD_SETTING
1311
from rdagent.components.coder.data_science.ensemble.exp import EnsembleTask
@@ -16,9 +14,12 @@
1614
from rdagent.components.coder.data_science.pipeline.exp import PipelineTask
1715
from rdagent.components.coder.data_science.raw_data_loader.exp import DataLoaderTask
1816
from rdagent.components.coder.data_science.workflow.exp import WorkflowTask
17+
from rdagent.core.conf import RD_AGENT_SETTINGS
1918
from rdagent.core.proposal import ExpGen
2019
from rdagent.core.scenario import Scenario
2120
from rdagent.log import rdagent_logger as logger
21+
from rdagent.log.timer import RDAgentTimer
22+
from rdagent.oai.backend.base import RD_Agent_TIMER_wrapper
2223
from rdagent.oai.llm_utils import APIBackend, md5_hash
2324
from rdagent.scenarios.data_science.dev.feedback import ExperimentFeedback
2425
from rdagent.scenarios.data_science.experiment.experiment import DSExperiment
@@ -30,7 +31,6 @@
3031
from rdagent.utils.agent.tpl import T
3132
from rdagent.utils.repo.diff import generate_diff_from_dict
3233
from rdagent.utils.workflow import wait_retry
33-
import re
3434

3535
_COMPONENT_META: Dict[str, Dict[str, Any]] = {
3636
"DataLoadSpec": {
@@ -583,8 +583,6 @@ def hypothesis_gen(
583583
sys_prompt = T(".prompts_v2:hypothesis_gen.system").r(
584584
hypothesis_output_format=(
585585
T(".prompts_v2:output_format.hypothesis").r(pipeline=pipeline, enable_idea_pool=enable_idea_pool)
586-
if not self.supports_response_schema
587-
else None
588586
),
589587
pipeline=pipeline,
590588
enable_idea_pool=enable_idea_pool,
@@ -600,30 +598,10 @@ def hypothesis_gen(
600598
response = APIBackend().build_messages_and_create_chat_completion(
601599
user_prompt=user_prompt,
602600
system_prompt=sys_prompt,
603-
response_format=HypothesisList if self.supports_response_schema else {"type": "json_object"},
604-
json_target_type=(
605-
Dict[str, Dict[str, str | Dict[str, str | int]]] if not self.supports_response_schema else None
606-
),
601+
response_format={"type": "json_object"},
602+
json_target_type=Dict[str, Dict[str, str | Dict[str, str | int]]],
607603
)
608-
if self.supports_response_schema:
609-
hypotheses = HypothesisList(**json.loads(response))
610-
resp_dict = {
611-
h.caption: {
612-
"reason": h.challenge,
613-
"component": h.component,
614-
"hypothesis": h.hypothesis,
615-
"evaluation": {
616-
"alignment_score": h.evaluation.alignment.score,
617-
"impact_score": h.evaluation.impact.score,
618-
"novelty_score": h.evaluation.novelty.score,
619-
"feasibility_score": h.evaluation.feasibility.score,
620-
"risk_reward_balance_score": h.evaluation.risk_reward_balance.score,
621-
},
622-
}
623-
for h in hypotheses.hypotheses
624-
}
625-
else:
626-
resp_dict = json.loads(response)
604+
resp_dict = json.loads(response)
627605
logger.info(f"Generated hypotheses:\n" + json.dumps(resp_dict, indent=2))
628606

629607
# make sure the problem name is aligned

0 commit comments

Comments
 (0)