Skip to content

Commit 905ffc2

Browse files
committed
more options for LR
1 parent 03635ff commit 905ffc2

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

chebai/models/classic_ml.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pickle as pkl
2-
from typing import Any, Dict
2+
from typing import Any, Dict, List, Optional
33

44
import numpy as np
55
import torch
@@ -8,18 +8,22 @@
88
from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression
99

1010
from chebai.models.base import ChebaiBaseNet
11+
import os
1112

13+
LR_MODEL_PATH = os.path.join("models", "LR")
1214

1315
class LogisticRegression(ChebaiBaseNet):
1416
"""
1517
Logistic Regression model using scikit-learn, wrapped to fit the ChebaiBaseNet interface.
1618
"""
1719

18-
def __init__(self, out_dim: int, input_dim: int, **kwargs):
20+
def __init__(self, out_dim: int, input_dim: int, only_predict_classes: Optional[List] = None, n_classes=1528, **kwargs):
1921
super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs)
2022
self.models = [
21-
SklearnLogisticRegression(solver="liblinear") for _ in range(300)
23+
SklearnLogisticRegression(solver="liblinear") for _ in range(n_classes)
2224
]
25+
# indices of classes (in the dataset used for training) where a model should be trained
26+
self.only_predict_classes = only_predict_classes
2327

2428
def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor:
2529
print(
@@ -36,13 +40,13 @@ def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor:
3640
except NotFittedError:
3741
preds.append(
3842
torch.zeros(
39-
(x["features"].shape[0], 1), device=(x["features"].device)
43+
(x["features"].shape[0]), device=(x["features"].device)
4044
)
4145
)
4246
except AttributeError:
4347
preds.append(
4448
torch.zeros(
45-
(x["features"].shape[0], 1), device=(x["features"].device)
49+
(x["features"].shape[0]), device=(x["features"].device)
4650
)
4751
)
4852
preds = torch.stack(preds, dim=1)
@@ -56,16 +60,18 @@ def fit_sklearn(self, X, y):
5660
for i, model in tqdm.tqdm(enumerate(self.models), desc="Fitting models"):
5761
import os
5862

59-
if os.path.exists(f"LR_CHEBI100_model_{i}.pkl"):
63+
if os.path.exists(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl")):
6064
print(f"Loading model {i} from file")
61-
self.models[i] = pkl.load(open(f"LR_CHEBI100_model_{i}.pkl", "rb"))
65+
self.models[i] = pkl.load(open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "rb"))
6266
else:
67+
if self.only_predict_classes and i not in self.only_predict_classes: # only try these classes
68+
continue
6369
try:
6470
model.fit(X, y[:, i])
6571
except ValueError:
6672
self.models[i] = PlaceholderModel()
6773
# dump
68-
pkl.dump(model, open(f"LR_CHEBI100_model_{i}.pkl", "wb"))
74+
pkl.dump(model, open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "wb"))
6975

7076
def configure_optimizers(self, **kwargs):
7177
pass

0 commit comments

Comments
 (0)