@@ -193,3 +193,130 @@ def _should_prune_trial(self, trial: Trial) -> bool:
193193 return curr_pfl < worst_pfl
194194
195195 return False
196+
197+ # ┌──────────────────────────────────────────────────────────┐
198+ # Improved Predicted Final Loss (PFL) Pruner V2
199+ # └──────────────────────────────────────────────────────────┘
200+ class PFLPrunerV2 (BasePruner ):
201+ """
202+ Improved Predicted Final Loss (PFL) based pruner.
203+
204+ This pruner models learning curves using a power-law fit (y = a*x^b)
205+ and prunes a trial if its predicted final loss is worse than the
206+ actual final loss of the k-th best completed trial.
207+ """
208+
209+ def __init__ (
210+ self ,
211+ n_startup_trials : int = 10 ,
212+ n_warmup_epochs : int = 10 ,
213+ top_k : int = 10 ,
214+ target_epoch : int = 50 ,
215+ min_points_for_prediction : int = 3 ,
216+ ):
217+ super ().__init__ ()
218+ self .n_startup_trials = n_startup_trials
219+ self .n_warmup_epochs = n_warmup_epochs
220+ self .top_k = top_k
221+ self .target_epoch = target_epoch
222+ self .min_points_for_prediction = min_points_for_prediction
223+
224+ self .top_k_final_losses : List [float ] = []
225+ self .completed_trials_count = 0
226+
227+ def complete_trial (self , trial_id : int ) -> None :
228+ if trial_id in self ._trials :
229+ trial = self ._trials [trial_id ]
230+ final_loss = self ._get_final_loss (trial )
231+
232+ if np .isfinite (final_loss ):
233+ self .completed_trials_count += 1
234+ if len (self .top_k_final_losses ) < self .top_k :
235+ bisect .insort (self .top_k_final_losses , final_loss )
236+ elif final_loss < self .top_k_final_losses [- 1 ]:
237+ self .top_k_final_losses .pop ()
238+ bisect .insort (self .top_k_final_losses , final_loss )
239+
240+ super ().complete_trial (trial_id )
241+ del self ._trials [trial_id ]
242+
243+ def _get_final_loss (self , trial : Trial ) -> float :
244+ """Get the average final loss across all seeds for a completed trial."""
245+ if not trial .seed_values :
246+ return float ("inf" )
247+
248+ total_loss = 0.0
249+ n_seeds = len (trial .seed_values )
250+ for loss_vec in trial .seed_values .values ():
251+ if not loss_vec : return float ("inf" )
252+ total_loss += loss_vec [- 1 ]
253+
254+ return total_loss / n_seeds if n_seeds > 0 else float ("inf" )
255+
256+ def _predict_final_loss_power_law (self , losses : List [float ]) -> float :
257+ """
258+ Predict final loss using power-law curve fitting (y = a*x^b).
259+ This is equivalent to a linear fit in log-log space.
260+ """
261+ n_losses = len (losses )
262+ if n_losses < self .min_points_for_prediction :
263+ return float ("inf" )
264+
265+ try :
266+ # x: epochs (1-based), y: losses
267+ epochs = np .arange (1 , n_losses + 1 )
268+ # Clip losses to avoid log(0) issues
269+ safe_losses = np .maximum (losses , 1e-10 )
270+
271+ log_epochs = np .log (epochs )
272+ log_losses = np .log (safe_losses )
273+
274+ # Linear fit in log-log space
275+ b , log_a = np .polyfit (log_epochs , log_losses , 1 )
276+
277+ # Prune if the slope (b) is positive
278+ if b > 0 :
279+ return float ("inf" )
280+
281+ # Predict final loss at target_epoch
282+ predicted_log_loss = log_a + b * np .log (self .target_epoch )
283+ predicted_loss = np .exp (predicted_log_loss )
284+
285+ # Return the minimum of predicted loss and the actual final losses
286+ return min (predicted_loss , min (losses ))
287+
288+ except (np .linalg .LinAlgError , ValueError ):
289+ # If fitting fails, return a large value to indicate pruning
290+ return float ("inf" )
291+
292+ def _should_prune_trial (self , trial : Trial ) -> bool :
293+ # Check if any seed has invalid loss
294+ for losses in trial .seed_values .values ():
295+ if not losses or not np .isfinite (losses [- 1 ]):
296+ return True
297+
298+ # Don't prune during warmup period
299+ if (
300+ self .completed_trials_count < self .n_startup_trials
301+ or trial .current_epoch <= self .n_warmup_epochs
302+ ):
303+ return False
304+
305+ # Don't prune if we have not enough top_k final losses
306+ if len (self .top_k_final_losses ) < 1 :
307+ return False
308+
309+ avg_predicted_loss = 0.0
310+ n_seeds = len (trial .seed_values )
311+ if n_seeds == 0 : return False
312+
313+ for loss_vec in trial .seed_values .values ():
314+ avg_predicted_loss += self ._predict_final_loss_power_law (loss_vec )
315+
316+ avg_predicted_loss /= n_seeds
317+
318+ # Get the worst final loss from the top k trials
319+ pruning_threshold = self .top_k_final_losses [- 1 ]
320+
321+ # Prune if the predicted final loss is worse than the threshold
322+ return avg_predicted_loss > pruning_threshold
0 commit comments