1+ import copy
12from pathlib import Path
2- from typing import List , Optional , Dict , Any , Union
3+ from typing import List , Optional , Dict
34
45import numpy as np
56import torch
67
7- from pytabkit .models import utils
88from pytabkit .models .alg_interfaces .alg_interfaces import SingleSplitAlgInterface , AlgInterface
99from pytabkit .models .alg_interfaces .base import SplitIdxs , InterfaceResources , RequiredResources
1010from pytabkit .models .data .data import DictDataset , TaskType
1111from pytabkit .models .torch_utils import cat_if_necessary
1212from pytabkit .models .training .logging import Logger
1313from pytabkit .models .training .metrics import Metrics
14+ from pytabkit .models .utils import ObjectLoadingContext
1415
1516
1617class WeightedPrediction :
@@ -28,25 +29,6 @@ def predict_for_weights(self, weights: np.ndarray):
2829 return weighted_sum
2930
3031
31- class ObjectLoadingContext :
32- def __init__ (self , obj : Any , filename : Optional [Union [str , Path ]] = None ):
33- self .obj = obj
34- self .filename = filename
35- self .saved = False
36-
37- def __enter__ (self ) -> Any :
38- # use pickle since it works better with torch than dill
39- if self .saved :
40- self .obj = utils .deserialize (self .filename , use_pickle = True )
41- return self .obj
42-
43- def __exit__ (self , type , value , traceback ) -> None :
44- if self .filename is not None :
45- utils .serialize (self .filename , self .obj , use_pickle = True )
46- self .saved = True
47- del self .obj
48-
49-
5032class CaruanaEnsembleAlgInterface (SingleSplitAlgInterface ):
5133 """
5234 Following a simple variant of Caruana et al. (2004), "Ensemble selection from libraries of models"
@@ -65,10 +47,15 @@ def get_refit_interface(self, n_refit: int, fit_params: Optional[List[Dict]] = N
6547
6648 def fit (self , ds : DictDataset , idxs_list : List [SplitIdxs ], interface_resources : InterfaceResources ,
6749 logger : Logger , tmp_folders : List [Optional [Path ]], name : str ) -> None :
50+ assert len (idxs_list ) == 1
51+
52+ # if tmp_folders is specified, then models will be saved there instead of holding all of them in memory
6853 tmp_folder = tmp_folders [0 ]
6954 self .alg_contexts_ = [ObjectLoadingContext (ai , None if tmp_folder is None else tmp_folder / f'model_{ i } ' ) for
7055 i , ai in enumerate (self .alg_interfaces )]
71- self .alg_interfaces = None # allow not holding all of them later, to free GPU memory
56+ # store copies here, but the ones that will actually be trained are in alg_contexts_
57+ # this means that models should not be held in RAM all the time
58+ self .alg_interfaces = copy .deepcopy (self .alg_interfaces )
7259
7360 sub_fit_params = []
7461
@@ -94,7 +81,7 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
9481 if val_metric_name is None :
9582 val_metric_name = Metrics .default_val_metric_name (task_type = self .task_type )
9683
97- n_caruana_steps = self .config .get ('n_caruana_steps' , 40 ) # default value is taken from TaskRepo paper (IIRC)
84+ n_caruana_steps = self .config .get ('n_caruana_steps' , 40 ) # default value is taken from TabRepo paper (IIRC)
9885
9986 y_preds_oob_list = []
10087 for alg_idx , alg_ctx in enumerate (self .alg_contexts_ ):
@@ -114,6 +101,8 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
114101
115102 wp = WeightedPrediction (y_preds_oob_list , self .task_type )
116103
104+ allow_negative_weights = self .config .get ('allow_negative_weights' , False )
105+
117106 for step_idx in range (n_caruana_steps ):
118107 best_step_weights = None
119108 best_step_loss = np .inf
@@ -129,6 +118,21 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
129118
130119 weights [weight_idx ] -= 1
131120
121+ # negative weights option
122+ # check weights >= 2 allowing for floating-point errors
123+ if allow_negative_weights and np .sum (weights ) >= 1.5 :
124+ weights [weight_idx ] -= 1
125+
126+ y_pred_oob = wp .predict_for_weights (weights )
127+ loss = Metrics .apply (y_pred_oob , y_oob , val_metric_name ).item ()
128+ # print(f'{weights=}, {loss=}')
129+ if loss < best_step_loss :
130+ best_step_loss = loss
131+ best_step_weights = np .copy (weights )
132+
133+ weights [weight_idx ] += 1
134+
135+
132136 if best_step_loss < best_loss :
133137 best_loss = best_step_loss
134138 best_weights = np .copy (best_step_weights )
@@ -179,13 +183,22 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
179183 logger : Logger , tmp_folders : List [Optional [Path ]], name : str ) -> None :
180184 assert len (idxs_list ) == 1
181185
186+ # if tmp_folders is specified, then models will be saved there instead of holding all of them in memory
187+ tmp_folder = tmp_folders [0 ]
188+ self .alg_contexts_ = [ObjectLoadingContext (ai , None if tmp_folder is None else tmp_folder / f'model_{ i } ' ) for
189+ i , ai in enumerate (self .alg_interfaces )]
190+ # store copies here, but the ones that will actually be trained are in alg_contexts_
191+ # this means that models should not be held in RAM all the time
192+ self .alg_interfaces = copy .deepcopy (self .alg_interfaces )
193+
182194 if self .fit_params is not None :
183195 # this is the refit stage, there is no validation data set to determine the best model on,
184196 # instead the best model index is already in fit_params
185197 best_alg_idx = self .fit_params [0 ]['best_alg_idx' ]
186198 sub_tmp_folders = [tmp_folder / str (best_alg_idx ) if tmp_folder is not None else None for tmp_folder in
187199 tmp_folders ]
188- self .alg_interfaces [best_alg_idx ].fit (ds , idxs_list , interface_resources , logger , sub_tmp_folders ,
200+ with self .alg_contexts_ [best_alg_idx ] as alg_interface :
201+ alg_interface .fit (ds , idxs_list , interface_resources , logger , sub_tmp_folders ,
189202 name + f'sub-alg-{ best_alg_idx } ' )
190203
191204 return
@@ -206,28 +219,32 @@ def fit(self, ds: DictDataset, idxs_list: List[SplitIdxs], interface_resources:
206219
207220 best_alg_idx = 0
208221 best_alg_loss = np .inf
222+ best_sub_fit_params = None
209223
210- for alg_idx , alg_interface in enumerate (self .alg_interfaces ):
211- sub_tmp_folders = [tmp_folder / str (alg_idx ) if tmp_folder is not None else None for tmp_folder in
212- tmp_folders ]
213- alg_interface .fit (ds , idxs_list , interface_resources , logger , sub_tmp_folders , name + f'sub-alg-{ alg_idx } ' )
214- y_preds = alg_interface .predict (ds )
215- # get out-of-bag predictions
216- y_pred_oob = cat_if_necessary ([y_preds [j , idxs_list [0 ].val_idxs [j ]]
217- for j in range (idxs_list [0 ].val_idxs .shape [0 ])], dim = 0 )
218- loss = Metrics .apply (y_pred_oob , y_oob , val_metric_name ).item ()
219- if loss < best_alg_loss :
220- best_alg_loss = loss
221- best_alg_idx = alg_idx
224+ for alg_idx , alg_ctx in enumerate (self .alg_contexts_ ):
225+ with alg_ctx as alg_interface :
226+ sub_tmp_folders = [tmp_folder / str (alg_idx ) if tmp_folder is not None else None for tmp_folder in
227+ tmp_folders ]
228+ alg_interface .fit (ds , idxs_list , interface_resources , logger , sub_tmp_folders , name + f'sub-alg-{ alg_idx } ' )
229+ y_preds = alg_interface .predict (ds )
230+ # get out-of-bag predictions
231+ y_pred_oob = cat_if_necessary ([y_preds [j , idxs_list [0 ].val_idxs [j ]]
232+ for j in range (idxs_list [0 ].val_idxs .shape [0 ])], dim = 0 )
233+ loss = Metrics .apply (y_pred_oob , y_oob , val_metric_name ).item ()
234+ if loss < best_alg_loss :
235+ best_alg_loss = loss
236+ best_alg_idx = alg_idx
237+ best_sub_fit_params = alg_interface .get_fit_params ()[0 ]
222238
223239 self .fit_params = [dict (best_alg_idx = best_alg_idx ,
224- sub_fit_params = self . alg_interfaces [ best_alg_idx ]. get_fit_params ()[ 0 ] )]
240+ sub_fit_params = best_sub_fit_params )]
225241 logger .log (2 , f'Best algorithm has index { best_alg_idx } ' )
226242 logger .log (2 , f'Algorithm selection fit parameters: { self .fit_params [0 ]} ' )
227243
228244 def predict (self , ds : DictDataset ) -> torch .Tensor :
229245 alg_idx = self .fit_params [0 ]['best_alg_idx' ]
230- return self .alg_interfaces [alg_idx ].predict (ds )
246+ with self .alg_contexts_ [alg_idx ] as alg_interface :
247+ return alg_interface .predict (ds )
231248
232249 def get_required_resources (self , ds : DictDataset , n_cv : int , n_refit : int , n_splits : int ,
233250 split_seeds : List [int ], n_train : int ) -> RequiredResources :
0 commit comments