Skip to content
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
6220e96
switch to differentiationinterface from diffresults
zuhengxu Feb 16, 2025
7b4fb85
rename train.jl to optimize.jl
zuhengxu Feb 16, 2025
5ff6041
fix some compat issue and bump version
zuhengxu Feb 16, 2025
3010669
update tests to new interface
zuhengxu Feb 16, 2025
f7ee84b
add Moonkcake to extras
zuhengxu Feb 16, 2025
e68ef5f
rm all ext for now
zuhengxu Feb 16, 2025
9bdf1f7
rm enzyme test, and import mooncake for test
zuhengxu Feb 16, 2025
b7f9f08
fixing compat and test with mooncake
zuhengxu Feb 16, 2025
1970b09
fixing test bug
zuhengxu Feb 17, 2025
b0390f7
fix _value_and_grad wrapper bug
zuhengxu Feb 17, 2025
3b40c33
fix AutoReverseDiff argument typo
zuhengxu Feb 17, 2025
9552de1
minor ed
zuhengxu Feb 17, 2025
104e5dd
minor ed
zuhengxu Feb 17, 2025
d1ec834
fixing test
zuhengxu Feb 17, 2025
deba738
minor ed
zuhengxu Feb 17, 2025
8976307
rm test for mooncake
zuhengxu Feb 17, 2025
cb3db4a
fix doc
zuhengxu Feb 17, 2025
1615009
chagne CI
zuhengxu Feb 17, 2025
690f754
Merge branch 'main' into diffinterface
yebai Feb 20, 2025
1166e6b
update CI
sunxd3 Feb 26, 2025
906d788
streamline project toml
sunxd3 Feb 26, 2025
0d6302b
Apply suggestions from code review
sunxd3 Feb 26, 2025
63321c0
add enzyme to tests
sunxd3 Feb 26, 2025
7204391
add Enzyme to using list
sunxd3 Feb 26, 2025
77b9a2e
fixing enzyme readonly error by wrapping loss in Const
zuhengxu Mar 3, 2025
9a8ed04
mv enzyme related edits to ext/ and fix tests
zuhengxu Mar 3, 2025
b3487b5
fixing extension loading error
zuhengxu Mar 4, 2025
45756e6
Update Project.toml
zuhengxu Mar 4, 2025
dbe725c
remove Requires
zuhengxu Mar 4, 2025
da8593c
remove explit load ext
zuhengxu Mar 4, 2025
da66996
Update src/objectives/loglikelihood.jl
zuhengxu Mar 4, 2025
3e65cde
make ext dep explicit
zuhengxu Mar 4, 2025
eeb9a92
rm empty argument specialization for _prep_grad and _value_grad
zuhengxu Mar 4, 2025
4b97585
signal empty rng arg
zuhengxu Mar 4, 2025
9075266
drop Requires
zuhengxu Mar 4, 2025
eade7a3
drop Requires
zuhengxu Mar 4, 2025
91202ff
update test to include mooncake
zuhengxu Mar 4, 2025
dcee3c0
rm unnecessary EnzymeCoreExt
zuhengxu Mar 5, 2025
edf2f12
minor update of readme
zuhengxu Mar 5, 2025
9fc328d
typo fix in readme
zuhengxu Mar 5, 2025
41d10ef
Update src/NormalizingFlows.jl
zuhengxu Mar 5, 2025
db4872b
Update src/NormalizingFlows.jl
zuhengxu Mar 5, 2025
0fe536f
rm time_elapsed from train_flow
zuhengxu Mar 5, 2025
fc935ce
Update docs/src/api.md
yebai Mar 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
name: CI

on:
push:
branches:
- main
tags: ['*']
pull_request:

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
Expand All @@ -19,17 +22,17 @@ jobs:
matrix:
version:
- '1'
- '1.6'
- 'min'
os:
- ubuntu-latest
arch:
- x64
steps:
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
35 changes: 10 additions & 25 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,50 +1,35 @@
name = "NormalizingFlows"
uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
version = "0.1.1"
version = "0.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
NormalizingFlowsEnzymeExt = "Enzyme"
NormalizingFlowsForwardDiffExt = "ForwardDiff"
NormalizingFlowsReverseDiffExt = "ReverseDiff"
NormalizingFlowsZygoteExt = "Zygote"
NormalizingFlowsEnzymeCoreExt = ["EnzymeCore", "ADTypes", "DifferentiationInterface"]

[compat]
ADTypes = "0.1, 0.2, 1"
Bijectors = "0.12.6, 0.13, 0.14"
DiffResults = "1"
ADTypes = "1"
Bijectors = "0.12.6, 0.13, 0.14, 0.15"
DifferentiationInterface = "0.6.42"
Distributions = "0.25"
DocStringExtensions = "0.9"
Enzyme = "0.11, 0.12, 0.13"
ForwardDiff = "0.10.25"
Optimisers = "0.2.16, 0.3"
Optimisers = "0.2.16, 0.3, 0.4"
ProgressMeter = "1.0.0"
Requires = "1"
ReverseDiff = "1.14"
StatsBase = "0.33, 0.34"
Zygote = "0.6"
julia = "1.6"
julia = "1.10"

[extras]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
11 changes: 2 additions & 9 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ For example of Gaussian VI, we can construct the flow as follows:
```@julia
using Distributions, Bijectors
T= Float32
@leaf MvNormal # to prevent params in q₀ from being optimized
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
flow = Bijectors.transformed(q₀, Bijectors.Shift(zeros(T,2)) ∘ Bijectors.Scale(ones(T, 2)))
```
Expand All @@ -23,7 +24,7 @@ To train the Gaussian VI targeting at distirbution $p$ via ELBO maiximization, w
using NormalizingFlows

sample_per_iter = 10
flow_trained, stats, _ = train_flow(
flow_trained, stats, _ , _ = train_flow(
elbo,
flow,
logp,
Expand Down Expand Up @@ -83,11 +84,3 @@ NormalizingFlows.loglikelihood
```@docs
NormalizingFlows.optimize
```


## Utility Functions for Taking Gradient
```@docs
NormalizingFlows.grad!
NormalizingFlows.value_and_gradient!
```

7 changes: 5 additions & 2 deletions docs/src/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Here we used the `PlanarLayer()` from `Bijectors.jl` to construct a

```julia
using Bijectors, FunctionChains
using Functors

function create_planar_flow(n_layers::Int, q₀)
d = length(q₀)
Expand All @@ -45,7 +46,9 @@ function create_planar_flow(n_layers::Int, q₀)
end

# create a 20-layer planar flow
flow = create_planar_flow(20, MvNormal(zeros(Float32, 2), I))
@leaf MvNormal # to prevent params in q₀ from being optimized
q₀ = MvNormal(zeros(Float32, 2), I)
flow = create_planar_flow(20, q₀)
flow_untrained = deepcopy(flow) # keep a copy of the untrained flow for comparison
```
*Notice that here the flow layers are chained together using `fchain` function from [`FunctionChains.jl`](https://github.com/oschulz/FunctionChains.jl).
Expand Down Expand Up @@ -116,4 +119,4 @@ plot!(title = "Comparison of Trained and Untrained Flow", xlabel = "X", ylabel=

## Reference

- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning
- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning
29 changes: 29 additions & 0 deletions ext/NormalizingFlowsEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
module NormalizingFlowsEnzymeCoreExt

using EnzymeCore
using NormalizingFlows
using NormalizingFlows: ADTypes, DifferentiationInterface

# deal with Enzyme readonly error: see https://discourse.julialang.org/t/enzyme-autodiff-readonly-error-and-working-with-batches-of-data/123012
function NormalizingFlows._prepare_gradient(loss, adbackend::ADTypes.AutoEnzyme, θ, args...)
return DifferentiationInterface.prepare_gradient(
EnzymeCore.Const(loss),
adbackend,
θ,
map(DifferentiationInterface.Constant, args)...,
)
end

function NormalizingFlows._value_and_gradient(
loss, prep, adbackend::ADTypes.AutoEnzyme, θ, args...
)
return DifferentiationInterface.value_and_gradient(
EnzymeCore.Const(loss),
prep,
adbackend,
θ,
map(DifferentiationInterface.Constant, args)...,
)
end

end
25 changes: 0 additions & 25 deletions ext/NormalizingFlowsEnzymeExt.jl

This file was deleted.

28 changes: 0 additions & 28 deletions ext/NormalizingFlowsForwardDiffExt.jl

This file was deleted.

22 changes: 0 additions & 22 deletions ext/NormalizingFlowsReverseDiffExt.jl

This file was deleted.

23 changes: 0 additions & 23 deletions ext/NormalizingFlowsZygoteExt.jl

This file was deleted.

45 changes: 11 additions & 34 deletions src/NormalizingFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@ using Bijectors
using Optimisers
using LinearAlgebra, Random, Distributions, StatsBase
using ProgressMeter
using ADTypes, DiffResults
using ADTypes
using DifferentiationInterface

using DocStringExtensions

export train_flow, elbo, loglikelihood, value_and_gradient!

using ADTypes
using DiffResults
export train_flow, elbo, loglikelihood

"""
"""
train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)

Train the given normalizing flow `flow` by calling `optimize`.
Expand Down Expand Up @@ -57,46 +55,25 @@ function train_flow(
# otherwise the compilation time for destructure will be too long
θ_flat, re = Optimisers.destructure(flow)

loss(θ, rng, args...) = -vo(rng, re(θ), args...)

# Normalizing flow training loop
θ_flat_trained, opt_stats, st = optimize(
rng,
θ_flat_trained, opt_stats, st, time_elapsed = optimize(
ADbackend,
vo,
loss,
θ_flat,
re,
args...;
(rng, args...)...;
max_iters=max_iters,
optimiser=optimiser,
kwargs...,
)

flow_trained = re(θ_flat_trained)
return flow_trained, opt_stats, st
return flow_trained, opt_stats, st, time_elapsed
end

include("train.jl")
include("optimize.jl")
include("objectives.jl")

# optional dependencies
if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base
using Requires
end

# Question: should Exts be loaded here or in train.jl?
function __init__()
@static if !isdefined(Base, :get_extension)
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
"../ext/NormalizingFlowsForwardDiffExt.jl"
)
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
"../ext/NormalizingFlowsReverseDiffExt.jl"
)
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include(
"../ext/NormalizingFlowsEnzymeExt.jl"
)
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include(
"../ext/NormalizingFlowsZygoteExt.jl"
)
end
end
end
2 changes: 1 addition & 1 deletion src/objectives.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
include("objectives/elbo.jl")
include("objectives/loglikelihood.jl")
include("objectives/loglikelihood.jl") # not fully tested
2 changes: 1 addition & 1 deletion src/objectives/elbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ end

function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples)
return elbo(Random.default_rng(), flow, logp, n_samples)
end
end
Loading
Loading