Skip to content

Commit 93e9b73

Browse files
committed
save prediction to csv for predict operation mode
1 parent 9fc5d20 commit 93e9b73

File tree

5 files changed

+36
-14
lines changed

5 files changed

+36
-14
lines changed

chebai/ensemble/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from ._base import EnsembleBase
21
from ._consolidator import WeightedMajorityVoting
32
from ._controller import NoActivationCondition
43
from ._wrappers import NNWrapper

chebai/ensemble/_base.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from abc import ABC, abstractmethod
22
from collections import deque
33
from pathlib import Path
4-
from typing import Any, Deque, Dict, Optional
4+
from typing import Any, Deque, Dict
55

66
import pandas as pd
77
import torch
8-
from lightning import LightningModule
98

109
from chebai.result.classification import print_metrics
1110

@@ -29,7 +28,7 @@ def __init__(
2928
self,
3029
model_configs: Dict[str, Dict[str, Any]],
3130
data_processed_dir_main: str,
32-
operation: str = EVAL_OP,
31+
operation_mode: str = EVAL_OP,
3332
**kwargs: Any,
3433
) -> None:
3534
"""
@@ -42,13 +41,13 @@ def __init__(
4241
"""
4342
if bool(kwargs.get("_perform_validation_checks", True)):
4443
self._perform_validation_checks(
45-
model_configs, operation=operation, **kwargs
44+
model_configs, operation=operation_mode, **kwargs
4645
)
4746

4847
self._model_configs: Dict[str, Dict[str, Any]] = model_configs
4948
self._data_processed_dir_main: str = data_processed_dir_main
50-
self._operation: str = operation
51-
print(f"Ensemble operation: {self._operation}")
49+
self._operation_mode: str = operation_mode
50+
print(f"Ensemble operation: {self._operation_mode}")
5251

5352
# These instance variable will be set in method `_process_input_to_ensemble`
5453
self._total_data_size: int | None = None
@@ -126,7 +125,7 @@ def _perform_validation_checks(
126125
labels_set.add(model_labels_path)
127126

128127
def _process_input_to_ensemble(self, **kwargs: Any) -> list[str] | Path:
129-
if self._operation == PRED_OP:
128+
if self._operation_mode == PRED_OP:
130129
p = Path(kwargs["smiles_list_file_path"])
131130
smiles_list: list[str] = []
132131
with open(p, "r") as f:
@@ -138,7 +137,7 @@ def _process_input_to_ensemble(self, **kwargs: Any) -> list[str] | Path:
138137
smiles_list.append(smiles)
139138
self._total_data_size = len(smiles_list)
140139
return smiles_list
141-
elif self._operation == EVAL_OP:
140+
elif self._operation_mode == EVAL_OP:
142141
processed_dir_path = Path(self._data_processed_dir_main)
143142
data_pkl_path = processed_dir_path / "data.pkl"
144143
if not data_pkl_path.exists():
@@ -183,7 +182,7 @@ def run_ensemble(self) -> None:
183182
)
184183

185184
print(
186-
f"Running {self.__class__.__name__} ensemble for {self._operation} operation..."
185+
f"Running {self.__class__.__name__} ensemble for {self._operation_mode} operation..."
187186
)
188187
while self._model_queue:
189188
model_name = self._model_queue.popleft()
@@ -204,7 +203,7 @@ def run_ensemble(self) -> None:
204203
true_scores=true_scores, false_scores=false_scores
205204
)
206205

207-
if self._operation == EVAL_OP:
206+
if self._operation_mode == EVAL_OP:
208207
assert (
209208
self._collated_labels is not None
210209
), "Collated labels must be set for evaluation operation."
@@ -214,6 +213,31 @@ def run_ensemble(self) -> None:
214213
self._device,
215214
classes=list(self._dm_labels.keys()),
216215
)
216+
else:
217+
# Get SMILES and label names
218+
smiles_list = self._ensemble_input
219+
label_names = list(self._dm_labels.keys())
220+
# Efficient conversion from tensor to NumPy
221+
preds_np = final_preds.detach().cpu().numpy()
222+
223+
assert (
224+
len(smiles_list) == preds_np.shape[0]
225+
), "Length of SMILES list does not match number of predictions."
226+
assert (
227+
len(label_names) == preds_np.shape[1]
228+
), "Number of label names does not match number of predictions."
229+
230+
# Build DataFrame
231+
df = pd.DataFrame(preds_np, columns=label_names)
232+
df.insert(0, "SMILES", smiles_list)
233+
234+
# Save to CSV
235+
output_path = (
236+
Path(self._data_processed_dir_main) / "ensemble_predictions.csv"
237+
)
238+
df.to_csv(output_path, index=False)
239+
240+
print(f"Predictions saved to {output_path}")
217241

218242
@abstractmethod
219243
def _controller(

chebai/ensemble/_controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _controller(
5252
Dict[str, Tensor]: Dictionary containing predictions and confidence scores.
5353
"""
5454
wrapped_model = self._wrap_model(model_name)
55-
if self._operation == PRED_OP:
55+
if self._operation_mode == PRED_OP:
5656
model_output, model_props = wrapped_model.predict(model_input)
5757
else:
5858
model_output, model_props = wrapped_model.evaluate(model_input)

chebai/ensemble/_wrappers/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _generate_model_label_props(self, dm_labels) -> dict[str, torch.Tensor]:
4343
except Exception as e:
4444
raise Exception(
4545
f"Label '{label}' has an unexpected error \n Error: {e}"
46-
)
46+
) from e
4747

4848
model_label_indices.append(dm_labels[label])
4949
tpv_label_values.append(props["TPV"])

chebai/ensemble/_wrappers/_neural_network.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def _predict_from_list_of_smiles(self, smiles_list: list[str]) -> list:
126126
else:
127127
index_map[i] = len(token_dicts)
128128
token_dicts.append(d)
129-
print(f"Predicting {len(token_dicts), token_dicts} out of {len(smiles_list)}")
130129
if token_dicts:
131130
model_output = self._forward_pass(token_dicts)
132131
if not isinstance(model_output, dict) and not "logits" in model_output:

0 commit comments

Comments
 (0)