Skip to content

Commit 960a359

Browse files
committed
keep cleaning
1 parent 6056ef1 commit 960a359

File tree

5 files changed

+72
-96
lines changed

5 files changed

+72
-96
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
/docs/build/
44
test/Manifest.toml
55
example/Manifest.toml
6+
example/LocalPreferences.toml
67

78
# Files generated by invoking Julia with --code-coverage
89
*.jl.cov

example/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
2121
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2222
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
2323
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
24+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2425
NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
2526
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
2627
PlotlyJS = "f0f68f2c-4968-5e81-91da-67840de0976a"
@@ -36,3 +37,6 @@ TickTock = "9ff05d80-102d-5586-aa04-3a8bd1a90d20"
3637
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
3738
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3839
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
40+
41+
[extras]
42+
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"

example/common.jl

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -47,38 +47,38 @@ function compare_trained_and_untrained_flow(
4747
return p
4848
end
4949

50-
function check_trained_flow(
51-
flow_trained::Bijectors.MultivariateTransformed,
52-
true_dist::ContinuousMultivariateDistribution,
53-
n_samples::Int;
54-
kwargs...,
55-
)
56-
samples_trained = rand_batch(flow_trained, n_samples)
57-
samples_true = rand(true_dist, n_samples)
58-
59-
p = Plots.scatter(
60-
samples_true[1, :],
61-
samples_true[2, :];
62-
label="True Distribution",
63-
color=:green,
64-
markersize=2,
65-
alpha=0.5,
66-
)
67-
Plots.scatter!(
68-
p,
69-
samples_trained[1, :],
70-
samples_trained[2, :];
71-
label="Trained Flow",
72-
color=:red,
73-
markersize=2,
74-
alpha=0.5,
75-
)
76-
Plots.plot!(; kwargs...)
77-
78-
Plots.title!(p, "Trained HamFlow")
79-
80-
return p
81-
end
50+
# function check_trained_flow(
51+
# flow_trained::Bijectors.MultivariateTransformed,
52+
# true_dist::ContinuousMultivariateDistribution,
53+
# n_samples::Int;
54+
# kwargs...,
55+
# )
56+
# samples_trained = rand_batch(flow_trained, n_samples)
57+
# samples_true = rand(true_dist, n_samples)
58+
59+
# p = Plots.scatter(
60+
# samples_true[1, :],
61+
# samples_true[2, :];
62+
# label="True Distribution",
63+
# color=:green,
64+
# markersize=2,
65+
# alpha=0.5,
66+
# )
67+
# Plots.scatter!(
68+
# p,
69+
# samples_trained[1, :],
70+
# samples_trained[2, :];
71+
# label="Trained Flow",
72+
# color=:red,
73+
# markersize=2,
74+
# alpha=0.5,
75+
# )
76+
# Plots.plot!(; kwargs...)
77+
78+
# Plots.title!(p, "Trained HamFlow")
79+
80+
# return p
81+
# end
8282

8383
function create_flow(Ls, q₀)
8484
ts = fchain(Ls)
@@ -89,33 +89,33 @@ end
8989
# training function for InvertibleNetworks
9090
########################
9191

92-
function pm_next!(pm, stats::NamedTuple)
93-
return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)])
94-
end
92+
# function pm_next!(pm, stats::NamedTuple)
93+
# return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)])
94+
# end
9595

96-
function train_invertible_networks!(G, loss, data_loader, n_epoch, opt)
97-
max_iters = n_epoch * length(data_loader)
98-
prog = ProgressMeter.Progress(
99-
max_iters; desc="Training", barlen=31, showspeed=true, enabled=true
100-
)
96+
# function train_invertible_networks!(G, loss, data_loader, n_epoch, opt)
97+
# max_iters = n_epoch * length(data_loader)
98+
# prog = ProgressMeter.Progress(
99+
# max_iters; desc="Training", barlen=31, showspeed=true, enabled=true
100+
# )
101101

102-
nnls = []
102+
# nnls = []
103103

104-
# training loop
105-
time_elapsed = @elapsed for (i, xs) in enumerate(IterTools.ncycle(data_loader, n_epoch))
106-
ls = loss(G, xs) #sets gradients of G
104+
# # training loop
105+
# time_elapsed = @elapsed for (i, xs) in enumerate(IterTools.ncycle(data_loader, n_epoch))
106+
# ls = loss(G, xs) #sets gradients of G
107107

108-
push!(nnls, ls)
108+
# push!(nnls, ls)
109109

110-
grad_norm = 0
111-
for p in get_params(G)
112-
grad_norm += sum(abs2, p.grad)
113-
Flux.update!(opt, p.data, p.grad)
114-
end
115-
grad_norm = sqrt(grad_norm)
110+
# grad_norm = 0
111+
# for p in get_params(G)
112+
# grad_norm += sum(abs2, p.grad)
113+
# Flux.update!(opt, p.data, p.grad)
114+
# end
115+
# grad_norm = sqrt(grad_norm)
116116

117-
stat = (iteration=i, neg_log_llh=ls, gradient_norm=grad_norm)
118-
pm_next!(prog, stat)
119-
end
120-
return nnls
121-
end
117+
# stat = (iteration=i, neg_log_llh=ls, gradient_norm=grad_norm)
118+
# pm_next!(prog, stat)
119+
# end
120+
# return nnls
121+
# end

example/planar_radial_flow/planar_flow_main.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
using Random, Distributions, LinearAlgebra, Bijectors
22
using ADTypes
33
using Optimisers
4-
using FunctionChains
54
using NormalizingFlows
6-
using Zygote
5+
using Mooncake
6+
using CUDA
77
using Flux: f32
8+
using Flux
89
using Plots
9-
include("../common.jl")
10+
include("common.jl")
1011

1112
Random.seed!(123)
1213
rng = Random.default_rng()
14+
T = Float32
1315

1416
######################################
1517
# 2d Banana as the target distribution
1618
######################################
17-
include("../targets/banana.jl")
19+
include("targets/banana.jl")
1820

1921
# create target p
2022
p = Banana(2, 1.0f-1, 100.0f0)
@@ -26,13 +28,15 @@ logp = Base.Fix1(logpdf, p)
2628
function create_planar_flow(n_layers::Int, q₀)
2729
d = length(q₀)
2830
Ls = [f32(PlanarLayer(d)) for _ in 1:n_layers]
29-
ts = fchain(Ls)
31+
ts = reduce(, Ls)
3032
return transformed(q₀, ts)
3133
end
3234

3335
# create a 10-layer planar flow
34-
flow = create_planar_flow(20, MvNormal(zeros(Float32, 2), I))
35-
flow_untrained = deepcopy(flow)
36+
q0 = MvNormal(zeros(T, 2), ones(T, 2))
37+
flow = create_planar_flow(10, q0)
38+
39+
3640

3741
# train the flow
3842
sample_per_iter = 10

example/targets/banana.jl

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,6 @@ using Distributions, Random
22
using Plots
33
using IrrationalConstants
44

5-
"""
6-
Banana{T<:Real}
7-
8-
Multidimensional banana-shape distribution.
9-
10-
# Fields
11-
$(FIELDS)
12-
13-
# Explanation
14-
15-
The banana distribution is obtained by applying a transformation ϕ to a multivariate normal
16-
distribution ``\\mathcal{N}(0, \\text{diag}(var, 1, 1, …, 1))``. The transformation ϕ is defined as
17-
```math
18-
\phi(x_1, … , x_p) = (x_1, x_2 - B x_1^² + \text{var}*B, x_3, … , x_p)
19-
````
20-
which has a unit Jacobian determinant.
21-
22-
Hence the density "fb" of a p-dimensional banana distribution is given by
23-
```math
24-
fb(x_1, \dots, x_p) = \exp\left[ -\frac{1}{2}\frac{x_1^2}{\text{var}} -
25-
\frac{1}{2}(x_2 + B x_1^2 - \text{var}*B)^2 - \frac{1}{2}(x_3^2 + x_4^2 + \dots
26-
+ x_p^2) \right] / Z,
27-
```
28-
where "B" is the "banananicity" constant, determining the curvature of a banana, and
29-
``Z = \\sqrt{\\text{var} * (2\\pi)^p)}`` is the normalization constant.
30-
31-
32-
# Reference
33-
34-
Gareth O. Roberts and Jeffrey S. Rosenthal
35-
"Examples of Adaptive MCMC."
36-
Journal of computational and graphical statistics, Volume 18, Number 2 (2009): 349-367.
37-
"""
385
struct Banana{T<:Real} <: ContinuousMultivariateDistribution
396
"Dimension of the distribution, must be >= 2"
407
dim::Int # Dimension

0 commit comments

Comments
 (0)