@@ -460,13 +460,13 @@ def evaluate_criticality_measure(self, params):
460460 d , gnew , crvmin = ctrsbox_sfista (self .model .xopt (abs_coordinates = True ), gopt , np .zeros (H .shape ), self .model .projections , 1 ,
461461 self .h , self .lh , self .prox_uh , argsh = self .argsh , argsprox = self .argsprox , func_tol = func_tol ,
462462 max_iters = params ("func_tol.max_iters" ), d_max_iters = params ("dykstra.max_iters" ), d_tol = params ("dykstra.d_tol" ),
463- scaling_changes = self .scaling_changes )
463+ scaling_changes = self .scaling_changes , sfista_iters_scale = params ( "sfista.max_iters_scaling" ) )
464464 else :
465465 proj = lambda x : pbox (x , self .model .sl , self .model .su )
466466 d , gnew , crvmin = ctrsbox_sfista (self .model .xopt (abs_coordinates = True ), gopt , np .zeros (H .shape ), [proj ], 1 ,
467467 self .h , self .lh , self .prox_uh , argsh = self .argsh , argsprox = self .argsprox , func_tol = func_tol ,
468468 max_iters = params ("func_tol.max_iters" ), d_max_iters = params ("dykstra.max_iters" ), d_tol = params ("dykstra.d_tol" ),
469- scaling_changes = self .scaling_changes )
469+ scaling_changes = self .scaling_changes , sfista_iters_scale = params ( "sfista.max_iters_scaling" ) )
470470
471471 # Calculate criticality measure
472472 criticality_measure = self .h (remove_scaling (self .model .xopt (abs_coordinates = True ), self .scaling_changes ), * self .argsh ) - model_value (gopt , np .zeros (H .shape ), d , self .model .xopt (abs_coordinates = True ), self .h , self .argsh , self .scaling_changes )
@@ -505,15 +505,15 @@ def trust_region_step(self, params, criticality_measure=1e-2):
505505 d , gnew , crvmin = ctrsbox_sfista (self .model .xopt (abs_coordinates = True ), gopt , H , self .model .projections , self .delta ,
506506 self .h , self .lh , self .prox_uh , argsh = self .argsh , argsprox = self .argsprox , func_tol = func_tol ,
507507 max_iters = params ("func_tol.max_iters" ), d_max_iters = params ("dykstra.max_iters" ), d_tol = params ("dykstra.d_tol" ),
508- scaling_changes = self .scaling_changes )
508+ scaling_changes = self .scaling_changes , sfista_iters_scale = params ( "sfista.max_iters_scaling" ) )
509509 else :
510510 # NOTE: alternative way if using trsbox
511511 # d, gnew, crvmin = trsbox(self.model.xopt(), gopt, H, self.model.sl, self.model.su, self.delta)
512512 proj = lambda x : pbox (x , self .model .sl , self .model .su )
513513 d , gnew , crvmin = ctrsbox_sfista (self .model .xopt (abs_coordinates = True ), gopt , H , [proj ], self .delta ,
514514 self .h , self .lh , self .prox_uh , argsh = self .argsh , argsprox = self .argsprox , func_tol = func_tol ,
515515 max_iters = params ("func_tol.max_iters" ), d_max_iters = params ("dykstra.max_iters" ), d_tol = params ("dykstra.d_tol" ),
516- scaling_changes = self .scaling_changes )
516+ scaling_changes = self .scaling_changes , sfista_iters_scale = params ( "sfista.max_iters_scaling" ) )
517517
518518 # NOTE: check sufficient decrease. If increase in the model, set zero step
519519 pred_reduction = self .h (remove_scaling (self .model .xopt (abs_coordinates = True ), self .scaling_changes ), * self .argsh ) - model_value (gopt , H , d , self .model .xopt (abs_coordinates = True ), self .h , self .argsh , self .scaling_changes )
0 commit comments