11from typing import Any , Dict
2-
2+ import pickle as pkl
3+ import numpy as np
34import torch
45import tqdm
56from sklearn .exceptions import NotFittedError
@@ -15,7 +16,7 @@ class LogisticRegression(ChebaiBaseNet):
1516
1617 def __init__ (self , out_dim : int , input_dim : int , ** kwargs ):
1718 super ().__init__ (out_dim = out_dim , input_dim = input_dim , ** kwargs )
18- self .models = [SklearnLogisticRegression (solver = "liblinear" ) for _ in range (5 )]
19+ self .models = [SklearnLogisticRegression (solver = "liblinear" ) for _ in range (out_dim )]
1920
2021 def forward (self , x : Dict [str , Any ], ** kwargs ) -> torch .Tensor :
2122 print (
@@ -56,7 +57,27 @@ def fit_sklearn(self, X, y):
5657 Fit the underlying sklearn model. X and y should be numpy arrays.
5758 """
5859 for i , model in tqdm .tqdm (enumerate (self .models ), desc = "Fitting models" ):
59- model .fit (X , y [:, i ])
60+ import os
61+ if os .path .exists (f"LR_CHEBI100_model_{ i } .pkl" ):
62+ print (f"Loading model { i } from file" )
63+ self .models [i ] = pkl .load (open (f"LR_CHEBI100_model_{ i } .pkl" , "rb" ))
64+ else :
65+ try :
66+ model .fit (X , y [:, i ])
67+ except ValueError as e :
68+ self .models [i ] = PlaceholderModel ()
69+ # dump
70+ pkl .dump (model , open (f"LR_CHEBI100_model_{ i } .pkl" , "wb" ))
6071
6172 def configure_optimizers (self , ** kwargs ):
6273 pass
74+
75+
76+ class PlaceholderModel :
77+ """Acts like a trained model, but isn't. Use this if training fails and you need a placeholder."""
78+
79+ def __init__ (self , default_prediction = 1 ):
80+ self .default_prediction = default_prediction
81+
82+ def predict (self , preds ):
83+ return np .ones (preds .shape [0 ]) * self .default_prediction
0 commit comments