1- from typing import Any , Dict
21import pickle as pkl
2+ from typing import Any , Dict
3+
34import numpy as np
45import torch
56import tqdm
@@ -16,55 +17,44 @@ class LogisticRegression(ChebaiBaseNet):
1617
1718 def __init__ (self , out_dim : int , input_dim : int , ** kwargs ):
1819 super ().__init__ (out_dim = out_dim , input_dim = input_dim , ** kwargs )
19- self .models = [SklearnLogisticRegression (solver = "liblinear" ) for _ in range (out_dim )]
20+ self .models = [
21+ SklearnLogisticRegression (solver = "liblinear" ) for _ in range (300 )
22+ ]
2023
2124 def forward (self , x : Dict [str , Any ], ** kwargs ) -> torch .Tensor :
2225 print (
2326 f"forward called with x[features].shape { x ['features' ].shape } , self.training { self .training } "
2427 )
2528 if self .training :
2629 self .fit_sklearn (x ["features" ], x ["labels" ])
27- try :
28- preds = [
29- torch .from_numpy (model .predict (x ["features" ]))
30- .to (
31- x ["features" ].device
32- if isinstance (x ["features" ], torch .Tensor )
33- else "cpu"
34- )
35- .float ()
36- for model in self .models
37- ]
38- except NotFittedError :
39- # Not fitted yet, return zeros
40- print (
41- f"returning default 0s with shape { (x ['features' ].shape [0 ], self .out_dim )} "
42- )
43- return torch .zeros (
44- (x ["features" ].shape [0 ], self .out_dim ),
45- device = (
46- x ["features" ].device
47- if isinstance (x ["features" ], torch .Tensor )
48- else "cpu"
49- ),
50- )
30+ preds = []
31+ for model in self .models :
32+ try :
33+ p = torch .from_numpy (model .predict (x ["features" ])).float ()
34+ p = p .to (x ["features" ].device )
35+ preds .append (p )
36+ except NotFittedError :
37+ preds .append (torch .zeros ((x ["features" ].shape [0 ], 1 ), device = (x ["features" ].device )))
38+ except AttributeError :
39+ preds .append (torch .zeros ((x ["features" ].shape [0 ], 1 ), device = (x ["features" ].device )))
5140 preds = torch .stack (preds , dim = 1 )
5241 print (f"preds shape { preds .shape } " )
53- return preds
42+ return preds . squeeze ( - 1 )
5443
5544 def fit_sklearn (self , X , y ):
5645 """
5746 Fit the underlying sklearn model. X and y should be numpy arrays.
5847 """
5948 for i , model in tqdm .tqdm (enumerate (self .models ), desc = "Fitting models" ):
6049 import os
50+
6151 if os .path .exists (f"LR_CHEBI100_model_{ i } .pkl" ):
6252 print (f"Loading model { i } from file" )
6353 self .models [i ] = pkl .load (open (f"LR_CHEBI100_model_{ i } .pkl" , "rb" ))
6454 else :
6555 try :
6656 model .fit (X , y [:, i ])
67- except ValueError as e :
57+ except ValueError :
6858 self .models [i ] = PlaceholderModel ()
6959 # dump
7060 pkl .dump (model , open (f"LR_CHEBI100_model_{ i } .pkl" , "wb" ))
@@ -80,4 +70,4 @@ def __init__(self, default_prediction=1):
8070 self .default_prediction = default_prediction
8171
8272 def predict (self , preds ):
83- return np .ones (preds .shape [0 ]) * self .default_prediction
73+ return np .ones (preds .shape [0 ]) * self .default_prediction
0 commit comments