|
30 | 30 | class InvertibleNetwork(tf.keras.Model): |
31 | 31 | """Implements a chain of conditional invertible coupling layers for conditional density estimation.""" |
32 | 32 |
|
| 33 | + available_designs = ("affine", "spline", "interleaved") |
| 34 | + |
33 | 35 | def __init__( |
34 | 36 | self, |
35 | 37 | num_params, |
@@ -69,16 +71,17 @@ def __init__( |
69 | 71 | num_params : int |
70 | 72 | The number of parameters to perform inference on. Equivalently, the dimensionality of the |
71 | 73 | latent space. |
72 | | - num_coupling_layers : int, optional, default: 5 |
| 74 | + num_coupling_layers : int, optional, default: 6 |
73 | 75 | The number of coupling layers to use as defined in [1] and [2]. In general, more coupling layers |
74 | 76 | will give you more expressive power, but will be slower and may need more simulations to train. |
75 | 77 | Typically, between 4 and 10 coupling layers should suffice for most applications. |
76 | 78 | coupling_design : str or callable, optional, default: 'affine' |
77 | | - The type of internal coupling network to use. Must be in ['affine', 'spline']. |
78 | | - The former corresponds to the architecture in [3, 5], the latter corresponds to a modified |
79 | | - version of [4]. |
| 79 | + The type of internal coupling network to use. Must be in ['affine', 'spline', 'interleaved']. |
| 80 | + The first corresponds to the architecture in [3, 5], the second corresponds to a modified |
| 81 | + version of [4]. The third option will alternate between affine and spline layers, for example, |
| 82 | + if num_coupling_layers == 3, the chain will consist of ["affine", "spline", "affine"] layers. |
80 | 83 |
|
81 | | - In general, spline couplings run slower than affine couplings, but require fewers coupling |
| 84 | + In general, spline couplings run slower than affine couplings, but require fewer coupling |
82 | 85 | layers. Spline couplings may work best with complex (e.g., multimodal) low-dimensional |
83 | 86 | problems. The difference will become less and less pronounced as we move to higher dimensions. |
84 | 87 |
|
@@ -127,16 +130,15 @@ def __init__( |
127 | 130 |
|
128 | 131 | super().__init__(**kwargs) |
129 | 132 |
|
130 | | - settings = dict( |
| 133 | + layer_settings = dict( |
131 | 134 | latent_dim=num_params, |
132 | | - coupling_settings=coupling_settings, |
133 | | - coupling_design=coupling_design, |
134 | 135 | permutation=permutation, |
135 | 136 | use_act_norm=use_act_norm, |
136 | 137 | act_norm_init=act_norm_init, |
137 | 138 | ) |
138 | | - self.coupling_layers = [] |
139 | | - self.coupling_layers = [CouplingLayer(**settings) for _ in range(num_coupling_layers)] |
| 139 | + self.coupling_layers = self._create_coupling_layers( |
| 140 | + layer_settings, coupling_settings, coupling_design, num_coupling_layers |
| 141 | + ) |
140 | 142 | self.soft_flow = use_soft_flow |
141 | 143 | self.soft_low = soft_flow_bounds[0] |
142 | 144 | self.soft_high = soft_flow_bounds[1] |
@@ -230,6 +232,34 @@ def inverse(self, z, condition, **kwargs): |
230 | 232 | target = layer(target, condition, inverse=True, **kwargs) |
231 | 233 | return target |
232 | 234 |
|
| 235 | + @staticmethod |
| 236 | + def _create_coupling_layers(settings, coupling_settings, coupling_design, num_coupling_layers): |
| 237 | + """Helper method to create a list of coupling layers. Takes care |
| 238 | + of the different options for coupling design. |
| 239 | + """ |
| 240 | + |
| 241 | + if coupling_design not in InvertibleNetwork.available_designs: |
| 242 | + raise NotImplementedError("Coupling design should be one of", InvertibleNetwork.available_designs) |
| 243 | + |
| 244 | + # Case affine or spline |
| 245 | + if coupling_design != "interleaved": |
| 246 | + design = coupling_design |
| 247 | + _coupling_settings = coupling_settings |
| 248 | + coupling_layers = [ |
| 249 | + CouplingLayer(coupling_design=design, coupling_settings=_coupling_settings, **settings) |
| 250 | + for _ in range(num_coupling_layers) |
| 251 | + ] |
| 252 | + # Case interleaved, starts with affine |
| 253 | + else: |
| 254 | + coupling_layers = [] |
| 255 | + designs = (["affine", "spline"] * int(np.ceil(num_coupling_layers / 2)))[:num_coupling_layers] |
| 256 | + for design in designs: |
| 257 | + # Fail gently, if neither None, nor a dictionary with keys ("spline", "affine") |
| 258 | + _coupling_settings = None if coupling_settings is None else coupling_settings[design] |
| 259 | + layer = CouplingLayer(coupling_design=design, coupling_settings=_coupling_settings, **settings) |
| 260 | + coupling_layers.append(layer) |
| 261 | + return coupling_layers |
| 262 | + |
233 | 263 | @classmethod |
234 | 264 | def create_config(cls, **kwargs): |
235 | 265 | """ "Used to create the settings dictionary for the internal networks of the invertible |
|
0 commit comments