Skip to content

Commit 953a858

Browse files
committed
remove other commits and focus on add more package info for persistent problems
1 parent d84c7a7 commit 953a858

File tree

3 files changed

+95
-19
lines changed

3 files changed

+95
-19
lines changed

rdagent/scenarios/data_science/proposal/exp_gen/package_info.py

Lines changed: 73 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,72 @@ def get_installed_packages():
66
return {dist.metadata["Name"].lower(): dist.version for dist in distributions()}
77

88

9+
# Kaggle competition packages - based on usage frequency
10+
PYTHON_BASE_PACKAGES = ["catboost", "lightgbm", "numpy", "optuna", "pandas", "scikit-learn", "scipy", "shap", "xgboost"]
11+
12+
PYTHON_ADVANCED_PACKAGES = [
13+
"accelerate",
14+
"albumentations",
15+
"category_encoders",
16+
"cudf-cu12",
17+
"cuml-cu12",
18+
"datasets",
19+
"featuretools",
20+
"imbalanced-learn",
21+
"opencv-python",
22+
"pillow",
23+
"polars",
24+
"sentence-transformers",
25+
"spacy",
26+
"tensorflow",
27+
"timm",
28+
"tokenizers",
29+
"torch",
30+
"torchvision",
31+
"transformers",
32+
]
33+
34+
PYTHON_AUTO_ML_PACKAGES = ["autogluon"]
35+
36+
37+
def get_available_packages_prompt():
38+
"""Generate prompt template for dynamically detected available packages"""
39+
installed_packages = get_installed_packages()
40+
41+
# Check which packages are actually installed
42+
base_available = [pkg for pkg in PYTHON_BASE_PACKAGES if pkg.lower() in installed_packages]
43+
advanced_available = [pkg for pkg in PYTHON_ADVANCED_PACKAGES if pkg.lower() in installed_packages]
44+
automl_available = [pkg for pkg in PYTHON_AUTO_ML_PACKAGES if pkg.lower() in installed_packages]
45+
46+
# Build prompt
47+
prompt_parts = ["Available packages in environment:\n"]
48+
49+
if base_available:
50+
prompt_parts.append("【Basic Libraries】(core tools for most competitions):")
51+
prompt_parts.append(f"- {', '.join(base_available)}")
52+
prompt_parts.append("")
53+
54+
if advanced_available:
55+
prompt_parts.append("【Advanced Tools】(specialized for specific domains):")
56+
prompt_parts.append(f"- {', '.join(advanced_available)}")
57+
prompt_parts.append("")
58+
59+
if automl_available:
60+
prompt_parts.append("【AutoML Tools】(automated machine learning):")
61+
prompt_parts.append(f"- {', '.join(automl_available)}")
62+
prompt_parts.append("")
63+
64+
prompt_parts.append("Choose appropriate tool combinations based on the competition type.")
65+
66+
return "\n".join(prompt_parts).strip()
67+
68+
69+
def get_all_available_packages():
70+
"""Get flattened list of all packages"""
71+
all_packages = PYTHON_BASE_PACKAGES + PYTHON_ADVANCED_PACKAGES + PYTHON_AUTO_ML_PACKAGES
72+
return sorted(set(all_packages))
73+
74+
975
def print_filtered_packages(installed_packages, filtered_packages):
1076
to_print = []
1177
for package_name in filtered_packages:
@@ -26,24 +92,8 @@ def get_python_packages():
2692
# Example: `python package_info.py pandas torch scikit-learn`
2793
# If no extra arguments are provided we fall back to the original default list
2894
# to keep full backward-compatibility.
29-
packages_list = [ # default packages
30-
"transformers",
31-
"accelerate",
32-
"torch",
33-
"tensorflow",
34-
"pandas",
35-
"numpy",
36-
"scikit-learn",
37-
"scipy",
38-
"xgboost",
39-
"sklearn",
40-
"lightgbm",
41-
"vtk",
42-
"opencv-python",
43-
"keras",
44-
"matplotlib",
45-
"pydicom",
46-
]
95+
# Use our Kaggle-optimized package list as default
96+
packages_list = get_all_available_packages()
4797
if len(sys.argv) > 1:
4898
packages_list = list(set(packages_list) | set(sys.argv[1:]))
4999

@@ -61,4 +111,8 @@ def get_python_packages():
61111

62112

63113
if __name__ == "__main__":
64-
get_python_packages()
114+
# Check for special argument to get prompt instead of package list
115+
if len(sys.argv) > 1 and sys.argv[1] == "--prompt":
116+
print(get_available_packages_prompt())
117+
else:
118+
get_python_packages()

rdagent/scenarios/data_science/proposal/exp_gen/proposal.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
DSDraftExpGen, # TODO: DSDraftExpGen should be moved to router in the further
2424
)
2525
from rdagent.scenarios.data_science.proposal.exp_gen.idea_pool import DSIdea
26+
from rdagent.scenarios.data_science.proposal.exp_gen.package_info import (
27+
get_available_packages_prompt,
28+
)
2629
from rdagent.scenarios.data_science.proposal.exp_gen.planner import (
2730
DSExperimentPlan,
2831
RD_Agent_TIMER_wrapper,
@@ -601,6 +604,12 @@ def hypothesis_gen(
601604
for i, (problem_name, problem_dict) in enumerate(problems.items()):
602605
problem_formatted_str += f"## {i+1}. {problem_name}\n"
603606
problem_formatted_str += f"{problem_dict['problem']}\n"
607+
608+
# Add package information for persistent problems
609+
if problem_dict.get("label") == "PERSISTENT_PROBLEM":
610+
packages_prompt = get_available_packages_prompt()
611+
problem_formatted_str += f"\n{packages_prompt}\n"
612+
604613
if "idea" in problem_dict:
605614
idea_formatted_str = DSIdea(problem_dict["idea"]).to_formatted_str()
606615
problem_formatted_str += f"Sampled Idea by user: \n{idea_formatted_str}\n"

rdagent/scenarios/data_science/proposal/exp_gen/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,16 @@ def get_packages(pkgs: list[str] | None = None) -> str:
103103
pkg_args = " ".join(pkgs) if pkgs else ""
104104
stdout = implementation.execute(env=env, entry=f"python {fname} {pkg_args}")
105105
return stdout
106+
107+
108+
def get_packages_prompt() -> str:
109+
"""Return available packages prompt information."""
110+
# Reuse package prompt cached during Draft stage when available.
111+
112+
env = get_ds_env()
113+
implementation = FBWorkspace()
114+
fname = "package_info.py"
115+
implementation.inject_files(**{fname: (Path(__file__).absolute().resolve().parent / "package_info.py").read_text()})
116+
117+
stdout = implementation.execute(env=env, entry=f"python {fname} --prompt")
118+
return stdout

0 commit comments

Comments
 (0)