Skip to content

Commit f9c99b1

Browse files
chore: remove torch in main code (#439)
* use np.ndarray in costeer model evaluators * remove package * fix test error --------- Co-authored-by: Xu Yang <[email protected]>
1 parent ddf6d42 commit f9c99b1

File tree

7 files changed

+14
-71
lines changed

7 files changed

+14
-71
lines changed

constraints/3.10.txt

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ isort==5.13.2
7171
jaraco.classes==3.3.0
7272
jedi==0.19.1
7373
jeepney==0.8.0
74-
Jinja2==3.1.2
7574
joblib==1.4.2
7675
json5==0.9.25
7776
jsonpatch==1.33
@@ -102,15 +101,13 @@ loguru==0.7.2
102101
loguru-mypy==0.0.4
103102
lxml==5.0.0
104103
markdown-it-py==3.0.0
105-
MarkupSafe==2.1.3
106104
marshmallow==3.20.1
107105
matplotlib==3.9.1
108106
matplotlib-inline==0.1.7
109107
mdit-py-plugins==0.4.0
110108
mdurl==0.1.2
111109
mistune==3.0.2
112110
more-itertools==10.1.0
113-
mpmath==1.3.0
114111
msal==1.30.0
115112
msal-extensions==1.2.0
116113
msgpack==1.0.8
@@ -124,24 +121,11 @@ nbconvert==7.16.4
124121
nbformat==5.10.4
125122
ndindex==1.8
126123
nest-asyncio==1.6.0
127-
networkx==3.2.1
128124
nh3==0.2.15
129125
notebook==7.2.1
130126
notebook_shim==0.2.4
131127
numexpr==2.10.1
132128
numpy==1.26.2
133-
nvidia-cublas-cu12==12.1.3.1
134-
nvidia-cuda-cupti-cu12==12.1.105
135-
nvidia-cuda-nvrtc-cu12==12.1.105
136-
nvidia-cuda-runtime-cu12==12.1.105
137-
nvidia-cudnn-cu12==8.9.2.26
138-
nvidia-cufft-cu12==11.0.2.54
139-
nvidia-curand-cu12==10.3.2.106
140-
nvidia-cusolver-cu12==11.4.5.107
141-
nvidia-cusparse-cu12==12.1.0.106
142-
nvidia-nccl-cu12==2.18.1
143-
nvidia-nvjitlink-cu12==12.3.101
144-
nvidia-nvtx-cu12==12.1.105
145129
oauthlib==3.2.2
146130
openai==1.6.1
147131
overrides==7.7.0
@@ -153,7 +137,6 @@ parso==0.8.4
153137
pathspec==0.12.1
154138
patsy==0.5.6
155139
pexpect==4.9.0
156-
pillow==10.4.0
157140
pkginfo==1.9.6
158141
platformdirs==4.1.0
159142
pluggy==1.3.0
@@ -201,7 +184,6 @@ ruamel.yaml==0.18.5
201184
ruamel.yaml.clib==0.2.8
202185
ruff==0.4.5
203186
scikit-learn==1.5.1
204-
scipy==1.11.4
205187
SecretStorage==3.3.3
206188
semver==3.0.2
207189
Send2Trash==1.8.3
@@ -226,7 +208,6 @@ sphinxcontrib-serializinghtml==1.1.9
226208
SQLAlchemy==2.0.24
227209
stack-data==0.6.3
228210
statsmodels==0.14.2
229-
sympy==1.12
230211
tables==3.9.2
231212
tabulate==0.9.0
232213
tenacity==8.2.3
@@ -238,22 +219,18 @@ tinycss2==1.3.0
238219
toml-sort==0.23.1
239220
tomli==2.0.1
240221
tomlkit==0.12.3
241-
torch==2.1.2
242-
torch_geometric==2.5.3
243222
tornado==6.4
244223
tqdm==4.66.1
245224
traitlets==5.14.3
246225
tree-sitter==0.22.3
247226
tree-sitter-python==0.21.0
248-
triton==2.1.0
249227
twine==4.0.2
250228
typer==0.9.0
251229
types-psutil==6.0.0.20240621
252230
types-python-dateutil==2.9.0.20240316
253231
types-PyYAML==6.0.12.20240724
254232
types-tqdm==4.66.0.20240417
255233
typing-inspect==0.9.0
256-
typing_extensions==4.9.0
257234
tzdata==2023.4
258235
uri-template==1.3.0
259236
urllib3==2.1.0

constraints/3.11.txt

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ isort==5.13.2
6969
jaraco.classes==3.3.0
7070
jedi==0.19.1
7171
jeepney==0.8.0
72-
Jinja2==3.1.2
7372
joblib==1.4.2
7473
json5==0.9.25
7574
jsonpatch==1.33
@@ -100,15 +99,13 @@ loguru==0.7.2
10099
loguru-mypy==0.0.4
101100
lxml==5.0.0
102101
markdown-it-py==3.0.0
103-
MarkupSafe==2.1.3
104102
marshmallow==3.20.1
105103
matplotlib==3.9.1
106104
matplotlib-inline==0.1.7
107105
mdit-py-plugins==0.4.0
108106
mdurl==0.1.2
109107
mistune==3.0.2
110108
more-itertools==10.1.0
111-
mpmath==1.3.0
112109
msal==1.30.0
113110
msal-extensions==1.2.0
114111
msgpack==1.0.8
@@ -122,24 +119,11 @@ nbconvert==7.16.4
122119
nbformat==5.10.4
123120
ndindex==1.8
124121
nest-asyncio==1.6.0
125-
networkx==3.2.1
126122
nh3==0.2.15
127123
notebook==7.2.1
128124
notebook_shim==0.2.4
129125
numexpr==2.10.1
130126
numpy==1.26.2
131-
nvidia-cublas-cu12==12.1.3.1
132-
nvidia-cuda-cupti-cu12==12.1.105
133-
nvidia-cuda-nvrtc-cu12==12.1.105
134-
nvidia-cuda-runtime-cu12==12.1.105
135-
nvidia-cudnn-cu12==8.9.2.26
136-
nvidia-cufft-cu12==11.0.2.54
137-
nvidia-curand-cu12==10.3.2.106
138-
nvidia-cusolver-cu12==11.4.5.107
139-
nvidia-cusparse-cu12==12.1.0.106
140-
nvidia-nccl-cu12==2.18.1
141-
nvidia-nvjitlink-cu12==12.3.101
142-
nvidia-nvtx-cu12==12.1.105
143127
oauthlib==3.2.2
144128
openai==1.6.1
145129
overrides==7.7.0
@@ -151,7 +135,6 @@ parso==0.8.4
151135
pathspec==0.12.1
152136
patsy==0.5.6
153137
pexpect==4.9.0
154-
pillow==10.4.0
155138
pkginfo==1.9.6
156139
platformdirs==4.1.0
157140
pluggy==1.3.0
@@ -199,7 +182,6 @@ ruamel.yaml==0.18.5
199182
ruamel.yaml.clib==0.2.8
200183
ruff==0.4.5
201184
scikit-learn==1.5.1
202-
scipy==1.11.4
203185
SecretStorage==3.3.3
204186
semver==3.0.2
205187
Send2Trash==1.8.3
@@ -224,7 +206,6 @@ sphinxcontrib-serializinghtml==1.1.9
224206
SQLAlchemy==2.0.24
225207
stack-data==0.6.3
226208
statsmodels==0.14.2
227-
sympy==1.12
228209
tables==3.9.2
229210
tabulate==0.9.0
230211
tenacity==8.2.3
@@ -235,22 +216,18 @@ tiktoken==0.7.0
235216
tinycss2==1.3.0
236217
toml-sort==0.23.1
237218
tomlkit==0.12.3
238-
torch==2.1.2
239-
torch_geometric==2.5.3
240219
tornado==6.4
241220
tqdm==4.66.1
242221
traitlets==5.14.3
243222
tree-sitter==0.22.3
244223
tree-sitter-python==0.21.0
245-
triton==2.1.0
246224
twine==4.0.2
247225
typer==0.9.0
248226
types-psutil==6.0.0.20240621
249227
types-python-dateutil==2.9.0.20240316
250228
types-PyYAML==6.0.12.20240724
251229
types-tqdm==4.66.0.20240417
252230
typing-inspect==0.9.0
253-
typing_extensions==4.9.0
254231
tzdata==2023.4
255232
uri-template==1.3.0
256233
urllib3==2.1.0

rdagent/components/coder/model_coder/CoSTEER/evaluators.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import List, Tuple
55

66
import numpy as np
7-
import torch
87
from jinja2 import Environment, StrictUndefined
98

109
from rdagent.components.coder.model_coder.conf import MODEL_IMPL_SETTINGS
@@ -25,7 +24,7 @@
2524
evaluate_prompts = Prompts(file_path=Path(__file__).parent.parent / "prompts.yaml")
2625

2726

28-
def shape_evaluator(prediction: torch.Tensor | np.ndarray, target_shape: Tuple = None) -> Tuple[str, bool]:
27+
def shape_evaluator(prediction: np.ndarray, target_shape: Tuple = None) -> Tuple[str, bool]:
2928
if target_shape is None or prediction is None:
3029
return (
3130
"No output generated from the model. No shape evaluation conducted.",
@@ -42,18 +41,10 @@ def shape_evaluator(prediction: torch.Tensor | np.ndarray, target_shape: Tuple =
4241
)
4342

4443

45-
def reshape_tensor(original_tensor, target_shape):
46-
new_tensor = torch.zeros(target_shape)
47-
for i, dim in enumerate(original_tensor.shape):
48-
new_tensor = new_tensor.narrow(i, 0, dim).copy_(original_tensor)
49-
50-
return new_tensor
51-
52-
5344
def value_evaluator(
54-
prediction: torch.Tensor,
55-
target: torch.Tensor,
56-
) -> Tuple[torch.Tensor, bool]:
45+
prediction: np.ndarray,
46+
target: np.ndarray,
47+
) -> Tuple[np.ndarray, bool]:
5748
if prediction is None:
5849
return "No output generated from the model. Skip value evaluation", False
5950
elif target is None:
@@ -63,7 +54,7 @@ def value_evaluator(
6354
)
6455
else:
6556
# Calculate the mean absolute difference
66-
diff = torch.mean(torch.abs(target - prediction)).item()
57+
diff = np.mean(np.abs(target - prediction))
6758
return (
6859
f"The value of the output is correct. The mean absolute difference is {diff}.",
6960
diff < 0.1,
@@ -273,7 +264,7 @@ def evaluate(
273264
param_init_value = 0.6
274265

275266
assert isinstance(implementation, ModelFBWorkspace)
276-
model_execution_feedback, gen_tensor = implementation.execute(
267+
model_execution_feedback, gen_np_array = implementation.execute(
277268
batch_size=batch_size,
278269
num_features=num_features,
279270
num_timesteps=num_timesteps,
@@ -282,18 +273,18 @@ def evaluate(
282273
)
283274
if gt_implementation is not None:
284275
assert isinstance(gt_implementation, ModelFBWorkspace)
285-
_, gt_tensor = gt_implementation.execute(
276+
_, gt_np_array = gt_implementation.execute(
286277
batch_size=batch_size,
287278
num_features=num_features,
288279
num_timesteps=num_timesteps,
289280
input_value=input_value,
290281
param_init_value=param_init_value,
291282
)
292283
else:
293-
gt_tensor = None
284+
gt_np_array = None
294285

295-
shape_feedback, shape_decision = shape_evaluator(gen_tensor, (batch_size, 1))
296-
value_feedback, value_decision = value_evaluator(gen_tensor, gt_tensor)
286+
shape_feedback, shape_decision = shape_evaluator(gen_np_array, (batch_size, 1))
287+
value_feedback, value_decision = value_evaluator(gen_np_array, gt_np_array)
297288
code_feedback, _ = ModelCodeEvaluator(scen=self.scen).evaluate(
298289
target_task=target_task,
299290
implementation=implementation,

rdagent/components/coder/model_coder/model_execute_template_v1.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ if MODEL_TYPE == "Graph":
3737
else:
3838
out = m(data)
3939

40-
execution_model_output = out.cpu().detach()
40+
execution_model_output = out.cpu().detach().numpy()
4141
execution_feedback_str = f"Execution successful, output tensor shape: {execution_model_output.shape}"
4242

4343
pickle.dump(execution_model_output, open("execution_model_output.pkl", "wb"))

rdagent/components/coder/model_coder/model_execute_template_v2.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ valid_X = pd.DataFrame(np.random.randn(8, 30), columns=[f"{i}" for i in range(30
1212
valid_y = pd.Series(np.random.randint(0, 2, 8))
1313

1414
model = fit(train_X, train_y, valid_X, valid_y)
15-
execution_model_output = predict(model, valid_X)
15+
execution_model_output = predict(model, valid_X).cpu().detach().numpy()
1616

1717
execution_feedback_str = f"Execution successful, output numpy ndarray shape: {execution_model_output.shape}"
1818

requirements.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ pydantic-settings
33
typer[all]
44

55
cython
6-
scipy
76
python-Levenshtein
87
scikit-learn
98
filelock
@@ -14,8 +13,6 @@ fuzzywuzzy
1413
openai
1514

1615
ruamel-yaml
17-
torch
18-
torch_geometric
1916
tabulate # Convert pandas dataframe to markdown table to make it more readable to LLM
2017
numpy # we use numpy as default data format. So we have to install numpy
2118
pandas # we use pandas as default data format. So we have to install pandas
@@ -69,5 +66,4 @@ seaborn
6966
setuptools-scm
7067

7168
# This is a temporary package installed to pass the test_import test
72-
xgboost
7369
lightgbm

test/utils/test_import.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def import_all_modules_from_directory(directory):
2121
continue
2222
if "_template" in fstr:
2323
continue
24+
if "model_coder" in fstr:
25+
continue
2426
if (
2527
fstr.endswith("rdagent/log/ui/app.py")
2628
or fstr.endswith("rdagent/app/cli.py")

0 commit comments

Comments
 (0)