diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index 618feea..a7dc109 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -1,6 +1,6 @@ import dataclasses from dataclasses import dataclass -from functools import partial +from functools import partial, wraps from typing import Any, Callable import numpy as np @@ -8,6 +8,20 @@ from nutpie import _lib # type: ignore from nutpie.sample import CompiledModel +from importlib.util import find_spec + +# importing from transform_adapter requires flowjax to be installed, which will not be the case for +# all users. If it's not present, the user can't access the with_transform_adapt method anyway, so we can +# use a dummy function so the docstring wrapper is always valid. +if find_spec("flowjax") is not None: + from nutpie.transform_adapter import make_transform_adapter +else: + + def make_transform_adapter(*args, **kwargs): + """Normalizing flow adaption not available. Install flowjax to use.""" + pass + + SeedType = int @@ -44,6 +58,7 @@ def with_data(self, **updates): updated.update(**updates) return dataclasses.replace(self, _shared_data=updated) + @wraps(make_transform_adapter) def with_transform_adapt(self, **kwargs): return dataclasses.replace(self, _transform_adapt_args=kwargs) @@ -71,8 +86,6 @@ def make_expand_func(seed1, seed2, chain): outer_kwargs = {} def make_adapter(*args, **kwargs): - from nutpie.transform_adapter import make_transform_adapter - return make_transform_adapter(**outer_kwargs)( *args, **kwargs, logp_fn=self._raw_logp_fn ) diff --git a/python/nutpie/transform_adapter.py b/python/nutpie/transform_adapter.py index 41d18f4..3404d33 100644 --- a/python/nutpie/transform_adapter.py +++ b/python/nutpie/transform_adapter.py @@ -869,43 +869,149 @@ def inv_transform(self, position, gradient): def make_transform_adapter( *, - verbose=False, - window_size=600, - show_progress=False, - nn_depth=None, - nn_width=None, - num_layers=8, - num_diag_windows=6, - learning_rate=5e-4, + verbose: bool = False, + window_size: int = 600, + show_progress: bool = False, + nn_depth: int | None = None, + nn_width: int | None = None, + num_layers: int = 8, + num_diag_windows: int = 6, + learning_rate: float = 5e-4, untransformed_dim=None, - zero_init=True, - batch_size=128, - reuse_opt_state=False, - max_patience=20, - householder_layer=False, - dct_layer=False, + zero_init: bool = True, + batch_size: int = 128, + reuse_opt_state: bool = False, + max_patience: int = 20, + householder_layer: bool = False, + dct_layer: bool = False, gamma=None, - log_inside_batch=False, - initial_skip=120, - extension_windows=None, - extend_dct=False, - extension_var_count=4, - extension_var_trafo_count=2, - debug_save_bijection=False, - make_optimizer=None, - coupling_type="masked", - mvscale_layer=False, - num_project=None, - num_embed=None, - num_householder=8, - twin_layers=False, - activation=None, - max_epochs=200, - affine_transformer=False, - contract_transformer=True, - asymmetric_transformer=False, - reuse_embed=True, + log_inside_batch: bool = False, + initial_skip: int = 120, + extension_windows: list[int] = None, + extend_dct: bool = False, + extension_var_count: int = 4, + extension_var_trafo_count: int = 2, + debug_save_bijection: bool = False, + make_optimizer: Callable | None = None, + coupling_type: str = "masked", + mvscale_layer: bool = False, + num_project: int | None = None, + num_embed: int | None = None, + num_householder: int = 8, + twin_layers: bool = False, + activation: str | None = None, + max_epochs: int = 200, + affine_transformer: bool = False, + contract_transformer: bool = True, + asymmetric_transformer: bool = False, + reuse_embed: bool = True, ): + """ + Create a TransformAdapter instance with the specified parameters. + + A TransformAdapter is a utility for parameterizing a normalizing flow model used inside MCMC sampling. For more + details, see the documentation here + + Parameters + ---------- + verbose: bool, default False + If True, print debug information, including random seed, available points, and loss value, to the terminal + during training. + window_size: int, default 600 + ??? + show_progress: bool, default False + If True, show a TQDM progress bar during training of the flow network. Note that when using multiple chains, + this will quickly become extremely spammy! + nn_depth: int | None, default None + Number of layers in the neural network used for the flow. If None, defaults to 1. + nn_width: int | None, default None + Number of hidden units in each layer of the flow network. If None, defaults to 32. + num_layers: int, default 8 + Number of flow layers to use in the flow network. Each layer will be parameterized according to nn_dept and + nn_width. + num_diag_windows: int, default 6 + Number of diagonal mass matrix updates to perform before starting the flow training. + learning_rate: float, default 5e-4 + Learning rate passed to the optimizer used to train the flow network. If a custom optimizer is provided via + the make_optimizer argument, this is ignored. + untransformed_dim: int | None, default None + ??? + zero_init: bool, default True + If True, all weights in the flow network are initialized to zero. Otherwise, initialization is done according to + the default flax initialization scheme (lecun_normal) + batch_size: int, default 128 + Number of samples to use in each training batch. + reuse_opt_state: bool, default False + If True, the optimizer state (gradients and optimizer parameters) are stored and reused between updates. + Otherwise, training is restarted from scratch at each training update. + max_patience: int, default 20 + Number of consecutive epochs with no validation loss improvement after which training is terminated. + householder_layer: bool, default False + If True, insert Householder transformation layers into the flow network. For more details, see the householder + layer documentation. + dct_layer: bool, default False + If True, insert discrete cosine transformation (DCT) layers into the flow network. For more details, see the + DCT layer documentation. + gamma: float | None, default None + ??? + log_inside_batch: bool, default False + ??? + initial_skip: int, default 120 + Number of initial samples to completely ignore before flow training. Initial samples are often not sufficiently + representative of the target distribution, and ignoring them can help the flow network to converge. + extension_windows: list[int] | None, default None + ??? + extend_dct: bool, default False + ??? + extension_var_count: int, default 4 + ??? + extension_var_trafo_count: int, default 2 + ??? + debug_save_bijection: bool, default False + ??? + make_optimizer: Callable | None, default None + A function with no arguments that returns an optax optimizer. The default is optax.adamw(learning_rate), + wrapped by optax.apply_if_finite. + coupling_type: str, default "masked" + One of "subset", "masked", "flowjax_coupling", or "twin". This determines the type of coupling layer used + to construct the normalizing flow. For more details, see the coupling type documentation. + mvscale_layer: bool, default False + If True, re-scale parameters using a mean vector and covariance matrix. Ignored if coupling_type is not + "masked". Currently unused. + num_project: int | None, default None + ??? Default is 2 * nn_width. + num_embed: int | None, default None + ??? Default is 2 * nn_width. + num_householder: int, default 1 + If greater than 0, the number of Householder layers to use in the flow network. Layers added in this way are + distinct from the (single) household layer inserted when householder_layer is True. Ignored if couping_type is + not "subset" or "twin". + twin_layers: bool, default False + If True, use twin layers in the flow network. This doubles the number of flow layers by masking each layer + twice, where each mask is the inverse of the previous. This should allow more expression flows to be learned. + activation: str, default "leaky_relu" + Nonlinearity to insert between flow layer. One of "relu", "leaky_relu", "gelu", "tanh", or "sigmoid". + max_epochs: int, default 200 + Maximum number of training epochs to perform when training the flow network. + affine_transformer: bool, default False + If True, parameters are added to the flow network to learn the location and scale of each sample. This can be + seen as a latent non-centered parameterization, or a type of batch norm. Ignored if coupling_type is not + "masked" or "twin". + contract_transformer: bool, default False + ??? Ignored if coupling_type is not "masked" or "twin". + asymmetric_transformer: bool, default False + If True, parameters are added to the flow network to the location and scale of each sample. Unlike + affine_transformer, the asymmetric transformer learns two scales, one for positive inputs, and one for negative + inputs. Ignored if coupling_type is not "masked" or "twin". + reuse_embed: bool, default False + ??? + + Returns + ------- + configured_adapter: TransformAdapter + A partially initialized TransformAdapter with the specified parameters. + """ + if extension_windows is None: extension_windows = []