66
77from .base import Flow
88from .inverted import InverseFlow
9+ from .transformer .affine import AffineTransformer
910
10- __all__ = ["SplitFlow" , "MergeFlow" , "SwapFlow" , "CouplingFlow" , "WrapFlow" , "SetConstantFlow" ]
11+ __all__ = [
12+ "SplitFlow" , "MergeFlow" , "SwapFlow" , "CouplingFlow" ,
13+ "WrapFlow" , "SetConstantFlow" , "VolumePreservingWrapFlow"
14+ ]
1115
1216
1317class SplitFlow (Flow ):
@@ -207,6 +211,7 @@ def _forward(self, *xs, **kwargs):
207211 inp = (xs [i ] for i in self ._indices )
208212 output = [xs [i ] for i in range (len (xs )) if i not in self ._indices ]
209213 * yi , dlogp = self ._flow (* inp , ** kwargs )
214+ assert len (yi ) == len (self ._out_indices )
210215 for i in self ._argsort_out_indices :
211216 index = self ._out_indices [i ]
212217 output .insert (index , yi [i ])
@@ -216,11 +221,109 @@ def _inverse(self, *xs, **kwargs):
216221 inp = (xs [i ] for i in self ._out_indices )
217222 output = [xs [i ] for i in range (len (xs )) if i not in self ._out_indices ]
218223 * yi , dlogp = self ._flow (* inp , inverse = True , ** kwargs )
224+ assert len (yi ) == len (self ._indices )
219225 for i in self ._argsort_indices :
220226 index = self ._indices [i ]
221227 output .insert (index , yi [i ])
222228 return (* tuple (output ), dlogp )
223229
230+ def output_index (self , input_index ):
231+ """Output index of a non-transformed input."""
232+ if input_index in self ._indices :
233+ raise ValueError ("output_index is only defined for non-transformed inputs" )
234+ n_flow_non_inputs_before_sink = sum (i not in self ._indices for i in range (input_index ))
235+ output_index = n_flow_non_inputs_before_sink
236+ for i_out in sorted (self ._out_indices ):
237+ if i_out <= output_index :
238+ # "insert before"
239+ output_index += 1
240+ else :
241+ # everything else is inserted after
242+ break
243+ return output_index
244+
245+
246+ class VolumePreservingWrapFlow (Flow ):
247+ def __init__ (
248+ self ,
249+ flow : Flow ,
250+ volume_sink_index : int ,
251+ out_volume_sink_index : int ,
252+ cond_indices : Sequence [int ],
253+ shift_transformation : torch .nn .Module = None ,
254+ scale_transformation : torch .nn .Module = None
255+ ):
256+ """Volume-preserving wrap layer.
257+
258+ This layer operates on two or more input tensors.
259+
260+ One of these tensors (as indexed by `volume_sink_index` and `out_volume_sink_index`)
261+ acts as a volume sink, while the others are transformed by a flow.
262+ Concretely, after applying the flow, an affine layer is applied to the volume sink
263+ in such a way that the volume change of this affine "co-transform" (`co_dlogp`) counteracts
264+ the volume change (`dlogp`) of the primary flow, `dlogp + co_dlogp = 0`.
265+
266+ The parameters of the co-transform (shift and scale)
267+ can be conditioned on dlogp as well as the inputs and outputs
268+ of the primary flow.
269+
270+ It is important that the primary transform does not use the volume sink in any way,
271+ neither transform it nor condition on it.
272+
273+ Parameters
274+ ----------
275+ flow
276+ The primary transform.
277+ volume_sink_index
278+ Input index of the volume sink tensor in the forward pass.
279+ out_volume_sink_index
280+ Input index of the volume sink tensor in the inverse pass.
281+ cond_indices : Sequence[int]
282+ This is a bit tricky. These indices refer to elements of the list
283+ `[dlogp, *inputs, *outputs]`.
284+ shift_transformation : torch.nn.Module, optional
285+ scale_transformation : torch.nn.Module, optional
286+ """
287+ super ().__init__ ()
288+ self .flow = flow
289+ self .volume_sink_index = volume_sink_index
290+ self .out_volume_sink_index = out_volume_sink_index
291+ co_transform = AffineTransformer (
292+ shift_transformation = shift_transformation ,
293+ scale_transformation = scale_transformation ,
294+ preserve_volume = True ,
295+ )
296+ self .co_flow = CouplingFlow (
297+ transformer = co_transform ,
298+ transformed_indices = (1 + self .volume_sink_index , ),
299+ cond_indices = cond_indices ,
300+ cat_dim = - 1
301+ )
302+ assert all (i != 1 + self .volume_sink_index for i in cond_indices )
303+
304+ def _forward (self , * xs , ** kwargs ):
305+ * ys , dlogp = self .flow .forward (* xs , ** kwargs )
306+ co_out , co_dlogp = self ._apply_coflow (dlogp , xs , ys , inverse = False )
307+ ys [self .out_volume_sink_index ] = co_out
308+ return (* ys , dlogp + co_dlogp )
309+
310+ def _inverse (self , * ys , ** kwargs ):
311+ * xs , dlogp = self .flow .forward (* ys , inverse = True , ** kwargs )
312+ co_out , co_dlogp = self ._apply_coflow (forward_dlogp = - dlogp , xs = xs , ys = ys , inverse = True )
313+ xs [self .volume_sink_index ] = co_out
314+ return (* xs , dlogp + co_dlogp )
315+
316+ def _apply_coflow (self , forward_dlogp , xs , ys , inverse ):
317+ assert torch .allclose (xs [self .volume_sink_index ], ys [self .out_volume_sink_index ])
318+ coflow_in = [
319+ forward_dlogp ,
320+ * [x for i , x in enumerate (xs )],
321+ * [y for i , y in enumerate (ys )]
322+ ]
323+ target_dlogp = forward_dlogp if inverse else - forward_dlogp
324+ * co_out , co_dlogp = self .co_flow .forward (* coflow_in , target_dlogp = target_dlogp , inverse = inverse )
325+ return co_out [1 + self .volume_sink_index ], co_dlogp
326+
224327
225328class SetConstantFlow (Flow ):
226329 """A flow that sets some inputs constant in the forward direction and removes them in the inverse.
0 commit comments