11import pickle as pkl
2- from typing import Any , Dict
2+ from typing import Any , Dict , List , Optional
33
44import numpy as np
55import torch
88from sklearn .linear_model import LogisticRegression as SklearnLogisticRegression
99
1010from chebai .models .base import ChebaiBaseNet
11+ import os
1112
13+ LR_MODEL_PATH = os .path .join ("models" , "LR" )
1214
1315class LogisticRegression (ChebaiBaseNet ):
1416 """
1517 Logistic Regression model using scikit-learn, wrapped to fit the ChebaiBaseNet interface.
1618 """
1719
18- def __init__ (self , out_dim : int , input_dim : int , ** kwargs ):
20+ def __init__ (self , out_dim : int , input_dim : int , only_predict_classes : Optional [ List ] = None , n_classes = 1528 , ** kwargs ):
1921 super ().__init__ (out_dim = out_dim , input_dim = input_dim , ** kwargs )
2022 self .models = [
21- SklearnLogisticRegression (solver = "liblinear" ) for _ in range (300 )
23+ SklearnLogisticRegression (solver = "liblinear" ) for _ in range (n_classes )
2224 ]
25+ # indices of classes (in the dataset used for training) where a model should be trained
26+ self .only_predict_classes = only_predict_classes
2327
2428 def forward (self , x : Dict [str , Any ], ** kwargs ) -> torch .Tensor :
2529 print (
@@ -36,13 +40,13 @@ def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor:
3640 except NotFittedError :
3741 preds .append (
3842 torch .zeros (
39- (x ["features" ].shape [0 ], 1 ), device = (x ["features" ].device )
43+ (x ["features" ].shape [0 ]), device = (x ["features" ].device )
4044 )
4145 )
4246 except AttributeError :
4347 preds .append (
4448 torch .zeros (
45- (x ["features" ].shape [0 ], 1 ), device = (x ["features" ].device )
49+ (x ["features" ].shape [0 ]), device = (x ["features" ].device )
4650 )
4751 )
4852 preds = torch .stack (preds , dim = 1 )
@@ -56,16 +60,18 @@ def fit_sklearn(self, X, y):
5660 for i , model in tqdm .tqdm (enumerate (self .models ), desc = "Fitting models" ):
5761 import os
5862
59- if os .path .exists (f"LR_CHEBI100_model_ { i } .pkl" ):
63+ if os .path .exists (os . path . join ( LR_MODEL_PATH , f"LR_model_ { i } .pkl") ):
6064 print (f"Loading model { i } from file" )
61- self .models [i ] = pkl .load (open (f"LR_CHEBI100_model_ { i } .pkl" , "rb" ))
65+ self .models [i ] = pkl .load (open (os . path . join ( LR_MODEL_PATH , f"LR_model_ { i } .pkl") , "rb" ))
6266 else :
67+ if self .only_predict_classes and i not in self .only_predict_classes : # only try these classes
68+ continue
6369 try :
6470 model .fit (X , y [:, i ])
6571 except ValueError :
6672 self .models [i ] = PlaceholderModel ()
6773 # dump
68- pkl .dump (model , open (f"LR_CHEBI100_model_ { i } .pkl" , "wb" ))
74+ pkl .dump (model , open (os . path . join ( LR_MODEL_PATH , f"LR_model_ { i } .pkl") , "wb" ))
6975
7076 def configure_optimizers (self , ** kwargs ):
7177 pass
0 commit comments