1+ import os
12import pickle as pkl
23from typing import Any , Dict , List , Optional
34
89from sklearn .linear_model import LogisticRegression as SklearnLogisticRegression
910
1011from chebai .models .base import ChebaiBaseNet
11- import os
1212
1313LR_MODEL_PATH = os .path .join ("models" , "LR" )
1414
15+
1516class LogisticRegression (ChebaiBaseNet ):
1617 """
1718 Logistic Regression model using scikit-learn, wrapped to fit the ChebaiBaseNet interface.
1819 """
1920
20- def __init__ (self , out_dim : int , input_dim : int , only_predict_classes : Optional [List ] = None , n_classes = 1528 , ** kwargs ):
21+ def __init__ (
22+ self ,
23+ out_dim : int ,
24+ input_dim : int ,
25+ only_predict_classes : Optional [List ] = None ,
26+ n_classes = 1528 ,
27+ ** kwargs ,
28+ ):
2129 super ().__init__ (out_dim = out_dim , input_dim = input_dim , ** kwargs )
2230 self .models = [
2331 SklearnLogisticRegression (solver = "liblinear" ) for _ in range (n_classes )
@@ -39,15 +47,11 @@ def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor:
3947 preds .append (p )
4048 except NotFittedError :
4149 preds .append (
42- torch .zeros (
43- (x ["features" ].shape [0 ]), device = (x ["features" ].device )
44- )
50+ torch .zeros ((x ["features" ].shape [0 ]), device = (x ["features" ].device ))
4551 )
4652 except AttributeError :
4753 preds .append (
48- torch .zeros (
49- (x ["features" ].shape [0 ]), device = (x ["features" ].device )
50- )
54+ torch .zeros ((x ["features" ].shape [0 ]), device = (x ["features" ].device ))
5155 )
5256 preds = torch .stack (preds , dim = 1 )
5357 print (f"preds shape { preds .shape } " )
@@ -62,16 +66,22 @@ def fit_sklearn(self, X, y):
6266
6367 if os .path .exists (os .path .join (LR_MODEL_PATH , f"LR_model_{ i } .pkl" )):
6468 print (f"Loading model { i } from file" )
65- self .models [i ] = pkl .load (open (os .path .join (LR_MODEL_PATH , f"LR_model_{ i } .pkl" ), "rb" ))
69+ self .models [i ] = pkl .load (
70+ open (os .path .join (LR_MODEL_PATH , f"LR_model_{ i } .pkl" ), "rb" )
71+ )
6672 else :
67- if self .only_predict_classes and i not in self .only_predict_classes : # only try these classes
73+ if (
74+ self .only_predict_classes and i not in self .only_predict_classes
75+ ): # only try these classes
6876 continue
6977 try :
7078 model .fit (X , y [:, i ])
7179 except ValueError :
7280 self .models [i ] = PlaceholderModel ()
7381 # dump
74- pkl .dump (model , open (os .path .join (LR_MODEL_PATH , f"LR_model_{ i } .pkl" ), "wb" ))
82+ pkl .dump (
83+ model , open (os .path .join (LR_MODEL_PATH , f"LR_model_{ i } .pkl" ), "wb" )
84+ )
7585
7686 def configure_optimizers (self , ** kwargs ):
7787 pass
0 commit comments