5353import cebra .data
5454import cebra .io
5555import cebra .models
56+ import cebra .solver .single_session as cebra_solver_single
5657from cebra .solver import register
5758from cebra .solver .base import Solver
5859from cebra .solver .schedulers import Scheduler
@@ -187,7 +188,7 @@ def _process_info(self, info):
187188
188189
189190@dataclasses .dataclass
190- class MultiobjectiveSolverBase (Solver ):
191+ class MultiobjectiveSolverBase (cebra_solver_single . SingleSessionSolver ):
191192
192193 feature_ranges : List [slice ] = None
193194 renormalize : bool = None
@@ -209,6 +210,13 @@ def __post_init__(self):
209210 renormalize = self .renormalize ,
210211 )
211212
213+ def parameters (self , session_id : Optional [int ] = None ):
214+ """Iterate over all parameters."""
215+ super ().parameters (session_id = session_id )
216+
217+ for parameter in self .regularizer .parameters ():
218+ yield parameter
219+
212220 def fit (self ,
213221 loader : cebra .data .Loader ,
214222 valid_loader : cebra .data .Loader = None ,
@@ -241,6 +249,7 @@ def _run_validation():
241249 save_hook (solver = self , step = num_steps )
242250 return stats_val
243251
252+ self ._set_fitted_params (loader )
244253 self .to (loader .device )
245254
246255 iterator = self ._get_loader (loader ,
@@ -393,11 +402,14 @@ def validation(
393402 logger = None ,
394403 weights_loss : Optional [List [float ]] = None ,
395404 ):
405+ loader .dataset .configure_for (self .model )
406+ iterator = self ._get_loader (loader )
407+
396408 self .model .eval ()
397409 total_loss = Meter ()
398410
399411 losses_dict = {}
400- for _ , batch in enumerate ( loader ) :
412+ for _ , batch in iterator :
401413 predictions = self ._inference (batch )
402414 losses = self .criterion (predictions )
403415
@@ -445,7 +457,7 @@ def validation(
445457 return stats_val
446458
447459 @torch .no_grad ()
448- def transform (self , inputs : torch .Tensor ) -> torch .Tensor :
460+ def transform_deprecated (self , inputs : torch .Tensor ) -> torch .Tensor :
449461 offset = self .model .get_offset ()
450462 self .model .eval ()
451463 X = inputs .cpu ().numpy ()
0 commit comments