Skip to content

Commit 270ff7c

Browse files
authored
refactor: Improve data handling and representation in eval tests and prompts (#740)
1 parent 759f295 commit 270ff7c

File tree

5 files changed

+40
-22
lines changed

5 files changed

+40
-22
lines changed

rdagent/components/coder/data_science/feature/eval_tests/feature_test.txt

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,33 @@
22
Tests for `feat_eng` in feature.py
33
"""
44

5-
import pickle
6-
from copy import deepcopy
75

6+
from copy import deepcopy
7+
import sys
88
import numpy as np
99
import pandas as pd
1010
from feature import feat_eng
1111
from load_data import load_data
12+
import reprlib
13+
aRepr = reprlib.Repr()
14+
aRepr.maxother=300
1215

1316
X, y, X_test, test_ids = load_data()
14-
print(f"X.shape: {X.shape}")
15-
print(f"y.shape: {y.shape}" if not isinstance(y, list) else f"y(list)'s length: {len(y)}")
16-
print(f"X_test.shape: {X_test.shape}")
17+
print("X:", aRepr.repr(X))
18+
print("y:", aRepr.repr(y))
19+
print("X_test:", aRepr.repr(X_test))
20+
print("test_ids", aRepr.repr(test_ids))
21+
22+
print(f"X.shape: {X.shape}" if hasattr(X, 'shape') else f"X length: {len(X)}")
23+
print(f"y.shape: {y.shape}" if hasattr(y, 'shape') else f"y length: {len(y)}")
24+
print(f"X_test.shape: {X_test.shape}" if hasattr(X_test, 'shape') else f"X_test length: {len(X_test)}")
1725
print(f"test_ids length: {len(test_ids)}")
26+
1827
X_loaded = deepcopy(X)
1928
y_loaded = deepcopy(y)
2029
X_test_loaded = deepcopy(X_test)
2130

22-
import sys
23-
import reprlib
2431
def debug_info_print(func):
25-
aRepr = reprlib.Repr()
26-
aRepr.maxother=300
2732
def wrapper(*args, **kwargs):
2833
def local_trace(frame, event, arg):
2934
if event == "return" and frame.f_code == func.__code__:
@@ -44,7 +49,7 @@ X, y, X_test = debug_info_print(feat_eng)(X, y, X_test)
4449

4550

4651
def get_length(data):
47-
return len(data) if isinstance(data, list) else data.shape[0]
52+
return data.shape[0] if hasattr(data, 'shape') else len(data)
4853

4954

5055
def get_width(data):

rdagent/components/coder/data_science/model/eval_tests/model_test.txt

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Tests for `model_workflow` in model01.py
33
"""
4+
import sys
45
import time
56

67
from feature import feat_eng
@@ -19,20 +20,33 @@ def log_execution_results(start_time, val_pred, test_pred, hypers, execution_lab
1920
print(feedback_str)
2021

2122

23+
import reprlib
24+
aRepr = reprlib.Repr()
25+
aRepr.maxother=300
26+
2227
# Load and preprocess data
2328
X, y, test_X, test_ids = load_data()
2429
X, y, test_X = feat_eng(X, y, test_X)
30+
31+
print(f"X.shape: {X.shape}" if hasattr(X, 'shape') else f"X length: {len(X)}")
32+
print(f"y.shape: {y.shape}" if hasattr(y, 'shape') else f"y length: {len(y)}")
33+
print(f"test_X.shape: {test_X.shape}" if hasattr(test_X, 'shape') else f"test_X length: {len(test_X)}")
34+
print(f"test_ids length: {len(test_ids)}")
35+
2536
train_X, val_X, train_y, val_y = train_test_split(X, y, test_size=0.8, random_state=42)
26-
print(f"train_X.shape: {train_X.shape}")
27-
print(f"train_y.shape: {train_y.shape}" if not isinstance(train_y, list) else f"train_y(list)'s length: {len(train_y)}")
28-
print(f"val_X.shape: {val_X.shape}")
29-
print(f"val_y.shape: {val_y.shape}" if not isinstance(val_y, list) else f"val_y(list)'s length: {len(val_y)}")
3037

31-
import sys
32-
import reprlib
38+
print("train_X:", aRepr.repr(train_X))
39+
print("train_y:", aRepr.repr(train_y))
40+
print("val_X:", aRepr.repr(val_X))
41+
print("val_y:", aRepr.repr(val_y))
42+
43+
print(f"train_X.shape: {train_X.shape}" if hasattr(train_X, 'shape') else f"train_X length: {len(train_X)}")
44+
print(f"train_y.shape: {train_y.shape}" if hasattr(train_y, 'shape') else f"train_y length: {len(train_y)}")
45+
print(f"val_X.shape: {val_X.shape}" if hasattr(val_X, 'shape') else f"val_X length: {len(val_X)}")
46+
print(f"val_y.shape: {val_y.shape}" if hasattr(val_y, 'shape') else f"val_y length: {len(val_y)}")
47+
48+
3349
def debug_info_print(func):
34-
aRepr = reprlib.Repr()
35-
aRepr.maxother=300
3650
def wrapper(*args, **kwargs):
3751
def local_trace(frame, event, arg):
3852
if event == "return" and frame.f_code == func.__code__:

rdagent/components/coder/data_science/raw_data_loader/eval_tests/data_loader_test.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ X, y, X_test, test_ids = debug_info_print(load_data)()
3333

3434

3535
def get_length(data):
36-
return len(data) if isinstance(data, list) else data.shape[0]
36+
return data.shape[0] if hasattr(data, 'shape') else len(data)
3737

3838

3939
def get_width(data):
40-
return 1 if isinstance(data, list) else data.shape[1:]
40+
return data.shape[1:] if hasattr(data, 'shape') else 1
4141

4242

4343
def get_column_list(data):

rdagent/components/coder/data_science/raw_data_loader/prompts.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,7 @@ data_loader_eval:
410410
The data loader is part of the whole workflow. The user has executed the entire pipeline and provided additional stdout.
411411
412412
**Workflow Code:**
413-
```python
414413
{{ workflow_code }}
415-
```
416414
417415
You should evaluate both the data loader test results and the overall workflow execution. **Approve the code only if both tests pass.**
418416
{% endif %}

rdagent/scenarios/data_science/share.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ component_spec:
101101
- Optimize memory usage for large datasets using techniques like downcasting or reading data in chunks if necessary.
102102
- Domain-Specific Handling:
103103
- Apply competition-specific preprocessing steps as needed (e.g., text tokenization, image resizing).
104+
- Instead of returning binary bytes directly, convert/decode them into more useful formats like numpy.ndarrays.
104105
105106
3. Code Standards:
106107
- DO NOT use progress bars (e.g., `tqdm`).

0 commit comments

Comments
 (0)