Skip to content

Commit 7f9e28d

Browse files
committed
add default model
1 parent bf97527 commit 7f9e28d

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

chebai/models/classic_ml.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict
2-
2+
import pickle as pkl
3+
import numpy as np
34
import torch
45
import tqdm
56
from sklearn.exceptions import NotFittedError
@@ -15,7 +16,7 @@ class LogisticRegression(ChebaiBaseNet):
1516

1617
def __init__(self, out_dim: int, input_dim: int, **kwargs):
1718
super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs)
18-
self.models = [SklearnLogisticRegression(solver="liblinear") for _ in range(5)]
19+
self.models = [SklearnLogisticRegression(solver="liblinear") for _ in range(out_dim)]
1920

2021
def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor:
2122
print(
@@ -56,7 +57,27 @@ def fit_sklearn(self, X, y):
5657
Fit the underlying sklearn model. X and y should be numpy arrays.
5758
"""
5859
for i, model in tqdm.tqdm(enumerate(self.models), desc="Fitting models"):
59-
model.fit(X, y[:, i])
60+
import os
61+
if os.path.exists(f"LR_CHEBI100_model_{i}.pkl"):
62+
print(f"Loading model {i} from file")
63+
self.models[i] = pkl.load(open(f"LR_CHEBI100_model_{i}.pkl", "rb"))
64+
else:
65+
try:
66+
model.fit(X, y[:, i])
67+
except ValueError as e:
68+
self.models[i] = PlaceholderModel()
69+
# dump
70+
pkl.dump(model, open(f"LR_CHEBI100_model_{i}.pkl", "wb"))
6071

6172
def configure_optimizers(self, **kwargs):
6273
pass
74+
75+
76+
class PlaceholderModel:
77+
"""Acts like a trained model, but isn't. Use this if training fails and you need a placeholder."""
78+
79+
def __init__(self, default_prediction=1):
80+
self.default_prediction = default_prediction
81+
82+
def predict(self, preds):
83+
return np.ones(preds.shape[0]) * self.default_prediction

0 commit comments

Comments
 (0)