Skip to content

Commit c1da092

Browse files
committed
reformat
1 parent 905ffc2 commit c1da092

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

chebai/models/classic_ml.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import pickle as pkl
23
from typing import Any, Dict, List, Optional
34

@@ -8,16 +9,23 @@
89
from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression
910

1011
from chebai.models.base import ChebaiBaseNet
11-
import os
1212

1313
LR_MODEL_PATH = os.path.join("models", "LR")
1414

15+
1516
class LogisticRegression(ChebaiBaseNet):
1617
"""
1718
Logistic Regression model using scikit-learn, wrapped to fit the ChebaiBaseNet interface.
1819
"""
1920

20-
def __init__(self, out_dim: int, input_dim: int, only_predict_classes: Optional[List] = None, n_classes=1528, **kwargs):
21+
def __init__(
22+
self,
23+
out_dim: int,
24+
input_dim: int,
25+
only_predict_classes: Optional[List] = None,
26+
n_classes=1528,
27+
**kwargs,
28+
):
2129
super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs)
2230
self.models = [
2331
SklearnLogisticRegression(solver="liblinear") for _ in range(n_classes)
@@ -39,15 +47,11 @@ def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor:
3947
preds.append(p)
4048
except NotFittedError:
4149
preds.append(
42-
torch.zeros(
43-
(x["features"].shape[0]), device=(x["features"].device)
44-
)
50+
torch.zeros((x["features"].shape[0]), device=(x["features"].device))
4551
)
4652
except AttributeError:
4753
preds.append(
48-
torch.zeros(
49-
(x["features"].shape[0]), device=(x["features"].device)
50-
)
54+
torch.zeros((x["features"].shape[0]), device=(x["features"].device))
5155
)
5256
preds = torch.stack(preds, dim=1)
5357
print(f"preds shape {preds.shape}")
@@ -62,16 +66,22 @@ def fit_sklearn(self, X, y):
6266

6367
if os.path.exists(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl")):
6468
print(f"Loading model {i} from file")
65-
self.models[i] = pkl.load(open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "rb"))
69+
self.models[i] = pkl.load(
70+
open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "rb")
71+
)
6672
else:
67-
if self.only_predict_classes and i not in self.only_predict_classes: # only try these classes
73+
if (
74+
self.only_predict_classes and i not in self.only_predict_classes
75+
): # only try these classes
6876
continue
6977
try:
7078
model.fit(X, y[:, i])
7179
except ValueError:
7280
self.models[i] = PlaceholderModel()
7381
# dump
74-
pkl.dump(model, open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "wb"))
82+
pkl.dump(
83+
model, open(os.path.join(LR_MODEL_PATH, f"LR_model_{i}.pkl"), "wb")
84+
)
7585

7686
def configure_optimizers(self, **kwargs):
7787
pass

0 commit comments

Comments
 (0)