Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion pymc/variational/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import numpy as np

__all__ = ["Callback", "CheckParametersConvergence", "Tracker"]
__all__ = ["Callback", "CheckParametersConvergence", "ReduceLROnPlateau", "Tracker"]


class Callback:
Expand Down Expand Up @@ -93,6 +93,78 @@ def flatten_shared(shared_list):
return np.concatenate([sh.get_value().flatten() for sh in shared_list])


class ReduceLROnPlateau(Callback):
"""Reduce learning rate when the loss has stopped improving.

This is inspired by Keras' homonymous callback:
https://github.com/keras-team/keras/blob/v2.14.0/keras/callbacks.py

Parameters
----------
optimiser: callable
PyMC optimiser
factor: float
factor by which the learning rate will be reduced: `new_lr = lr * factor`
patience: int
number of epochs with no improvement after which learning rate will be reduced
min_lr: float
lower bound on the learning rate
cooldown: int
number of iterations to wait before resuming normal operation after lr has been reduced
verbose: bool
false: quiet, true: update messages
"""

def __init__(
self,
optimiser,
factor=0.1,
patience=10,
min_lr=1e-6,
cooldown=0,
):
self.optimiser = optimiser
self.factor = factor
self.patience = patience
self.min_lr = min_lr
self.cooldown = cooldown

self.cooldown_counter = 0
self.wait = 0
self.best = float("inf")
self.old_lr = None

def __call__(self, approx, loss_hist, i):
current = loss_hist[-1]

if np.isinf(current):
return

if self.in_cooldown():
self.cooldown_counter -= 1
self.wait = 0
return

if current < self.best:
self.best = current
self.wait = 0
elif not np.isinf(self.best):
self.wait += 1
if self.wait >= self.patience:
self.reduce_lr()
self.cooldown_counter = self.cooldown
self.wait = 0

def reduce_lr(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still prefer that this was done symbolically with shared variables, because it will allow for composition between learning rate annealing strategies.

old_lr = float(self.optimiser.keywords["learning_rate"])
if old_lr > self.min_lr:
new_lr = max(old_lr * self.factor, self.min_lr)
self.optimiser.keywords["learning_rate"] = new_lr

def in_cooldown(self):
return self.cooldown_counter > 0


class Tracker(Callback):
"""
Helper class to record arbitrary stats during VI
Expand Down
30 changes: 30 additions & 0 deletions tests/variational/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,33 @@ def test_tracker_callback():
tracker = pm.callbacks.Tracker(bad=lambda t: t) # bad signature
with pytest.raises(TypeError):
tracker(None, None, 1)


def test_reducelronplateau_callback():
optimiser = pm.adam(learning_rate=0.1)
cb = pm.variational.callbacks.ReduceLROnPlateau(
optimiser=optimiser,
patience=1,
min_lr=0.001,
)
cb(None, [float("inf")], 1)
np.testing.assert_almost_equal(optimiser.keywords["learning_rate"], 0.1)
assert cb.best == float("inf")
cb(None, [float("inf"), 2], 1)
np.testing.assert_almost_equal(optimiser.keywords["learning_rate"], 0.1)
assert cb.best == 2
cb(None, [float("inf"), 2, 1], 1)
np.testing.assert_almost_equal(optimiser.keywords["learning_rate"], 0.1)
assert cb.best == 1
cb(None, [float("inf"), 2, 1, 99], 1)
np.testing.assert_almost_equal(optimiser.keywords["learning_rate"], 0.01)
assert cb.best == 1
cb(None, [float("inf"), 2, 1, 99, 0.9], 1)
np.testing.assert_almost_equal(optimiser.keywords["learning_rate"], 0.01)
assert cb.best == 0.9
cb(None, [float("inf"), 2, 1, 99, 0.9, 99], 1)
np.testing.assert_almost_equal(optimiser.keywords["learning_rate"], 0.001)
assert cb.best == 0.9
cb(None, [float("inf"), 2, 1, 99, 0.9, 99, 99], 1)
np.testing.assert_almost_equal(optimiser.keywords["learning_rate"], 0.001)
assert cb.best == 0.9