Skip to content

Commit 8fe2ba1

Browse files
committed
saves results of eval of model
1 parent fa1f0b9 commit 8fe2ba1

File tree

5 files changed

+82
-7
lines changed

5 files changed

+82
-7
lines changed

chebai/loss/boost_bce.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
import torch
2-
import sys
3-
sys.path.insert(1,'/home/programmer/Bachelorarbeit/python-chebai')
4-
52
import extras.weight_loader as f
63

74

chebai/models/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
from chebai.preprocessing.structures import XYData
99

10-
import sys
11-
sys.path.insert(1,'/home/programmer/Bachelorarbeit/python-chebai')
1210

1311
import extras.weight_loader as f
1412

chebai/preprocessing/datasets/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from torch.utils.data import DataLoader
1919

2020
from chebai.preprocessing import reader as dr
21-
import sys
22-
sys.path.insert(1,'/home/programmer/Bachelorarbeit/python-chebai')
2321

2422
import extras.weight_loader as f
2523

chebai/result/utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from chebai.preprocessing.datasets.base import XYBaseDataModule
1515
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
1616

17+
from extras.ev_model import create_weight_dict
18+
1719

1820
def get_checkpoint_from_wandb(
1921
epoch: int,
@@ -117,6 +119,7 @@ def evaluate_model(
117119
data_list = data_list[: data_module.data_limit]
118120
preds_list = []
119121
labels_list = []
122+
weights_list = []
120123
if buffer_dir is not None:
121124
os.makedirs(buffer_dir, exist_ok=True)
122125
save_ind = 0
@@ -132,6 +135,8 @@ def evaluate_model(
132135
preds, labels = _run_batch(data_list[i : i + batch_size], model, collate)
133136
preds_list.append(preds)
134137
labels_list.append(labels)
138+
for j in range(i,i+batch_size):
139+
weights_list.append(data_list[j])
135140

136141
if buffer_dir is not None:
137142
if n_saved * batch_size >= save_batch_size:
@@ -170,6 +175,68 @@ def evaluate_model(
170175
)
171176

172177

178+
def evaluate_model_weights(
179+
model: ChebaiBaseNet,
180+
data_module: XYBaseDataModule,
181+
filename: Optional[str] = None,
182+
buffer_dir: Optional[str] = None,
183+
batch_size: int = 32,
184+
skip_existing_preds: bool = False,
185+
kind: str = "test",
186+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
187+
"""
188+
Runs the model on the test set of the data module or on the dataset found in the specified file.
189+
If buffer_dir is set, results will be saved in buffer_dir.
190+
191+
Note:
192+
No need to provide "filename" parameter for Chebi dataset, "kind" parameter should be provided.
193+
194+
Args:
195+
model: The model to evaluate.
196+
data_module: The data module containing the dataset.
197+
filename: Optional file name for the dataset.
198+
buffer_dir: Optional directory to save the results.
199+
batch_size: The batch size for evaluation.
200+
skip_existing_preds: Whether to skip evaluation if predictions already exist.
201+
kind: Kind of split of the data to be used for testing the model. Default is `test`.
202+
203+
Returns:
204+
Tensors with predictions and labels.
205+
"""
206+
model.eval()
207+
collate = data_module.reader.COLLATOR()
208+
209+
if isinstance(data_module, _ChEBIDataExtractor):
210+
# As the dynamic split change is implemented only for chebi-dataset as of now
211+
data_df = data_module.dynamic_split_dfs[kind]
212+
data_list = data_df.to_dict(orient="records")
213+
else:
214+
data_list = data_module.load_processed_data("test", filename)
215+
data_list = data_list[: data_module.data_limit]
216+
preds_list = []
217+
labels_list = []
218+
weights_list = []
219+
if buffer_dir is not None:
220+
os.makedirs(buffer_dir, exist_ok=True)
221+
save_ind = 0
222+
save_batch_size = 128
223+
n_saved = 1
224+
225+
print("")
226+
for i in tqdm.tqdm(range(0, len(data_list), batch_size)):
227+
if not (
228+
skip_existing_preds
229+
and os.path.isfile(os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"))
230+
):
231+
preds, labels = _run_batch(data_list[i : i + batch_size], model, collate)
232+
preds_list.append(preds)
233+
labels_list.append(labels)
234+
235+
result = create_weight_dict(preds_list,labels_list,data_list)
236+
torch.save(result,"./result.pt")
237+
238+
239+
173240
def load_results_from_buffer(
174241
buffer_dir: str, device: torch.device
175242
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:

extras/ev_model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
def create_weight_dict(p,l,data_list):
2+
result = []
3+
i = 0
4+
for j in range(0,len(p)):
5+
for k in range(0,len(p[j])):
6+
d = {}
7+
pred = p[j][k]
8+
label = l[j][k]
9+
ident = data_list[i]["idents"]
10+
d["pred"]= pred
11+
d["label"]= label
12+
d["ident"]= ident
13+
result.append(d)
14+
i = i + 1
15+
return result

0 commit comments

Comments
 (0)