Skip to content

masked autoregressive flow with mixed transformer types #161

@llaurabatt

Description

@llaurabatt

I am looking into a modification of a regular masked autoregressive flow where the base distribution is an N-dimensional uniform and the first variable does not get transformed, while the rest of the variables get transformed via a rational quadratic spline. I have removed the shuffling in the masked_autoregressive_flow function via removing the _add_default_permute, and modified the _flat_params_to_transformer in the MaskedAutoregressive class to apply an Identity transformer to the first dimension in the following way

    def _flat_params_to_transformer(self, params: Array, y_dim=1):
        """Reshape to dim X params_per_dim, then vmap."""
        dim = self.shape[-1]
        transformer_params = jnp.reshape(params, (dim, -1))
        transformer_params = transformer_params[y_dim:, :]
        transformer = eqx.filter_vmap(self.transformer_constructor)(transformer_params)
        return Concatenate(
            [Identity((y_dim,)), Vmap(transformer, in_axes=eqx.if_array(0))]
        )

My understanding is that in this way the masked_autoregressive_mlp will still produce a set of spline parameters for the first variable, that then never get used, and that this should be harmless. My experiments seem to produce the expected results but I am not sure that this is the most efficient way to go about this or whether I am disregarding anything relevant, so would love to hear your opinion as to how to make the best use of your package. Thanks again for all the amazing work!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions