1313 PiecewiseRationalQuadraticCDF ,
1414)
1515from nflows .utils import torchutils
16+ from nflows .transforms .UMNN import *
1617
1718
1819class CouplingTransform (Transform ):
@@ -140,6 +141,73 @@ def _coupling_transform_inverse(self, inputs, transform_params):
140141 raise NotImplementedError ()
141142
142143
144+ class UMNNCouplingTransform (CouplingTransform ):
145+ """An unconstrained monotonic neural networks coupling layer that transforms the variables.
146+
147+ Reference:
148+ > A. Wehenkel and G. Louppe, Unconstrained Monotonic Neural Networks, NeurIPS2019.
149+
150+ ---- Specific arguments ----
151+ integrand_net_layers: the layers dimension to put in the integrand network.
152+ cond_size: The embedding size for the conditioning factors.
153+ nb_steps: The number of integration steps.
154+ solver: The quadrature algorithm - CC or CCParallel. Both implements Clenshaw-Curtis quadrature with
155+ Leibniz rule for backward computation. CCParallel pass all the evaluation points (nb_steps) at once, it is faster
156+ but requires more memory.
157+
158+ """
159+ def __init__ (
160+ self ,
161+ mask ,
162+ transform_net_create_fn ,
163+ integrand_net_layers = [50 , 50 , 50 ],
164+ cond_size = 20 ,
165+ nb_steps = 20 ,
166+ solver = "CCParallel" ,
167+ apply_unconditional_transform = False
168+ ):
169+
170+ if apply_unconditional_transform :
171+ unconditional_transform = lambda features : MonotonicNormalizer (integrand_net_layers , 0 , nb_steps , solver )
172+ else :
173+ unconditional_transform = None
174+ self .cond_size = cond_size
175+ super ().__init__ (
176+ mask ,
177+ transform_net_create_fn ,
178+ unconditional_transform = unconditional_transform ,
179+ )
180+
181+ self .transformer = MonotonicNormalizer (integrand_net_layers , cond_size , nb_steps , solver )
182+
183+ def _transform_dim_multiplier (self ):
184+ return self .cond_size
185+
186+ def _coupling_transform_forward (self , inputs , transform_params ):
187+ if len (inputs .shape ) == 2 :
188+ z , jac = self .transformer (inputs , transform_params .reshape (inputs .shape [0 ], inputs .shape [1 ], - 1 ))
189+ log_det_jac = jac .log ().sum (1 )
190+ return z , log_det_jac
191+ else :
192+ B , C , H , W = inputs .shape
193+ z , jac = self .transformer (inputs .permute (0 , 2 , 3 , 1 ).reshape (- 1 , inputs .shape [1 ]), transform_params .permute (0 , 2 , 3 , 1 ).reshape (- 1 , 1 , transform_params .shape [1 ]))
194+ log_det_jac = jac .log ().reshape (B , - 1 ).sum (1 )
195+ return z .reshape (B , H , W , C ).permute (0 , 3 , 1 , 2 ), log_det_jac
196+
197+ def _coupling_transform_inverse (self , inputs , transform_params ):
198+ if len (inputs .shape ) == 2 :
199+ x = self .transformer .inverse_transform (inputs , transform_params .reshape (inputs .shape [0 ], inputs .shape [1 ], - 1 ))
200+ z , jac = self .transformer (x , transform_params .reshape (inputs .shape [0 ], inputs .shape [1 ], - 1 ))
201+ log_det_jac = - jac .log ().sum (1 )
202+ return x , log_det_jac
203+ else :
204+ B , C , H , W = inputs .shape
205+ x = self .transformer .inverse_transform (inputs .permute (0 , 2 , 3 , 1 ).reshape (- 1 , inputs .shape [1 ]), transform_params .permute (0 , 2 , 3 , 1 ).reshape (- 1 , 1 , transform_params .shape [1 ]))
206+ z , jac = self .transformer (x , transform_params .permute (0 , 2 , 3 , 1 ).reshape (- 1 , 1 , transform_params .shape [1 ]))
207+ log_det_jac = - jac .log ().reshape (B , - 1 ).sum (1 )
208+ return x .reshape (B , H , W , C ).permute (0 , 3 , 1 , 2 ), log_det_jac
209+
210+
143211class AffineCouplingTransform (CouplingTransform ):
144212 """An affine coupling layer that scales and shifts part of the variables.
145213
@@ -151,7 +219,7 @@ def _transform_dim_multiplier(self):
151219 return 2
152220
153221 def _scale_and_shift (self , transform_params ):
154- unconstrained_scale = transform_params [:, self .num_transform_features :, ...]
222+ unconstrained_scale = transform_params [:, self .num_transform_features :, ...]
155223 shift = transform_params [:, : self .num_transform_features , ...]
156224 # scale = (F.softplus(unconstrained_scale) + 1e-3).clamp(0, 3)
157225 scale = torch .sigmoid (unconstrained_scale + 2 ) + 1e-3
0 commit comments