Skip to content

Commit 0925a41

Browse files
authored
Merge pull request #127 from ChEB-AI/feature/new-ensemble-models
Feature/new ensemble models
2 parents 22c517c + 08f6071 commit 0925a41

File tree

17 files changed

+3873
-36
lines changed

17 files changed

+3873
-36
lines changed

chebai/callbacks/epoch_metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
6262
labels (torch.Tensor): Ground truth labels.
6363
"""
6464
tps = torch.sum(
65-
torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0
65+
torch.logical_and(preds > self.threshold, labels.to(torch.bool)),
66+
dim=0,
6667
)
6768
self.true_positives += tps
6869
self.positive_predictions += torch.sum(preds > self.threshold, dim=0)

chebai/models/base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
from lightning.pytorch.core.module import LightningModule
7+
from lightning.pytorch.utilities.rank_zero import rank_zero_info
78

89
from chebai.preprocessing.structures import XYData
910

@@ -106,7 +107,8 @@ def _get_prediction_and_labels(
106107
Returns:
107108
Tuple[torch.Tensor, torch.Tensor]: Predictions and labels.
108109
"""
109-
return output, labels
110+
# cast labels to int
111+
return output, labels.to(torch.int) if labels is not None else labels
110112

111113
def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor:
112114
"""
@@ -158,6 +160,13 @@ def _process_for_loss(
158160
"""
159161
return model_output, labels, loss_kwargs
160162

163+
def on_train_epoch_start(self) -> None:
164+
# pass current epoch to datamodule if it has the attribute curr_epoch (for PubChemBatched dataset)
165+
rank_zero_info(f"Starting epoch {self.current_epoch}")
166+
if hasattr(self.trainer.datamodule, "curr_epoch"):
167+
rank_zero_info(f"Setting datamodule.curr_epoch to {self.current_epoch}")
168+
self.trainer.datamodule.curr_epoch = self.current_epoch
169+
161170
def training_step(
162171
self, batch: XYData, batch_idx: int
163172
) -> Dict[str, Union[torch.Tensor, Any]]:
@@ -310,6 +319,8 @@ def _execute(
310319
for metric_name, metric in metrics.items():
311320
metric.update(pr, tar)
312321
self._log_metrics(prefix, metrics, len(batch))
322+
if isinstance(d, dict) and "loss" not in d:
323+
print(f"d has keys {d.keys()}, log={log}, criterion={self.criterion}")
313324
return d
314325

315326
def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int):

chebai/models/classic_ml.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import os
2+
import pickle as pkl
3+
from typing import Any, Dict, List, Optional
4+
5+
import numpy as np
6+
import torch
7+
import tqdm
8+
from sklearn.exceptions import NotFittedError
9+
from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression
10+
11+
from chebai.models.base import ChebaiBaseNet
12+
13+
LR_MODEL_PATH = os.path.join("models", "LR")
14+
15+
16+
class LogisticRegression(ChebaiBaseNet):
17+
"""
18+
Logistic Regression model using scikit-learn, wrapped to fit the ChebaiBaseNet interface.
19+
"""
20+
21+
def __init__(
22+
self,
23+
out_dim: int,
24+
input_dim: int,
25+
only_predict_classes: Optional[List] = None,
26+
n_classes=1528,
27+
**kwargs,
28+
):
29+
super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs)
30+
self.models = [
31+
SklearnLogisticRegression(solver="liblinear") for _ in range(n_classes)
32+
]
33+
# indices of classes (in the dataset used for training) where a model should be trained
34+
self.only_predict_classes = only_predict_classes
35+
36+
def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor:
37+
print(
38+
f"forward called with x[features].shape {x['features'].shape}, self.training {self.training}"
39+
)
40+
if self.training:
41+
self.fit_sklearn(x["features"], x["labels"])
42+
preds = []
43+
for model in self.models:
44+
try:
45+
p = torch.from_numpy(model.predict(x["features"])).float()
46+
p = p.to(x["features"].device)
47+
preds.append(p)
48+
except NotFittedError:
49+
preds.append(
50+
torch.zeros((x["features"].shape[0]), device=(x["features"].device))
51+
)
52+
except AttributeError:
53+
preds.append(
54+
torch.zeros((x["features"].shape[0]), device=(x["features"].device))
55+
)
56+
preds = torch.stack(preds, dim=1)
57+
print(f"preds shape {preds.shape}")
58+
return preds.squeeze(-1)
59+
60+
def fit_sklearn(self, X, y):
61+
"""
62+
Fit the underlying sklearn model. X and y should be numpy arrays.
63+
"""
64+
for i, model in tqdm.tqdm(enumerate(self.models), desc="Fitting models"):
65+
import os
66+
67+
if os.path.exists(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl")):
68+
print(f"Loading model {i} from file")
69+
self.models[i] = pkl.load(
70+
open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "rb")
71+
)
72+
else:
73+
if (
74+
self.only_predict_classes and i not in self.only_predict_classes
75+
): # only try these classes
76+
continue
77+
try:
78+
model.fit(X, y[:, i])
79+
except ValueError:
80+
self.models[i] = PlaceholderModel()
81+
# dump
82+
pkl.dump(
83+
model, open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "wb")
84+
)
85+
86+
def configure_optimizers(self, **kwargs):
87+
pass
88+
89+
90+
class PlaceholderModel:
91+
"""Acts like a trained model, but isn't. Use this if training fails and you need a placeholder."""
92+
93+
def __init__(self, default_prediction=1):
94+
self.default_prediction = default_prediction
95+
96+
def predict(self, preds):
97+
return np.ones(preds.shape[0]) * self.default_prediction

chebai/models/electra.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def __init__(
224224
config: Optional[Dict[str, Any]] = None,
225225
pretrained_checkpoint: Optional[str] = None,
226226
load_prefix: Optional[str] = None,
227+
freeze_electra: bool = False,
227228
**kwargs: Any,
228229
):
229230
# Remove this property in order to prevent it from being stored as a
@@ -262,6 +263,10 @@ def __init__(
262263
else:
263264
self.electra = ElectraModel(config=self.config)
264265

266+
if freeze_electra:
267+
for param in self.electra.parameters():
268+
param.requires_grad = False
269+
265270
def _process_for_loss(
266271
self,
267272
model_output: Dict[str, Tensor],

chebai/models/lstm.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,55 @@
11
import logging
22

33
from torch import nn
4-
from torch.nn.utils.rnn import pack_padded_sequence
4+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
55

66
from chebai.models.base import ChebaiBaseNet
77

88
logging.getLogger("pysmiles").setLevel(logging.CRITICAL)
99

1010

1111
class ChemLSTM(ChebaiBaseNet):
12-
def __init__(self, in_d, out_d, num_classes, **kwargs):
13-
super().__init__(num_classes, **kwargs)
14-
self.lstm = nn.LSTM(in_d, out_d, batch_first=True)
15-
self.embedding = nn.Embedding(800, 100)
12+
def __init__(
13+
self,
14+
out_d,
15+
in_d,
16+
num_classes,
17+
criterion: nn.Module = None,
18+
num_layers=6,
19+
dropout=0.2,
20+
**kwargs,
21+
):
22+
super().__init__(
23+
out_dim=out_d,
24+
input_dim=in_d,
25+
criterion=criterion,
26+
num_classes=num_classes,
27+
**kwargs,
28+
)
29+
self.lstm = nn.LSTM(
30+
in_d,
31+
out_d,
32+
batch_first=True,
33+
dropout=dropout,
34+
bidirectional=True,
35+
num_layers=num_layers,
36+
)
37+
self.embedding = nn.Embedding(1400, in_d)
1638
self.output = nn.Sequential(
17-
nn.Linear(out_d, in_d),
39+
nn.Linear(out_d * 2, out_d),
1840
nn.ReLU(),
1941
nn.Dropout(0.2),
20-
nn.Linear(in_d, num_classes),
42+
nn.Linear(out_d, num_classes),
2143
)
2244

23-
def forward(self, data):
24-
x = data.x
25-
x_lens = data.lens
45+
def forward(self, data, *args, **kwargs):
46+
x = data["features"]
47+
x_lens = data["model_kwargs"]["lens"]
2648
x = self.embedding(x)
2749
x = pack_padded_sequence(x, x_lens, batch_first=True, enforce_sorted=False)
28-
x = self.lstm(x)[1][0]
29-
# = pad_packed_sequence(x, batch_first=True)[0]
50+
x = self.lstm(x)[0]
51+
x = pad_packed_sequence(x, batch_first=True)[0][
52+
:, 0
53+
] # reduce sequence dimension to first element
3054
x = self.output(x)
31-
return x.squeeze(0)
55+
return x

0 commit comments

Comments
 (0)