2020
2121import copy
2222
23+ import tensorflow as tf
2324from tensorflow .keras .optimizers .schedules import LearningRateSchedule
2425
2526from bayesflow import default_settings
@@ -104,7 +105,7 @@ def extract_current_lr(optimizer):
104105
105106def format_loss_string (ep , it , loss , avg_dict , slope = None , lr = None ,
106107 ep_str = "Epoch" , it_str = 'Iter' , scalar_loss_str = 'Loss' ):
107- """ Prepare loss string for displaying on progress bar."""
108+ """Prepare loss string for displaying on progress bar."""
108109
109110 disp_str = f"{ ep_str } : { ep } , { it_str } : { it } "
110111 if type (loss ) is dict :
@@ -123,6 +124,50 @@ def format_loss_string(ep, it, loss, avg_dict, slope=None, lr=None,
123124 return disp_str
124125
125126
127+ def backprop_step (input_dict , amortizer , optimizer , ** kwargs ):
128+ """Computes the loss of the provided amortizer given an input dictionary and applies gradients.
129+
130+ Parameters
131+ ----------
132+ input_dict : dict
133+ The configured output of the genrative model
134+ amortizer : tf.keras.Model
135+ The custom amortizer. Needs to implement a compute_loss method.
136+ optimizer : tf.keras.optimizers.Optimizer
137+ The optimizer used to update the amortizer's parameters.
138+ **kwargs : dict
139+ Optional keyword arguments passed to the network's compute_loss method
140+
141+ Returns
142+ -------
143+ loss : dict
144+ The outputs of the compute_loss() method of the amortizer comprising all
145+ loss components, such as divergences or regularization.
146+ """
147+
148+ # Forward pass and loss computation
149+ with tf .GradientTape () as tape :
150+ # Compute custom loss
151+ loss = amortizer .compute_loss (input_dict , training = True , ** kwargs )
152+ # If dict, add components
153+ if type (loss ) is dict :
154+ _loss = tf .add_n (list (loss .values ()))
155+ else :
156+ _loss = loss
157+ # Collect regularization loss, if any
158+ if amortizer .losses != []:
159+ reg = tf .add_n (amortizer .losses )
160+ _loss += reg
161+ if type (loss ) is dict :
162+ loss ['W.Decay' ] = reg
163+ else :
164+ loss = {'Loss' : loss , 'W.Decay' : reg }
165+ # One step backprop and return loss
166+ gradients = tape .gradient (_loss , amortizer .trainable_variables )
167+ optimizer .apply_gradients (zip (gradients , amortizer .trainable_variables ))
168+ return loss
169+
170+
126171def check_posterior_prior_shapes (post_samples , prior_samples ):
127172 """Checks requirements for the shapes of posterior and prior draws as
128173 necessitated by most diagnostic functions.
0 commit comments