Skip to content

Commit f11d7c4

Browse files
committed
Add optional autograph functionality to all training modes
1 parent 5cb223f commit f11d7c4

File tree

4 files changed

+123
-77
lines changed

4 files changed

+123
-77
lines changed

bayesflow/coupling_networks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __init__(self, meta):
8585
# Optional permutation
8686
if meta['use_permutation']:
8787
self.permutation = Permutation(self.latent_dim)
88+
self.permutation.trainable = False
8889
else:
8990
self.permutation = None
9091

bayesflow/helper_functions.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import copy
2222

23+
import tensorflow as tf
2324
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
2425

2526
from bayesflow import default_settings
@@ -104,7 +105,7 @@ def extract_current_lr(optimizer):
104105

105106
def format_loss_string(ep, it, loss, avg_dict, slope=None, lr=None,
106107
ep_str="Epoch", it_str='Iter', scalar_loss_str='Loss'):
107-
""" Prepare loss string for displaying on progress bar."""
108+
"""Prepare loss string for displaying on progress bar."""
108109

109110
disp_str = f"{ep_str}: {ep}, {it_str}: {it}"
110111
if type(loss) is dict:
@@ -123,6 +124,50 @@ def format_loss_string(ep, it, loss, avg_dict, slope=None, lr=None,
123124
return disp_str
124125

125126

127+
def backprop_step(input_dict, amortizer, optimizer, **kwargs):
128+
"""Computes the loss of the provided amortizer given an input dictionary and applies gradients.
129+
130+
Parameters
131+
----------
132+
input_dict : dict
133+
The configured output of the genrative model
134+
amortizer : tf.keras.Model
135+
The custom amortizer. Needs to implement a compute_loss method.
136+
optimizer : tf.keras.optimizers.Optimizer
137+
The optimizer used to update the amortizer's parameters.
138+
**kwargs : dict
139+
Optional keyword arguments passed to the network's compute_loss method
140+
141+
Returns
142+
-------
143+
loss : dict
144+
The outputs of the compute_loss() method of the amortizer comprising all
145+
loss components, such as divergences or regularization.
146+
"""
147+
148+
# Forward pass and loss computation
149+
with tf.GradientTape() as tape:
150+
# Compute custom loss
151+
loss = amortizer.compute_loss(input_dict, training=True, **kwargs)
152+
# If dict, add components
153+
if type(loss) is dict:
154+
_loss = tf.add_n(list(loss.values()))
155+
else:
156+
_loss = loss
157+
# Collect regularization loss, if any
158+
if amortizer.losses != []:
159+
reg = tf.add_n(amortizer.losses)
160+
_loss += reg
161+
if type(loss) is dict:
162+
loss['W.Decay'] = reg
163+
else:
164+
loss = {'Loss': loss, 'W.Decay': reg}
165+
# One step backprop and return loss
166+
gradients = tape.gradient(_loss, amortizer.trainable_variables)
167+
optimizer.apply_gradients(zip(gradients, amortizer.trainable_variables))
168+
return loss
169+
170+
126171
def check_posterior_prior_shapes(post_samples, prior_samples):
127172
"""Checks requirements for the shapes of posterior and prior draws as
128173
necessitated by most diagnostic functions.

bayesflow/helper_networks.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,29 @@ def call(self, target, inverse=False):
118118
----------
119119
target : tf.Tensor of shape (batch_size, ...)
120120
The target vector to be permuted over its last axis.
121-
inverse : bool, default: False
121+
inverse : bool, optional, default: False
122122
Controls if the current pass is forward (``inverse=False``) or inverse (``inverse=True``).
123123
124124
Returns
125125
-------
126126
out : tf.Tensor of the same shape as `target`.
127-
The permuted target vector.
128-
127+
The (un-)permuted target vector.
129128
"""
130129

131130
if not inverse:
132-
return tf.transpose(tf.gather(tf.transpose(target), self.permutation))
133-
return tf.transpose(tf.gather(tf.transpose(target), self.inv_permutation))
131+
return self._forward(target)
132+
else:
133+
return self._inverse(target)
134+
135+
@tf.function
136+
def _forward(self, target):
137+
"""Performs a fixed permutation over the last axis."""
138+
return tf.gather(target, self.permutation, axis=-1)
139+
140+
@tf.function
141+
def _inverse(self, target):
142+
"""Un-does the fixed permutation over the last axis."""
143+
return tf.gather(target, self.inv_permutation, axis=-1)
134144

135145

136146
class ActNorm(tf.keras.Model):

0 commit comments

Comments
 (0)