|
| 1 | +# Planar Flow on a 2D Banana Distribution |
| 2 | + |
| 3 | +This example demonstrates learning a synthetic 2D banana distribution with a planar normalizing flow [^RM2015] by maximizing the Evidence Lower BOund (ELBO). |
| 4 | + |
| 5 | +The two required ingredients are: |
| 6 | + |
| 7 | +- A log-density function `logp` for the target distribution. |
| 8 | +- A parametrised invertible transformation (the planar flow) applied to a simple base distribution. |
| 9 | + |
| 10 | +## Target Distribution |
| 11 | + |
| 12 | +The banana target used here is defined in `example/targets/banana.jl` (see source for details): |
| 13 | + |
| 14 | +```julia |
| 15 | +using Random, Distributions |
| 16 | +Random.seed!(123) |
| 17 | + |
| 18 | +target = Banana(2, 1.0, 10.0) # (dimension, nonlinearity, scale) |
| 19 | +logp = Base.Fix1(logpdf, target) |
| 20 | +``` |
| 21 | + |
| 22 | +You can visualise its contour and samples (figure shipped as `banana.png`). |
| 23 | + |
| 24 | + |
| 25 | + |
| 26 | +## Planar Flow |
| 27 | + |
| 28 | +A planar flow of length N applies a sequence of planar layers to a base distribution q₀: |
| 29 | + |
| 30 | +```math |
| 31 | +T_{n,\theta_n}(x) = x + u_n \tanh(w_n^T x + b_n), \qquad n = 1,\ldots,N. |
| 32 | +``` |
| 33 | + |
| 34 | +Parameters θₙ = (uₙ, wₙ, bₙ) are learned. `Bijectors.jl` provides `PlanarLayer`. |
| 35 | + |
| 36 | +```julia |
| 37 | +using Bijectors |
| 38 | +using Functors # for @leaf |
| 39 | + |
| 40 | +function create_planar_flow(n_layers::Int, q₀) |
| 41 | + d = length(q₀) |
| 42 | + Ls = [PlanarLayer(d) for _ in 1:n_layers] |
| 43 | + ts = reduce(∘, Ls) # alternatively: FunctionChains.fchain(Ls) |
| 44 | + return transformed(q₀, ts) |
| 45 | +end |
| 46 | + |
| 47 | +@leaf MvNormal # prevent updating base distribution parameters |
| 48 | +q₀ = MvNormal(zeros(2), ones(2)) |
| 49 | +flow = create_planar_flow(10, q₀) |
| 50 | +flow_untrained = deepcopy(flow) # keep copy for comparison |
| 51 | +``` |
| 52 | + |
| 53 | +If you build *many* layers (e.g. > ~30) you may reduce compilation time by using `FunctionChains.jl`: |
| 54 | + |
| 55 | +```julia |
| 56 | +# uncomment the following lines to use FunctionChains |
| 57 | +# using FunctionChains |
| 58 | +# ts = fchain([PlanarLayer(d) for _ in 1:n_layers]) |
| 59 | +``` |
| 60 | +See [this comment](https://github.com/TuringLang/NormalizingFlows.jl/blob/8f4371d48228adf368d851e221af076ff929f1cf/src/NormalizingFlows.jl#L52) |
| 61 | +for how the compilation time might be a concern. |
| 62 | + |
| 63 | +## Training the Flow |
| 64 | + |
| 65 | +We maximize the ELBO (here using the minibatch estimator `elbo_batch`) with the generic `train_flow` interface. |
| 66 | + |
| 67 | +```julia |
| 68 | +using NormalizingFlows |
| 69 | +using ADTypes, Optimisers |
| 70 | +using Mooncake |
| 71 | + |
| 72 | +sample_per_iter = 32 |
| 73 | +adtype = ADTypes.AutoMooncake(; config=Mooncake.Config()) # try AutoZygote() / AutoForwardDiff() / etc. |
| 74 | +# optional: callback function to track the batch size per iteration and the AD backend used |
| 75 | +cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter, ad=adtype) |
| 76 | +# optional: defined stopping criteria when the gradient norm is less than 1e-3 |
| 77 | +checkconv(iter, stat, re, θ, st) = stat.gradient_norm < 1e-3 |
| 78 | + |
| 79 | +flow_trained, stats, _ = train_flow( |
| 80 | + elbo_batch, |
| 81 | + flow, |
| 82 | + logp, |
| 83 | + sample_per_iter; |
| 84 | + max_iters = 20_000, |
| 85 | + optimiser = Optimisers.Adam(1e-2), |
| 86 | + ADbackend = adtype, |
| 87 | + callback = cb, |
| 88 | + hasconverged = checkconv, |
| 89 | + show_progress = false, |
| 90 | +) |
| 91 | + |
| 92 | +losses = map(x -> x.loss, stats) |
| 93 | +``` |
| 94 | + |
| 95 | +Plot the losses (negative ELBO): |
| 96 | + |
| 97 | +```julia |
| 98 | +using Plots |
| 99 | +plot(losses; xlabel = "iteration", ylabel = "negative ELBO", label = "", lw = 2) |
| 100 | +``` |
| 101 | + |
| 102 | + |
| 103 | + |
| 104 | +## Evaluating the Trained Flow |
| 105 | + |
| 106 | +The trained flow is a `Bijectors.TransformedDistribution`, so we can call `rand` to draw iid samples and call `logpdf` to evaluate the log-density function of the flow. |
| 107 | +See [documentation of `Bijectors.jl`](https://turinglang.org/Bijectors.jl/dev/distributions/) for details. |
| 108 | +```julia |
| 109 | +n_samples = 1_000 |
| 110 | +samples_trained = rand(flow_trained, n_samples) |
| 111 | +samples_untrained = rand(flow_untrained, n_samples) |
| 112 | +samples_true = rand(target, n_samples) |
| 113 | +``` |
| 114 | + |
| 115 | +Simple visual comparison: |
| 116 | + |
| 117 | +```julia |
| 118 | +using Plots |
| 119 | +scatter(samples_true[1, :], samples_true[2, :]; label="Target", ms=2, alpha=0.5) |
| 120 | +scatter!(samples_untrained[1, :], samples_untrained[2, :]; label="Untrained", ms=2, alpha=0.5) |
| 121 | +scatter!(samples_trained[1, :], samples_trained[2, :]; label="Trained", ms=2, alpha=0.5) |
| 122 | +plot!(title = "Planar Flow: Before vs After Training", xlabel = "x₁", ylabel = "x₂", legend = :topleft) |
| 123 | +``` |
| 124 | + |
| 125 | + |
| 126 | + |
| 127 | +## Notes |
| 128 | + |
| 129 | +- Use `elbo` instead of `elbo_batch` for a single-sample estimator. |
| 130 | +- Switch AD backends by changing `adtype` (see `ADTypes.jl`). |
| 131 | +- Marking the base distribution with `@leaf` prevents its parameters from being updated during training. |
| 132 | + |
| 133 | +## Reference |
| 134 | + |
| 135 | +[^RM2015]: Rezende, D. & Mohamed, S. (2015). Variational Inference with Normalizing Flows. ICML. |
0 commit comments