-
-
Notifications
You must be signed in to change notification settings - Fork 161
Description
I am opening an issue based on this discourse thread in which I ask about estimating intractable likelihood functions with normalizing flows. As described in this paper, normalizing flows can approximate the likelihood function of a model by learning the relationship between model parameters and the output of a model. Having a working example would be of broad interest given that many scientific fields work with complex models for which a likelihood function is unknown.
The paper is associated with a Python package called SBI. Here is a simple working example based on a LogNormal distribution. I pasted my attempt at replicating it in Julia below. Note that the package generalizes to processes that emit multiple distributions, but I have used a single distribution for simplicity.
The architectural details in the article were a little sparse, citing:
For the neural spline flow architecture (Durkan et al., 2019), we transformed the reaction time data to the log-domain, used a standard normal base distribution, 2 spline transforms with 5 bins each and conditioning networks with 3 hidden layers and 10 hidden units each, and rectified linear unit activation functions. The neural network training was performed using the sbi package with the following settings: learning rate 0.0005; training batch size 100; 10% of training data as validation data, stop training after 20 epochs without validation loss improvement.
It appears that the architecture is heavily influenced by Sequential Neural Likelihood:
Fast Likelihood-free Inference with Autoregressive Flows. What appears to be the core Python code can be found here.
Also for additional background, Hossein has started a related package, but it is experimental and has no documentation.
Thank you for looking into this. I don't know much about neural network and Flux, but let me know if I can be helpful at all.
WIP Code
###########################################################################################################
# load packages
###########################################################################################################
cd(@__DIR__)
using Pkg
Pkg.activate("")
using Flux, DiffEqFlux, DifferentialEquations, Optimization, OptimizationFlux
using OptimizationOptimJL, Distributions
using Random
Random.seed!(3411)
###########################################################################################################
# generate training data
###########################################################################################################
n_parms = 10_000
# training parameters
train_parms = map(x -> Float32.(rand(Gamma(1.0, .5), 2)), 1:n_parms)
# training samples
samples = map(p -> Float32(rand(LogNormal(p...))), train_parms)
training_data = [hcat(train_parms...); samples']
###########################################################################################################
# setup network
###########################################################################################################
nn = Flux.Chain(
# inputs are parameters μ and σ, and distribution sample
Flux.Dense(3, 10, tanh),
Flux.Dense(10, 10, tanh),
Flux.Dense(10, 1, tanh),
) |> f32
tspan = (0.0f0, 50.0f0)
ffjord_mdl = DiffEqFlux.FFJORD(nn, tspan, Tsit5())
function loss(θ)
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
-mean(logpx)
end
function cb(p, l)::Bool
vl = loss(p)
@info "Training" loss = vl
false
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ffjord_mdl.p)
res1 = Optimization.solve(optprob,
ADAM(0.1),
maxiters = 100, callback=cb)
optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2,
Optim.LBFGS(),
allow_f_increases=false, callback=cb)
###########################################################################################################
# evaluate and plot
###########################################################################################################
using Plots
test_parms = [1.5, .5]
xs = [.05:.05:20;]
true_density = pdf.(LogNormal(test_parms...), xs)
# is there a better way to get the estimated density?
est_density = map(x -> exp(ffjord_mdl([x], res2.u, monte_carlo=false)[1]), xs)
est_density = vcat(est_density...)
# plot the true and estimated densities
plot(xs, true_density)
plot!(xs, est_density)