Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
8db828f
make realnvp and nsf layers as part of the pkg
zuhengxu Jul 5, 2025
574e257
import Functors
zuhengxu Jul 5, 2025
cd63565
add fully connect nn constructor
zuhengxu Jul 13, 2025
e03213c
update realnvp default constructor
zuhengxu Jul 13, 2025
34a964e
minor typo fix
zuhengxu Jul 13, 2025
1c4118b
minor update of realnnvp constructor and add some doc
zuhengxu Jul 14, 2025
d6b86cb
fixing bug in Distribution types
zuhengxu Jul 14, 2025
aa2adeb
exclude nsf for now
zuhengxu Jul 14, 2025
00ca29c
minor ed in nsf
zuhengxu Jul 14, 2025
731e657
fix typo in realnvp
zuhengxu Jul 14, 2025
e4fa67b
add realnvp test
zuhengxu Jul 14, 2025
1dfebd1
export nsf layer
zuhengxu Jul 14, 2025
55fb607
update demos, debugging mooncake with elbo
zuhengxu Jul 14, 2025
a2f6fbe
add AD tests for realnvp elbo
zuhengxu Jul 14, 2025
e39b8a8
wip debug mooncake on coupling layers
zuhengxu Jul 23, 2025
84ce45f
found that bug revealed by mooncake 0.4.124
zuhengxu Jul 25, 2025
9ba0a3b
add compat mooncake v0.4.142, fixed the autograd error on nested struct
zuhengxu Jul 31, 2025
81000b9
add mooncake compat >= v0.4.142
zuhengxu Aug 2, 2025
4caae49
add nsf interface
zuhengxu Aug 3, 2025
cf2e674
fix a typo in elbo_batch signiture
zuhengxu Aug 3, 2025
7f9c382
rm redundant comments
zuhengxu Aug 3, 2025
9f0cbad
making target adapting to the chosen Floating type automatically
zuhengxu Aug 3, 2025
99a0fed
rm redundant flux dependencies
zuhengxu Aug 3, 2025
c4128fa
add new nsf implementation and demo; much faster than the original nsf
zuhengxu Aug 3, 2025
1f30b33
rm redundant flux from realnvp demo
zuhengxu Aug 4, 2025
2903b83
dump the previous nsf implementation
zuhengxu Aug 4, 2025
48bc3d3
add test for nsf
zuhengxu Aug 4, 2025
9494de1
add ad test for nsf
zuhengxu Aug 4, 2025
8f61fc9
fix typo in nsf test
zuhengxu Aug 4, 2025
977caaf
fix nsf test error regarding rand()
zuhengxu Aug 4, 2025
0b9e656
relax rtol for nsf invertibility error in FLoat32
zuhengxu Aug 4, 2025
48829ad
update doc
zuhengxu Aug 8, 2025
4e6bfbe
wip doc build erro
zuhengxu Aug 8, 2025
eb19664
updating docs
zuhengxu Aug 8, 2025
ae53100
update gha perms
zuhengxu Aug 8, 2025
b8b229b
minor ed on docs
zuhengxu Aug 8, 2025
2d5b9c5
update docs
zuhengxu Aug 9, 2025
8a04147
update readme
zuhengxu Aug 9, 2025
2431e57
incorpoerate comments from red-portal and sunxd3
zuhengxu Aug 20, 2025
4dec51a
fix test error
zuhengxu Aug 20, 2025
fb6b80b
add planar and radial flow; updating docs
zuhengxu Aug 20, 2025
862c6dc
fixed error in doc for creat_flow
zuhengxu Aug 20, 2025
dc73605
rm redundant comments
zuhengxu Aug 20, 2025
5753150
change Any[] to Flux.Dense[]
zuhengxu Aug 20, 2025
d9525e7
minor comment update
zuhengxu Aug 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
name = "NormalizingFlows"
uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
version = "0.2.1"
version = "0.2.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Expand All @@ -27,7 +29,9 @@ CUDA = "5"
DifferentiationInterface = "0.6, 0.7"
Distributions = "0.25"
DocStringExtensions = "0.9"
Flux = "0.16"
Functors = "0.5.2"
Optimisers = "0.2.16, 0.3, 0.4"
ProgressMeter = "1.0.0"
StatsBase = "0.33, 0.34"
julia = "1.10"
julia = "1.11"
6 changes: 6 additions & 0 deletions example/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
Expand All @@ -17,6 +19,10 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extras]
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"

[compat]
Mooncake = "0.4.142"
127 changes: 11 additions & 116 deletions example/demo_RealNVP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,114 +11,6 @@ using NormalizingFlows
include("SyntheticTargets.jl")
include("utils.jl")

##################################
# define affine coupling layer using Bijectors.jl interface
#################################
struct AffineCoupling <: Bijectors.Bijector
dim::Int
mask::Bijectors.PartitionMask
s::Flux.Chain
t::Flux.Chain
end

# let params track field s and t
@functor AffineCoupling (s, t)

function AffineCoupling(
dim::Int, # dimension of input
hdims::Int, # dimension of hidden units for s and t
mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
)
cdims = length(mask_idx) # dimension of parts used to construct coupling law
s = mlp3(cdims, hdims, cdims)
t = mlp3(cdims, hdims, cdims)
mask = PartitionMask(dim, mask_idx)
return AffineCoupling(dim, mask, s, t)
end

function Bijectors.transform(af::AffineCoupling, x::AbstractVecOrMat)
# partition vector using 'af.mask::PartitionMask`
x₁, x₂, x₃ = partition(af.mask, x)
y₁ = x₁ .* af.s(x₂) .+ af.t(x₂)
return combine(af.mask, y₁, x₂, x₃)
end

function (af::AffineCoupling)(x::AbstractArray)
return transform(af, x)
end

function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractVector)
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
logjac = sum(log ∘ abs, af.s(x_2)) # this is a scalar
return combine(af.mask, y_1, x_2, x_3), logjac
end

function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractMatrix)
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
logjac = sum(log ∘ abs, af.s(x_2); dims = 1) # 1 × size(x, 2)
return combine(af.mask, y_1, x_2, x_3), vec(logjac)
end


function Bijectors.with_logabsdet_jacobian(
iaf::Inverse{<:AffineCoupling}, y::AbstractVector
)
af = iaf.orig
# partition vector using `af.mask::PartitionMask`
y_1, y_2, y_3 = partition(af.mask, y)
# inverse transformation
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
logjac = -sum(log ∘ abs, af.s(y_2))
return combine(af.mask, x_1, y_2, y_3), logjac
end

function Bijectors.with_logabsdet_jacobian(
iaf::Inverse{<:AffineCoupling}, y::AbstractMatrix
)
af = iaf.orig
# partition vector using `af.mask::PartitionMask`
y_1, y_2, y_3 = partition(af.mask, y)
# inverse transformation
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
logjac = -sum(log ∘ abs, af.s(y_2); dims = 1)
return combine(af.mask, x_1, y_2, y_3), vec(logjac)
end

###################
# an equivalent definition of AffineCoupling using Bijectors.Coupling
# (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1)
###################

# struct AffineCoupling <: Bijectors.Bijector
# dim::Int
# mask::Bijectors.PartitionMask
# s::Flux.Chain
# t::Flux.Chain
# end

# # let params track field s and t
# @functor AffineCoupling (s, t)

# function AffineCoupling(dim, mask, s, t)
# return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask)
# end

# function AffineCoupling(
# dim::Int, # dimension of input
# hdims::Int, # dimension of hidden units for s and t
# mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
# )
# cdims = length(mask_idx) # dimension of parts used to construct coupling law
# s = mlp3(cdims, hdims, cdims)
# t = mlp3(cdims, hdims, cdims)
# mask = PartitionMask(dim, mask_idx)
# return AffineCoupling(dim, mask, s, t)
# end



##################################
# start demo
#################################
Expand All @@ -132,30 +24,33 @@ T = Float32
target = Banana(2, 1.0f0, 100.0f0)
logp = Base.Fix1(logpdf, target)


######################################
# learn the target using Affine coupling flow
######################################
@leaf MvNormal
q0 = MvNormal(zeros(T, 2), ones(T, 2))
q0 = MvNormal(zeros(T, 2), I)

d = 2
hdims = 32

# alternating the coupling layers
Ls = [AffineCoupling(d, hdims, [1]) ∘ AffineCoupling(d, hdims, [2]) for i in 1:3]
hdims = [16, 16]
nlayers = 3

flow = create_flow(Ls, q0)
# use NormalizingFlows.realnvp to create a RealNVP flow
flow = realnvp(q0, hdims, nlayers; paramtype=T)
flow_untrained = deepcopy(flow)


######################################
# start training
######################################
sample_per_iter = 64
sample_per_iter = 16

# callback function to log training progress
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
# TODO: now using AutoMooncake the example broke, but AutoZygote works, need to debug
adtype = ADTypes.AutoMooncake(; config = nothing)
# adtype = ADTypes.AutoZygote()

checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
flow_trained, stats, _ = train_flow(
rng,
Expand Down
99 changes: 1 addition & 98 deletions example/demo_neural_spline_flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,104 +11,6 @@ using NormalizingFlows
include("SyntheticTargets.jl")
include("utils.jl")

##################################
# define neural spline layer using Bijectors.jl interface
#################################
"""
Neural Rational quadratic Spline layer

# References
[1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
"""
struct NeuralSplineLayer{T,A<:Flux.Chain} <: Bijectors.Bijector
dim::Int # dimension of input
K::Int # number of knots
n_dims_transferred::Int # number of dimensions that are transformed
nn::A # networks that parmaterize the knots and derivatives
B::T # bound of the knots
mask::Bijectors.PartitionMask
end

function NeuralSplineLayer(
dim::T1, # dimension of input
hdims::T1, # dimension of hidden units for s and t
K::T1, # number of knots
B::T2, # bound of the knots
mask_idx::AbstractVector{<:Int}, # index of dimensione that one wants to apply transformations on
) where {T1<:Int,T2<:Real}
num_of_transformed_dims = length(mask_idx)
input_dims = dim - num_of_transformed_dims

# output dim of the NN
output_dims = (3K - 1)*num_of_transformed_dims
# one big mlp that outputs all the knots and derivatives for all the transformed dimensions
nn = mlp3(input_dims, hdims, output_dims)

mask = Bijectors.PartitionMask(dim, mask_idx)
return NeuralSplineLayer(dim, K, num_of_transformed_dims, nn, B, mask)
end

@functor NeuralSplineLayer (nn,)

# define forward and inverse transformation
"""
Build a rational quadratic spline from the nn output
Bijectors.jl has implemented the inverse and logabsdetjac for rational quadratic spline

we just need to map the nn output to the knots and derivatives of the RQS
"""
function instantiate_rqs(nsl::NeuralSplineLayer, x::AbstractVector)
K, B = nsl.K, nsl.B
nnoutput = reshape(nsl.nn(x), nsl.n_dims_transferred, :)
ws = @view nnoutput[:, 1:K]
hs = @view nnoutput[:, (K + 1):(2K)]
ds = @view nnoutput[:, (2K + 1):(3K - 1)]
return Bijectors.RationalQuadraticSpline(ws, hs, ds, B)
end

function Bijectors.transform(nsl::NeuralSplineLayer, x::AbstractVector)
x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x)
# instantiate rqs knots and derivatives
rqs = instantiate_rqs(nsl, x_2)
y_1 = Bijectors.transform(rqs, x_1)
return Bijectors.combine(nsl.mask, y_1, x_2, x_3)
end

function Bijectors.transform(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector)
nsl = insl.orig
y1, y2, y3 = partition(nsl.mask, y)
rqs = instantiate_rqs(nsl, y2)
x1 = Bijectors.transform(Inverse(rqs), y1)
return Bijectors.combine(nsl.mask, x1, y2, y3)
end

function (nsl::NeuralSplineLayer)(x::AbstractVector)
return Bijectors.transform(nsl, x)
end

# define logabsdetjac
function Bijectors.logabsdetjac(nsl::NeuralSplineLayer, x::AbstractVector)
x_1, x_2, _ = Bijectors.partition(nsl.mask, x)
rqs = instantiate_rqs(nsl, x_2)
logjac = logabsdetjac(rqs, x_1)
return logjac
end

function Bijectors.logabsdetjac(insl::Inverse{<:NeuralSplineLayer}, y::AbstractVector)
nsl = insl.orig
y1, y2, _ = partition(nsl.mask, y)
rqs = instantiate_rqs(nsl, y2)
logjac = logabsdetjac(Inverse(rqs), y1)
return logjac
end

function Bijectors.with_logabsdet_jacobian(nsl::NeuralSplineLayer, x::AbstractVector)
x_1, x_2, x_3 = Bijectors.partition(nsl.mask, x)
rqs = instantiate_rqs(nsl, x_2)
y_1, logjac = with_logabsdet_jacobian(rqs, x_1)
return Bijectors.combine(nsl.mask, y_1, x_2, x_3), logjac
end

##################################
# start demo
#################################
Expand Down Expand Up @@ -148,6 +50,7 @@ sample_per_iter = 64

# callback function to log training progress
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
# TODO: now using AutoMooncake the example broke, but AutoZygote works, need to debug
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
flow_trained, stats, _ = train_flow(
Expand Down
5 changes: 0 additions & 5 deletions example/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@ function mlp3(input_dim::Int, hidden_dims::Int, output_dim::Int; activation=Flux
)
end

function create_flow(Ls, q₀)
ts = reduce(∘, Ls)
return transformed(q₀, ts)
end

function compare_trained_and_untrained_flow(
flow_trained::Bijectors.MultivariateTransformed,
flow_untrained::Bijectors.MultivariateTransformed,
Expand Down
15 changes: 14 additions & 1 deletion src/NormalizingFlows.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
module NormalizingFlows

using ADTypes
using Bijectors
using Distributions
using LinearAlgebra
using Optimisers
using ProgressMeter
using Random
using StatsBase
using Bijectors
using Bijectors: PartitionMask, Inverse, combine, partition
using Functors
import DifferentiationInterface as DI

using DocStringExtensions
Expand Down Expand Up @@ -123,4 +125,15 @@ function _device_specific_rand(
return Random.rand(rng, td, n)
end


# interface of contructing common flow layers
include("flows/utils.jl")
include("flows/realnvp.jl")
include("flows/neuralspline.jl")

export create_flow
export RealNVP_layer, realnvp, AffineCoupling
export NeuralSplineLayer


end
Loading
Loading