Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion pymc/variational/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@

import numpy as np

__all__ = ["Callback", "CheckParametersConvergence", "Tracker"]
__all__ = [
"Callback",
"CheckParametersConvergence",
"ExponentialDecay",
"ReduceLROnPlateau",
"Tracker",
]


class Callback:
Expand Down Expand Up @@ -93,6 +99,125 @@ def flatten_shared(shared_list):
return np.concatenate([sh.get_value().flatten() for sh in shared_list])


class LearningRateScheduler(Callback):
"""Baseclass for learning rate schedulers."""

def __init__(self, optimizer):
self.optimizer = optimizer

def __call__(self, approx, loss_hist, i):
raise NotImplementedError("Must be implemented in subclass.")

def _set_new_lr(self, new_lr):
self.optimizer.keywords["learning_rate"] = new_lr


class ExponentialDecay(LearningRateScheduler):
"""
Exponentially decays the learning rate.

This is inspired by Keras' homonymous callback:
https://github.com/keras-team/keras/blob/v2.14.0/keras/optimizers/schedules/learning_rate_schedule.py

Parameters
----------
decay_steps : int
Number of steps at which the learning rate decay happens.
decay_rate : float
Rate of decay.
min_lr: float
lower bound on the learning rate
staircase : bool
If True, decay the learning rate at discrete intervals.
"""

def __init__(self, optimizer, decay_steps, decay_rate, min_lr=1e-6, staircase=False):
super().__init__(optimizer)
self.decay_steps = decay_steps
self.decay_rate = decay_rate
self.staircase = staircase
self.min_lr = min_lr

self.initial_learning_rate = float(self.optimizer.keywords["learning_rate"])

def __call__(self, approx, loss_hist, i):
if self.staircase:
new_lr = self.initial_learning_rate * self.decay_rate ** (i // self.decay_steps)
else:
new_lr = self.initial_learning_rate * self.decay_rate ** (i / self.decay_steps)
if new_lr >= self.min_lr:
self._set_new_lr(new_lr)


class ReduceLROnPlateau(LearningRateScheduler):
"""
Reduce learning rate when the loss has stopped improving.

This is inspired by Keras' homonymous callback:
https://github.com/keras-team/keras/blob/v2.14.0/keras/callbacks.py

Parameters
----------
optimizer: callable
PyMC optimizer
factor: float
factor by which the learning rate will be reduced: `new_lr = lr * factor`
patience: int
number of epochs with no improvement after which learning rate will be reduced
min_lr: float
lower bound on the learning rate
cooldown: int
number of iterations to wait before resuming normal operation after lr has been reduced
"""

def __init__(
self,
optimizer,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the user have to provide this? Can it instead be inferred somehow from the host VI object? It's ugly to have to pass the optimizer twice (once for the VI itself, then again in the callback)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, this would be great, but I haven't figured out whether it's possible. Probably one for someone more familiar with the codebase :)

factor=0.1,
patience=10,
min_lr=1e-6,
cooldown=0,
):
super().__init__(optimizer)
self.factor = factor
self.patience = patience
self.min_lr = min_lr
self.cooldown = cooldown
self.cooldown_counter = 0
self.wait = 0
self.best = float("inf")

def _in_cooldown(self):
return self.cooldown_counter > 0

def _reduce_lr(self):
old_lr = float(self.optimizer.keywords["learning_rate"])
new_lr = max(old_lr * self.factor, self.min_lr)
if new_lr >= self.min_lr:
self._set_new_lr(new_lr)

def __call__(self, approx, loss_hist, i):
current = loss_hist[-1]

if np.isinf(current):
return

if self._in_cooldown():
self.cooldown_counter -= 1
self.wait = 0
return

if current < self.best:
self.best = current
self.wait = 0
elif not np.isinf(self.best):
self.wait += 1
if self.wait >= self.patience:
self._reduce_lr()
self.cooldown_counter = self.cooldown
self.wait = 0


class Tracker(Callback):
"""
Helper class to record arbitrary stats during VI
Expand Down
7 changes: 6 additions & 1 deletion pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def updates(
if more_updates is None:
more_updates = dict()
resulting_updates = ObjectiveUpdates()

if self.test_params:
self.add_test_updates(
resulting_updates,
Expand Down Expand Up @@ -313,10 +314,14 @@ def add_obj_updates(
obj_target = self(
obj_n_mc, more_obj_params=more_obj_params, more_replacements=more_replacements
)

grads = pm.updates.get_or_compute_grads(obj_target, self.obj_params + more_obj_params)
if total_grad_norm_constraint is not None:
grads = pm.total_norm_constraint(grads, total_grad_norm_constraint)
updates.update(obj_optimizer(grads, self.obj_params + more_obj_params))

# Pass the loss plus the gradients to the optimizer, so that schedulers can use the loss if need.
updates.update(obj_optimizer((obj_target, grads), self.obj_params + more_obj_params))

if self.op.returns_loss:
updates.loss = obj_target

Expand Down
Loading