Skip to content

Commit 3b233d6

Browse files
committed
streamline classic ml
1 parent abc9b53 commit 3b233d6

File tree

2 files changed

+20
-32
lines changed

2 files changed

+20
-32
lines changed

chebai/callbacks/epoch_metrics.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None:
6262
labels (torch.Tensor): Ground truth labels.
6363
"""
6464
tps = torch.sum(
65-
torch.logical_and(
66-
preds > self.threshold, labels.to(torch.bool)
67-
),
65+
torch.logical_and(preds > self.threshold, labels.to(torch.bool)),
6866
dim=0,
6967
)
7068
self.true_positives += tps

chebai/models/classic_ml.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Any, Dict
21
import pickle as pkl
2+
from typing import Any, Dict
3+
34
import numpy as np
45
import torch
56
import tqdm
@@ -16,55 +17,44 @@ class LogisticRegression(ChebaiBaseNet):
1617

1718
def __init__(self, out_dim: int, input_dim: int, **kwargs):
1819
super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs)
19-
self.models = [SklearnLogisticRegression(solver="liblinear") for _ in range(out_dim)]
20+
self.models = [
21+
SklearnLogisticRegression(solver="liblinear") for _ in range(300)
22+
]
2023

2124
def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor:
2225
print(
2326
f"forward called with x[features].shape {x['features'].shape}, self.training {self.training}"
2427
)
2528
if self.training:
2629
self.fit_sklearn(x["features"], x["labels"])
27-
try:
28-
preds = [
29-
torch.from_numpy(model.predict(x["features"]))
30-
.to(
31-
x["features"].device
32-
if isinstance(x["features"], torch.Tensor)
33-
else "cpu"
34-
)
35-
.float()
36-
for model in self.models
37-
]
38-
except NotFittedError:
39-
# Not fitted yet, return zeros
40-
print(
41-
f"returning default 0s with shape {(x['features'].shape[0], self.out_dim)}"
42-
)
43-
return torch.zeros(
44-
(x["features"].shape[0], self.out_dim),
45-
device=(
46-
x["features"].device
47-
if isinstance(x["features"], torch.Tensor)
48-
else "cpu"
49-
),
50-
)
30+
preds = []
31+
for model in self.models:
32+
try:
33+
p = torch.from_numpy(model.predict(x["features"])).float()
34+
p = p.to(x["features"].device)
35+
preds.append(p)
36+
except NotFittedError:
37+
preds.append(torch.zeros((x["features"].shape[0], 1), device=(x["features"].device)))
38+
except AttributeError:
39+
preds.append(torch.zeros((x["features"].shape[0], 1), device=(x["features"].device)))
5140
preds = torch.stack(preds, dim=1)
5241
print(f"preds shape {preds.shape}")
53-
return preds
42+
return preds.squeeze(-1)
5443

5544
def fit_sklearn(self, X, y):
5645
"""
5746
Fit the underlying sklearn model. X and y should be numpy arrays.
5847
"""
5948
for i, model in tqdm.tqdm(enumerate(self.models), desc="Fitting models"):
6049
import os
50+
6151
if os.path.exists(f"LR_CHEBI100_model_{i}.pkl"):
6252
print(f"Loading model {i} from file")
6353
self.models[i] = pkl.load(open(f"LR_CHEBI100_model_{i}.pkl", "rb"))
6454
else:
6555
try:
6656
model.fit(X, y[:, i])
67-
except ValueError as e:
57+
except ValueError:
6858
self.models[i] = PlaceholderModel()
6959
# dump
7060
pkl.dump(model, open(f"LR_CHEBI100_model_{i}.pkl", "wb"))
@@ -80,4 +70,4 @@ def __init__(self, default_prediction=1):
8070
self.default_prediction = default_prediction
8171

8272
def predict(self, preds):
83-
return np.ones(preds.shape[0]) * self.default_prediction
73+
return np.ones(preds.shape[0]) * self.default_prediction

0 commit comments

Comments
 (0)