Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
6220e96
switch to differentiationinterface from diffresults
zuhengxu Feb 16, 2025
7b4fb85
rename train.jl to optimize.jl
zuhengxu Feb 16, 2025
5ff6041
fix some compat issue and bump version
zuhengxu Feb 16, 2025
3010669
update tests to new interface
zuhengxu Feb 16, 2025
f7ee84b
add Moonkcake to extras
zuhengxu Feb 16, 2025
e68ef5f
rm all ext for now
zuhengxu Feb 16, 2025
9bdf1f7
rm enzyme test, and import mooncake for test
zuhengxu Feb 16, 2025
b7f9f08
fixing compat and test with mooncake
zuhengxu Feb 16, 2025
1970b09
fixing test bug
zuhengxu Feb 17, 2025
b0390f7
fix _value_and_grad wrapper bug
zuhengxu Feb 17, 2025
3b40c33
fix AutoReverseDiff argument typo
zuhengxu Feb 17, 2025
9552de1
minor ed
zuhengxu Feb 17, 2025
104e5dd
minor ed
zuhengxu Feb 17, 2025
d1ec834
fixing test
zuhengxu Feb 17, 2025
deba738
minor ed
zuhengxu Feb 17, 2025
8976307
rm test for mooncake
zuhengxu Feb 17, 2025
cb3db4a
fix doc
zuhengxu Feb 17, 2025
1615009
chagne CI
zuhengxu Feb 17, 2025
690f754
Merge branch 'main' into diffinterface
yebai Feb 20, 2025
1166e6b
update CI
sunxd3 Feb 26, 2025
906d788
streamline project toml
sunxd3 Feb 26, 2025
0d6302b
Apply suggestions from code review
sunxd3 Feb 26, 2025
63321c0
add enzyme to tests
sunxd3 Feb 26, 2025
7204391
add Enzyme to using list
sunxd3 Feb 26, 2025
77b9a2e
fixing enzyme readonly error by wrapping loss in Const
zuhengxu Mar 3, 2025
9a8ed04
mv enzyme related edits to ext/ and fix tests
zuhengxu Mar 3, 2025
b3487b5
fixing extension loading error
zuhengxu Mar 4, 2025
45756e6
Update Project.toml
zuhengxu Mar 4, 2025
dbe725c
remove Requires
zuhengxu Mar 4, 2025
da8593c
remove explit load ext
zuhengxu Mar 4, 2025
da66996
Update src/objectives/loglikelihood.jl
zuhengxu Mar 4, 2025
3e65cde
make ext dep explicit
zuhengxu Mar 4, 2025
eeb9a92
rm empty argument specialization for _prep_grad and _value_grad
zuhengxu Mar 4, 2025
4b97585
signal empty rng arg
zuhengxu Mar 4, 2025
9075266
drop Requires
zuhengxu Mar 4, 2025
eade7a3
drop Requires
zuhengxu Mar 4, 2025
91202ff
update test to include mooncake
zuhengxu Mar 4, 2025
dcee3c0
rm unnecessary EnzymeCoreExt
zuhengxu Mar 5, 2025
edf2f12
minor update of readme
zuhengxu Mar 5, 2025
9fc328d
typo fix in readme
zuhengxu Mar 5, 2025
41d10ef
Update src/NormalizingFlows.jl
zuhengxu Mar 5, 2025
db4872b
Update src/NormalizingFlows.jl
zuhengxu Mar 5, 2025
0fe536f
rm time_elapsed from train_flow
zuhengxu Mar 5, 2025
fc935ce
Update docs/src/api.md
yebai Mar 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
matrix:
version:
- '1'
- '1.6'
- '1.10'
os:
- ubuntu-latest
arch:
Expand Down
23 changes: 10 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
name = "NormalizingFlows"
uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
version = "0.1.1"
version = "0.1.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -16,35 +16,32 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
NormalizingFlowsEnzymeExt = "Enzyme"
NormalizingFlowsForwardDiffExt = "ForwardDiff"
NormalizingFlowsReverseDiffExt = "ReverseDiff"
NormalizingFlowsZygoteExt = "Zygote"

[compat]
ADTypes = "0.1, 0.2, 1"
Bijectors = "0.12.6, 0.13, 0.14"
DiffResults = "1"
Bijectors = "0.12.6, 0.13, 0.14, 0.15"
DifferentiationInterface = "0.6"
Distributions = "0.25"
DocStringExtensions = "0.9"
Mooncake = "0.4.95"
Enzyme = "0.11, 0.12, 0.13"
ForwardDiff = "0.10.25"
Optimisers = "0.2.16, 0.3"
Optimisers = "0.2.16, 0.3, 0.4"
ProgressMeter = "1.0.0"
Requires = "1"
ReverseDiff = "1.14"
StatsBase = "0.33, 0.34"
Zygote = "0.6"
julia = "1.6"
Zygote = "0.6, 0.7"
julia = "1.10"

[extras]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
11 changes: 2 additions & 9 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ For example of Gaussian VI, we can construct the flow as follows:
```@julia
using Distributions, Bijectors
T= Float32
@leaf MvNormal # to prevent params in q₀ from being optimized
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
flow = Bijectors.transformed(q₀, Bijectors.Shift(zeros(T,2)) ∘ Bijectors.Scale(ones(T, 2)))
```
Expand All @@ -23,7 +24,7 @@ To train the Gaussian VI targeting at distirbution $p$ via ELBO maiximization, w
using NormalizingFlows
sample_per_iter = 10
flow_trained, stats, _ = train_flow(
flow_trained, stats, _ , _ = train_flow(
elbo,
flow,
logp,
Expand Down Expand Up @@ -83,11 +84,3 @@ NormalizingFlows.loglikelihood
```@docs
NormalizingFlows.optimize
```


## Utility Functions for Taking Gradient
```@docs
NormalizingFlows.grad!
NormalizingFlows.value_and_gradient!
```

7 changes: 5 additions & 2 deletions docs/src/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Here we used the `PlanarLayer()` from `Bijectors.jl` to construct a

```julia
using Bijectors, FunctionChains
using Functors

function create_planar_flow(n_layers::Int, q₀)
d = length(q₀)
Expand All @@ -45,7 +46,9 @@ function create_planar_flow(n_layers::Int, q₀)
end

# create a 20-layer planar flow
flow = create_planar_flow(20, MvNormal(zeros(Float32, 2), I))
@leaf MvNormal # to prevent params in q₀ from being optimized
q₀ = MvNormal(zeros(Float32, 2), I)
flow = create_planar_flow(20, q₀)
flow_untrained = deepcopy(flow) # keep a copy of the untrained flow for comparison
```
*Notice that here the flow layers are chained together using `fchain` function from [`FunctionChains.jl`](https://github.com/oschulz/FunctionChains.jl).
Expand Down Expand Up @@ -116,4 +119,4 @@ plot!(title = "Comparison of Trained and Untrained Flow", xlabel = "X", ylabel=

## Reference

- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning
- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning
25 changes: 0 additions & 25 deletions ext/NormalizingFlowsEnzymeExt.jl

This file was deleted.

28 changes: 0 additions & 28 deletions ext/NormalizingFlowsForwardDiffExt.jl

This file was deleted.

22 changes: 0 additions & 22 deletions ext/NormalizingFlowsReverseDiffExt.jl

This file was deleted.

23 changes: 0 additions & 23 deletions ext/NormalizingFlowsZygoteExt.jl

This file was deleted.

49 changes: 14 additions & 35 deletions src/NormalizingFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@ using Bijectors
using Optimisers
using LinearAlgebra, Random, Distributions, StatsBase
using ProgressMeter
using ADTypes, DiffResults
using ADTypes
using DifferentiationInterface

using DocStringExtensions

export train_flow, elbo, loglikelihood, value_and_gradient!

using ADTypes
using DiffResults
export train_flow, elbo, loglikelihood

"""
train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)
""" train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)

Train the given normalizing flow `flow` by calling `optimize`.

Expand Down Expand Up @@ -56,47 +53,29 @@ function train_flow(
# use FunctionChains instead of simple compositions to construct the flow when many flow layers are involved
# otherwise the compilation time for destructure will be too long
θ_flat, re = Optimisers.destructure(flow)

loss(θ, rng, args...) = -vo(rng, re(θ), args...)

# Normalizing flow training loop
θ_flat_trained, opt_stats, st = optimize(
rng,
θ_flat_trained, opt_stats, st, time_elapsed = optimize(
ADbackend,
vo,
loss,
θ_flat,
re,
args...;
re,
(rng, args...)...;
max_iters=max_iters,
optimiser=optimiser,
kwargs...,
)

flow_trained = re(θ_flat_trained)
return flow_trained, opt_stats, st
return flow_trained, opt_stats, st, time_elapsed
end

include("train.jl")


include("optimize.jl")
include("objectives.jl")

# optional dependencies
if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base
using Requires
end

# Question: should Exts be loaded here or in train.jl?
function __init__()
@static if !isdefined(Base, :get_extension)
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
"../ext/NormalizingFlowsForwardDiffExt.jl"
)
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
"../ext/NormalizingFlowsReverseDiffExt.jl"
)
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include(
"../ext/NormalizingFlowsEnzymeExt.jl"
)
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include(
"../ext/NormalizingFlowsZygoteExt.jl"
)
end
end
end
2 changes: 1 addition & 1 deletion src/objectives.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
include("objectives/elbo.jl")
include("objectives/loglikelihood.jl")
include("objectives/loglikelihood.jl") # not tested
2 changes: 1 addition & 1 deletion src/objectives/elbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ end

function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples)
return elbo(Random.default_rng(), flow, logp, n_samples)
end
end
7 changes: 5 additions & 2 deletions src/objectives/loglikelihood.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,32 @@
# training by minimizing forward KL (MLE)
####################################
"""
loglikelihood(flow::Bijectors.TransformedDistribution, xs::AbstractVecOrMat)
loglikelihood(rng, flow::Bijectors.TransformedDistribution, xs::AbstractVecOrMat)

Compute the log-likelihood for variational distribution flow at a batch of samples xs from
the target distribution p.

# Arguments
- `rng`: random number generator (empty argument, only needed to ensure the same signature as other variational objectives)
- `flow`: variational distribution to be trained. In particular
"flow = transformed(q₀, T::Bijectors.Bijector)",
q₀ is a reference distribution that one can easily sample and compute logpdf
- `xs`: samples from the target distribution p.

"""
function loglikelihood(
rng::AbstractRNG, # empty argument
flow::Bijectors.UnivariateTransformed, # variational distribution to be trained
xs::AbstractVector, # sample batch from target dist p
)
return mean(Base.Fix1(logpdf, flow), xs)
end

function loglikelihood(
rng::AbstractRNG, # empty argument
flow::Bijectors.MultivariateTransformed, # variational distribution to be trained
xs::AbstractMatrix, # sample batch from target dist p
)
llhs = map(x -> logpdf(flow, x), eachcol(xs))
return mean(llhs)
end
end
Loading
Loading