Skip to content

Commit 1581693

Browse files
committed
revert previous changes
1 parent 0322741 commit 1581693

File tree

1 file changed

+0
-66
lines changed

1 file changed

+0
-66
lines changed

chebai/result/utils.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ def evaluate_model(
119119
data_list = data_list[: data_module.data_limit]
120120
preds_list = []
121121
labels_list = []
122-
weights_list = []
123122
if buffer_dir is not None:
124123
os.makedirs(buffer_dir, exist_ok=True)
125124
save_ind = 0
@@ -135,8 +134,6 @@ def evaluate_model(
135134
preds, labels = _run_batch(data_list[i : i + batch_size], model, collate)
136135
preds_list.append(preds)
137136
labels_list.append(labels)
138-
for j in range(i,i+batch_size):
139-
weights_list.append(data_list[j])
140137

141138
if buffer_dir is not None:
142139
if n_saved * batch_size >= save_batch_size:
@@ -174,69 +171,6 @@ def evaluate_model(
174171
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),
175172
)
176173

177-
178-
def evaluate_model_weights(
179-
model: ChebaiBaseNet,
180-
data_module: XYBaseDataModule,
181-
filename: Optional[str] = None,
182-
buffer_dir: Optional[str] = None,
183-
batch_size: int = 32,
184-
skip_existing_preds: bool = False,
185-
kind: str = "test",
186-
):
187-
"""
188-
Runs the model on the test set of the data module or on the dataset found in the specified file.
189-
If buffer_dir is set, results will be saved in buffer_dir.
190-
191-
Note:
192-
No need to provide "filename" parameter for Chebi dataset, "kind" parameter should be provided.
193-
194-
Args:
195-
model: The model to evaluate.
196-
data_module: The data module containing the dataset.
197-
filename: Optional file name for the dataset.
198-
buffer_dir: Optional directory to save the results.
199-
batch_size: The batch size for evaluation.
200-
skip_existing_preds: Whether to skip evaluation if predictions already exist.
201-
kind: Kind of split of the data to be used for testing the model. Default is `test`.
202-
203-
Returns:
204-
Tensors with predictions and labels.
205-
"""
206-
model.eval()
207-
collate = data_module.reader.COLLATOR()
208-
209-
if isinstance(data_module, _ChEBIDataExtractor):
210-
# As the dynamic split change is implemented only for chebi-dataset as of now
211-
data_df = data_module.dynamic_split_dfs[kind]
212-
data_list = data_df.to_dict(orient="records")
213-
else:
214-
data_list = data_module.load_processed_data("test", filename)
215-
data_list = data_list[: data_module.data_limit]
216-
preds_list = []
217-
labels_list = []
218-
weights_list = []
219-
if buffer_dir is not None:
220-
os.makedirs(buffer_dir, exist_ok=True)
221-
save_ind = 0
222-
save_batch_size = 128
223-
n_saved = 1
224-
225-
print("")
226-
for i in tqdm.tqdm(range(0, len(data_list), batch_size)):
227-
if not (
228-
skip_existing_preds
229-
and os.path.isfile(os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"))
230-
):
231-
preds, labels = _run_batch(data_list[i : i + batch_size], model, collate)
232-
preds_list.append(preds)
233-
labels_list.append(labels)
234-
235-
result = create_weight_dict(preds_list,labels_list,data_list)
236-
torch.save(result,"./result.pt")
237-
238-
239-
240174
def load_results_from_buffer(
241175
buffer_dir: str, device: torch.device
242176
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:

0 commit comments

Comments
 (0)