@@ -67,6 +67,8 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
6767 raise ValueError (val_error_message )
6868
6969 datafit .initialize (X , y )
70+ lipschitz = datafit .get_lipschitz (X , y )
71+
7072 all_groups = np .arange (n_groups )
7173 p_objs_out = np .zeros (self .max_iter )
7274 stop_crit = 0. # prevent ref before assign when max_iter == 0
@@ -100,7 +102,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
100102
101103 for epoch in range (self .max_epochs ):
102104 # inplace update of w and Xw
103- _bcd_epoch (X , y , w [:n_features ], Xw , datafit , penalty , ws )
105+ _bcd_epoch (X , y , w [:n_features ], Xw , lipschitz , datafit , penalty , ws )
104106
105107 # update intercept
106108 if self .fit_intercept :
@@ -140,15 +142,15 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
140142
141143
142144@njit
143- def _bcd_epoch (X , y , w , Xw , datafit , penalty , ws ):
145+ def _bcd_epoch (X , y , w , Xw , lipschitz , datafit , penalty , ws ):
144146 # perform a single BCD epoch on groups in ws
145147 grp_ptr , grp_indices = penalty .grp_ptr , penalty .grp_indices
146148
147149 for g in ws :
148150 grp_g_indices = grp_indices [grp_ptr [g ]: grp_ptr [g + 1 ]]
149151 old_w_g = w [grp_g_indices ].copy ()
150152
151- lipschitz_g = datafit . lipschitz [g ]
153+ lipschitz_g = lipschitz [g ]
152154 grad_g = datafit .gradient_g (X , y , w , Xw , g )
153155
154156 w [grp_g_indices ] = penalty .prox_1group (
0 commit comments