Skip to content

Commit 2f8ad28

Browse files
committed
PFLPruner V2
1 parent f7e9d6e commit 2f8ad28

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed

pruner.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,130 @@ def _should_prune_trial(self, trial: Trial) -> bool:
193193
return curr_pfl < worst_pfl
194194

195195
return False
196+
197+
# ┌──────────────────────────────────────────────────────────┐
198+
# Improved Predicted Final Loss (PFL) Pruner V2
199+
# └──────────────────────────────────────────────────────────┘
200+
class PFLPrunerV2(BasePruner):
201+
"""
202+
Improved Predicted Final Loss (PFL) based pruner.
203+
204+
This pruner models learning curves using a power-law fit (y = a*x^b)
205+
and prunes a trial if its predicted final loss is worse than the
206+
actual final loss of the k-th best completed trial.
207+
"""
208+
209+
def __init__(
210+
self,
211+
n_startup_trials: int = 10,
212+
n_warmup_epochs: int = 10,
213+
top_k: int = 10,
214+
target_epoch: int = 50,
215+
min_points_for_prediction: int = 3,
216+
):
217+
super().__init__()
218+
self.n_startup_trials = n_startup_trials
219+
self.n_warmup_epochs = n_warmup_epochs
220+
self.top_k = top_k
221+
self.target_epoch = target_epoch
222+
self.min_points_for_prediction = min_points_for_prediction
223+
224+
self.top_k_final_losses: List[float] = []
225+
self.completed_trials_count = 0
226+
227+
def complete_trial(self, trial_id: int) -> None:
228+
if trial_id in self._trials:
229+
trial = self._trials[trial_id]
230+
final_loss = self._get_final_loss(trial)
231+
232+
if np.isfinite(final_loss):
233+
self.completed_trials_count += 1
234+
if len(self.top_k_final_losses) < self.top_k:
235+
bisect.insort(self.top_k_final_losses, final_loss)
236+
elif final_loss < self.top_k_final_losses[-1]:
237+
self.top_k_final_losses.pop()
238+
bisect.insort(self.top_k_final_losses, final_loss)
239+
240+
super().complete_trial(trial_id)
241+
del self._trials[trial_id]
242+
243+
def _get_final_loss(self, trial: Trial) -> float:
244+
"""Get the average final loss across all seeds for a completed trial."""
245+
if not trial.seed_values:
246+
return float("inf")
247+
248+
total_loss = 0.0
249+
n_seeds = len(trial.seed_values)
250+
for loss_vec in trial.seed_values.values():
251+
if not loss_vec: return float("inf")
252+
total_loss += loss_vec[-1]
253+
254+
return total_loss / n_seeds if n_seeds > 0 else float("inf")
255+
256+
def _predict_final_loss_power_law(self, losses: List[float]) -> float:
257+
"""
258+
Predict final loss using power-law curve fitting (y = a*x^b).
259+
This is equivalent to a linear fit in log-log space.
260+
"""
261+
n_losses = len(losses)
262+
if n_losses < self.min_points_for_prediction:
263+
return float("inf")
264+
265+
try:
266+
# x: epochs (1-based), y: losses
267+
epochs = np.arange(1, n_losses + 1)
268+
# Clip losses to avoid log(0) issues
269+
safe_losses = np.maximum(losses, 1e-10)
270+
271+
log_epochs = np.log(epochs)
272+
log_losses = np.log(safe_losses)
273+
274+
# Linear fit in log-log space
275+
b, log_a = np.polyfit(log_epochs, log_losses, 1)
276+
277+
# Prune if the slope (b) is positive
278+
if b > 0:
279+
return float("inf")
280+
281+
# Predict final loss at target_epoch
282+
predicted_log_loss = log_a + b * np.log(self.target_epoch)
283+
predicted_loss = np.exp(predicted_log_loss)
284+
285+
# Return the minimum of predicted loss and the actual final losses
286+
return min(predicted_loss, min(losses))
287+
288+
except (np.linalg.LinAlgError, ValueError):
289+
# If fitting fails, return a large value to indicate pruning
290+
return float("inf")
291+
292+
def _should_prune_trial(self, trial: Trial) -> bool:
293+
# Check if any seed has invalid loss
294+
for losses in trial.seed_values.values():
295+
if not losses or not np.isfinite(losses[-1]):
296+
return True
297+
298+
# Don't prune during warmup period
299+
if (
300+
self.completed_trials_count < self.n_startup_trials
301+
or trial.current_epoch <= self.n_warmup_epochs
302+
):
303+
return False
304+
305+
# Don't prune if we have not enough top_k final losses
306+
if len(self.top_k_final_losses) < 1:
307+
return False
308+
309+
avg_predicted_loss = 0.0
310+
n_seeds = len(trial.seed_values)
311+
if n_seeds == 0: return False
312+
313+
for loss_vec in trial.seed_values.values():
314+
avg_predicted_loss += self._predict_final_loss_power_law(loss_vec)
315+
316+
avg_predicted_loss /= n_seeds
317+
318+
# Get the worst final loss from the top k trials
319+
pruning_threshold = self.top_k_final_losses[-1]
320+
321+
# Prune if the predicted final loss is worse than the threshold
322+
return avg_predicted_loss > pruning_threshold

0 commit comments

Comments
 (0)