-
Notifications
You must be signed in to change notification settings - Fork 5
more NF examples #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
more NF examples #11
Changes from 123 commits
Commits
Show all changes
139 commits
Select commit
Hold shift + click to select a range
4f591af
add support for hasconverged
zuhengxu 4e31807
fix test error
zuhengxu 99ac0c1
rm example/Manifest.toml
zuhengxu e4f5efa
minor bug fix for trainig loop
zuhengxu b345e78
test new stopping criterion
zuhengxu f10512f
test convergent condition/ rm unready examples
zuhengxu b93dbe8
Merge branch 'TuringLang:main' into hasconverge
zuhengxu 1c1c88a
rm julia test from CI
zuhengxu 5d2844f
Revert "rm julia test from CI"
zuhengxu 3fdcb0e
make autodiff pkgs as extension + require for bwd compat
zuhengxu ddad59e
debugging Ext
zuhengxu ef60ee1
keep debugging
zuhengxu 5a5deb0
Fix AD package extension loading issues
sunxd3 25d4211
Applying @devmotion's comment
sunxd3 df5eddd
patch last commit
sunxd3 d44143b
patch for julia 1.6
sunxd3 b03e922
loading dep pkgs from main pkg instead of functions for explicitness
zuhengxu e9acf70
fixing test err
zuhengxu c6bf68b
rm unready examples
zuhengxu fabf20a
update realnvp
zuhengxu 7964f91
minor ed
zuhengxu 01de9d4
removing unnecessary import
zuhengxu 6993d33
refactor affinecoupling and example/
zuhengxu 810881f
debug affinecoupling flow
zuhengxu 0277c9e
adapt to the updated autoforwarddiff to resolve test err
zuhengxu 4bc02bf
fix test err
zuhengxu 144c668
add new implementation of affcoupling using Bijectors.Coupling
zuhengxu 2678198
implement ham flow
zuhengxu ab7ac64
finish hamflow implementation
zuhengxu 308ab61
minor update
zuhengxu ceafbde
rename hamflow.jl to hamiltonian_layer.jl
zuhengxu 5589f36
upadting readme
zuhengxu de01e0e
rm hamflow.jl
zuhengxu 93a8572
Merge branch 'main' of github.com:zuhengxu/NormalizingFlows.jl into m…
zuhengxu b8a8f5f
sync with main
zuhengxu c858b10
fix minor bugs in affine coupling layer
zuhengxu baef6a9
test affine coupling flow on banana
zuhengxu 0323680
rename simple flow run files
zuhengxu 144b06a
update loglikelihood to fit in optimize interface
zuhengxu f70a4b7
fix minor bugs in nsf_layer
zuhengxu 30cfc32
rm unused data in nsf lfow
zuhengxu 00387d3
rm @view to avoid zygote mutation error
zuhengxu 2cacfc6
update optimize for direct minibatch training
zuhengxu 24b1564
add BatchNorm for real nvp and make real nvp compatible with batching
zuhengxu f720eab
testing
zuhengxu 880976c
update resblock
zuhengxu f6e4189
update resnet arch
zuhengxu 0a78963
minor update
zuhengxu 7c953c2
Merge branch 'stability' of github.com:zuhengxu/NormalizingFlows.jl i…
zuhengxu e41c3f6
udpate
zuhengxu 91cf585
fix affine bugs
zuhengxu c523f94
add invertiblemlp
zuhengxu 1f67b10
start testing invertiblenetworks
zuhengxu 6cde9cc
minor update
zuhengxu eada0e6
minor bug fix
zuhengxu 8255949
update MLP to make it work
zuhengxu 37bd3ab
fix type instability in invertible MLP
zuhengxu a67a6e2
update deep mlp
zuhengxu d6c6a2b
update deep hamflow result
zuhengxu 799343a
update flows
zuhengxu b908a80
add shadowing window computation
zuhengxu 8bde6bf
minor ed
zuhengxu eaf1775
refactoring training and setup files
zuhengxu 519fd66
minor setup change to allow convenient precision switch
zuhengxu 77816fd
get MLP figs and res'
zuhengxu 8db13c9
obtain figs for ham
zuhengxu 4a656e8
rm duplication
zuhengxu 9f68fdd
add some MLP result
zuhengxu 8c703d5
add shf res
zuhengxu 9e27814
add mom normalization layer implementation
zuhengxu c63abfc
update shf res
zuhengxu 61862ff
add script for deep shf
zuhengxu 4fe363f
update shf_big res
zuhengxu a9d46bf
minor update
zuhengxu 8bd27ee
update log reg and new shf res with q0 being std normal
zuhengxu 156faeb
minor bug update
zuhengxu 7e7e782
update sfh banana figs
zuhengxu 8be77ad
add some lip constant figs for shf banana
zuhengxu 0381259
add some lip constant reg
zuhengxu 37835c8
update shf banana/cross
zuhengxu 2f15b1b
update shf/stab
zuhengxu 2ae4f27
update banana plots
zuhengxu 3e453e5
update banana res
zuhengxu ed088eb
fix conflict
zuhengxu 7db8236
update cross res
zuhengxu 167af8e
update cross res
zuhengxu d491fee
update
zuhengxu 1aaad46
update cross
zuhengxu fa80afb
update
zuhengxu 3ffe162
Merge branch 'stability' of github.com:zuhengxu/NormalizingFlows.jl i…
zuhengxu e25f31b
update cross elbo res
zuhengxu 4e78a5a
update shadowing res
zuhengxu 7704896
udpate banana and cross res
zuhengxu b4b4aaf
update banana figs
zuhengxu 9d4f6ba
update cross res
zuhengxu b7d7b67
update cross figs and res
zuhengxu 3b148d7
update cross lp scaling figs
zuhengxu ae35d63
minor update
zuhengxu ae9b9f2
new cross figs
zuhengxu deaf04a
fix conflict
zuhengxu 2dd855a
rm readme
zuhengxu b30d961
update
zuhengxu aeda365
Merge branch 'stability' into more_examples
zuhengxu 2c4d793
rm many unrelated code
zuhengxu 6056ef1
merge conflict from upstream
zuhengxu 960a359
keep cleaning
zuhengxu 51bc5b9
merge from turing/main
zuhengxu f5a2f9a
fix warpgaussian logpdf error/making neal funnel logpdf working with …
zuhengxu d916b64
add easier model loading code
zuhengxu d84fce4
restructure example folder and refactor planar and radial examples
zuhengxu 6872760
clean demos for realnvp/planar/radial/ fix a bug in nsf
zuhengxu e3fb428
rm enzyme dependency
zuhengxu ddf0ba9
tune nsf flow--enlarge B--to make it work
zuhengxu f2affe2
rm some useless files from HamVI
zuhengxu 78d5f45
rename MLP_3layer to mlp3 for convenience
zuhengxu 2229577
rename common.jl to utils
zuhengxu aefc651
minor update file naming, and better nsf implementation
zuhengxu 7f1580c
minor ed
zuhengxu 750c1e1
rm redundant nsf file
zuhengxu ab48d11
shrink stepsize for nsf
zuhengxu 42bc4a1
cleaned hamiltonian flow
zuhengxu edf1bac
minor ed
zuhengxu c05789b
rm redundant pkgs
zuhengxu 228d0e8
MLutils and itertools from dependencies
zuhengxu 6b1500c
rm useless example utils functions
zuhengxu e257fab
add elbo_batch implementation; much faster for invertible NN based flows
zuhengxu 17c8c5d
rm simple unpack from ham flows
zuhengxu c57becd
rm simple unpack dependency from example env
zuhengxu 738c2dd
bump version
zuhengxu 0672626
add test for elbo_batch
zuhengxu 427225e
use elbo_batch for real_nvp; achieved 4+ times speed up
zuhengxu c77dc2d
rm load_model and use constructor directly
zuhengxu a18bb12
fix doc building error
zuhengxu 2bd914d
update DI bounds
zuhengxu cc3c761
fix cuda test error
zuhengxu 06e7651
reduce the iter number
zuhengxu 5508a9f
add CI group for exampels
zuhengxu 4bf8fe1
update examples CI
zuhengxu 5f2a8da
fixing CI bugs
zuhengxu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
using DocStringExtensions | ||
using Distributions, Random, LinearAlgebra | ||
using IrrationalConstants | ||
using Plots | ||
|
||
|
||
include("targets/banana.jl") | ||
include("targets/cross.jl") | ||
include("targets/neal_funnel.jl") | ||
include("targets/warped_gaussian.jl") | ||
|
||
|
||
function load_model(name::String) | ||
if name == "Banana" | ||
return Banana(2, 1.0, 10.0) | ||
elseif name == "Cross" | ||
return Cross() | ||
elseif name == "Funnel" | ||
return Funnel(2) | ||
elseif name == "WarpedGaussian" | ||
return WarpedGauss() | ||
else | ||
error("Model not defined") | ||
end | ||
end | ||
zuhengxu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
function visualize(p::ContinuousMultivariateDistribution, samples=rand(p, 1000)) | ||
xrange = range(minimum(samples[1, :]) - 1, maximum(samples[1, :]) + 1; length=100) | ||
yrange = range(minimum(samples[2, :]) - 1, maximum(samples[2, :]) + 1; length=100) | ||
z = [exp(Distributions.logpdf(p, [x, y])) for x in xrange, y in yrange] | ||
fig = contour(xrange, yrange, z'; levels=15, color=:viridis, label="PDF", linewidth=2) | ||
scatter!(samples[1, :], samples[2, :]; label="Samples", alpha=0.3, legend=:bottomright) | ||
return fig | ||
end |
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
using Flux | ||
using Bijectors | ||
using Bijectors: partition, combine, PartitionMask | ||
|
||
using Random, Distributions, LinearAlgebra | ||
using Functors | ||
using Optimisers, ADTypes | ||
using Mooncake | ||
using NormalizingFlows | ||
|
||
include("SyntheticTargets.jl") | ||
include("utils.jl") | ||
|
||
################################## | ||
# define affine coupling layer using Bijectors.jl interface | ||
################################# | ||
struct AffineCoupling <: Bijectors.Bijector | ||
dim::Int | ||
mask::Bijectors.PartitionMask | ||
s::Flux.Chain | ||
t::Flux.Chain | ||
end | ||
|
||
# let params track field s and t | ||
@functor AffineCoupling (s, t) | ||
|
||
function AffineCoupling( | ||
dim::Int, # dimension of input | ||
hdims::Int, # dimension of hidden units for s and t | ||
mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on | ||
) | ||
cdims = length(mask_idx) # dimension of parts used to construct coupling law | ||
s = mlp3(cdims, hdims, cdims) | ||
t = mlp3(cdims, hdims, cdims) | ||
mask = PartitionMask(dim, mask_idx) | ||
return AffineCoupling(dim, mask, s, t) | ||
end | ||
|
||
function Bijectors.transform(af::AffineCoupling, x::AbstractVector) | ||
# partition vector using 'af.mask::PartitionMask` | ||
x₁, x₂, x₃ = partition(af.mask, x) | ||
y₁ = x₁ .* af.s(x₂) .+ af.t(x₂) | ||
return combine(af.mask, y₁, x₂, x₃) | ||
end | ||
|
||
function (af::AffineCoupling)(x::AbstractArray) | ||
return transform(af, x) | ||
end | ||
|
||
function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractVector) | ||
x_1, x_2, x_3 = Bijectors.partition(af.mask, x) | ||
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2) | ||
logjac = sum(log ∘ abs, af.s(x_2)) | ||
return combine(af.mask, y_1, x_2, x_3), logjac | ||
end | ||
|
||
function Bijectors.with_logabsdet_jacobian( | ||
iaf::Inverse{<:AffineCoupling}, y::AbstractVector | ||
) | ||
af = iaf.orig | ||
# partition vector using `af.mask::PartitionMask` | ||
y_1, y_2, y_3 = partition(af.mask, y) | ||
# inverse transformation | ||
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2) | ||
logjac = -sum(log ∘ abs, af.s(y_2)) | ||
return combine(af.mask, x_1, y_2, y_3), logjac | ||
end | ||
|
||
function Bijectors.logabsdetjac(af::AffineCoupling, x::AbstractVector) | ||
_, x_2, _ = partition(af.mask, x) | ||
logjac = sum(log ∘ abs, af.s(x_2)) | ||
return logjac | ||
end | ||
|
||
################### | ||
# an equivalent definition of AffineCoupling using Bijectors.Coupling | ||
# (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1) | ||
################### | ||
|
||
# struct AffineCoupling <: Bijectors.Bijector | ||
# dim::Int | ||
# mask::Bijectors.PartitionMask | ||
# s::Flux.Chain | ||
# t::Flux.Chain | ||
# end | ||
|
||
# # let params track field s and t | ||
# @functor AffineCoupling (s, t) | ||
|
||
# function AffineCoupling(dim, mask, s, t) | ||
# return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask) | ||
# end | ||
|
||
# function AffineCoupling( | ||
# dim::Int, # dimension of input | ||
# hdims::Int, # dimension of hidden units for s and t | ||
# mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on | ||
# ) | ||
# cdims = length(mask_idx) # dimension of parts used to construct coupling law | ||
# s = mlp3(cdims, hdims, cdims) | ||
# t = mlp3(cdims, hdims, cdims) | ||
# mask = PartitionMask(dim, mask_idx) | ||
# return AffineCoupling(dim, mask, s, t) | ||
# end | ||
|
||
|
||
|
||
################################## | ||
# start demo | ||
################################# | ||
Random.seed!(123) | ||
rng = Random.default_rng() | ||
T = Float32 | ||
|
||
###################################### | ||
# a difficult banana target | ||
###################################### | ||
target = Banana(2, 1.0f0, 100.0f0) | ||
logp = Base.Fix1(logpdf, target) | ||
|
||
###################################### | ||
# learn the target using Affine coupling flow | ||
###################################### | ||
@leaf MvNormal | ||
q0 = MvNormal(zeros(T, 2), ones(T, 2)) | ||
|
||
d = 2 | ||
hdims = 32 | ||
Ls = [AffineCoupling(d, hdims, [1]) ∘ AffineCoupling(d, hdims, [2]) for i in 1:3] | ||
|
||
flow = create_flow(Ls, q0) | ||
flow_untrained = deepcopy(flow) | ||
|
||
|
||
###################################### | ||
# start training | ||
###################################### | ||
sample_per_iter = 64 | ||
|
||
# callback function to log training progress | ||
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) | ||
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) | ||
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 | ||
flow_trained, stats, _ = train_flow( | ||
elbo, | ||
flow, | ||
logp, | ||
sample_per_iter; | ||
max_iters=50_000, | ||
optimiser=Optimisers.Adam(5e-4), | ||
ADbackend=adtype, | ||
show_progress=true, | ||
callback=cb, | ||
hasconverged=checkconv, | ||
) | ||
θ, re = Optimisers.destructure(flow_trained) | ||
losses = map(x -> x.loss, stats) | ||
|
||
###################################### | ||
# evaluate trained flow | ||
###################################### | ||
plot(losses; label="Loss", linewidth=2) # plot the loss | ||
compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.