Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
139 commits
Select commit Hold shift + click to select a range
4f591af
add support for hasconverged
zuhengxu Jul 11, 2023
4e31807
fix test error
zuhengxu Jul 11, 2023
99ac0c1
rm example/Manifest.toml
zuhengxu Jul 11, 2023
e4f5efa
minor bug fix for trainig loop
zuhengxu Jul 11, 2023
b345e78
test new stopping criterion
zuhengxu Jul 11, 2023
f10512f
test convergent condition/ rm unready examples
zuhengxu Jul 11, 2023
b93dbe8
Merge branch 'TuringLang:main' into hasconverge
zuhengxu Jul 11, 2023
1c1c88a
rm julia test from CI
zuhengxu Jul 11, 2023
5d2844f
Revert "rm julia test from CI"
zuhengxu Jul 11, 2023
3fdcb0e
make autodiff pkgs as extension + require for bwd compat
zuhengxu Jul 12, 2023
ddad59e
debugging Ext
zuhengxu Jul 12, 2023
ef60ee1
keep debugging
zuhengxu Jul 13, 2023
5a5deb0
Fix AD package extension loading issues
sunxd3 Jul 13, 2023
25d4211
Applying @devmotion's comment
sunxd3 Jul 13, 2023
df5eddd
patch last commit
sunxd3 Jul 13, 2023
d44143b
patch for julia 1.6
sunxd3 Jul 13, 2023
b03e922
loading dep pkgs from main pkg instead of functions for explicitness
zuhengxu Jul 13, 2023
e9acf70
fixing test err
zuhengxu Jul 13, 2023
c6bf68b
rm unready examples
zuhengxu Jul 13, 2023
fabf20a
update realnvp
zuhengxu Jul 24, 2023
7964f91
minor ed
zuhengxu Jul 24, 2023
01de9d4
removing unnecessary import
zuhengxu Jul 24, 2023
6993d33
refactor affinecoupling and example/
zuhengxu Jul 31, 2023
810881f
debug affinecoupling flow
zuhengxu Jul 31, 2023
0277c9e
adapt to the updated autoforwarddiff to resolve test err
zuhengxu Jul 31, 2023
4bc02bf
fix test err
zuhengxu Jul 31, 2023
144c668
add new implementation of affcoupling using Bijectors.Coupling
zuhengxu Jul 31, 2023
2678198
implement ham flow
zuhengxu Aug 1, 2023
ab7ac64
finish hamflow implementation
zuhengxu Aug 1, 2023
308ab61
minor update
zuhengxu Aug 1, 2023
ceafbde
rename hamflow.jl to hamiltonian_layer.jl
zuhengxu Aug 3, 2023
5589f36
upadting readme
zuhengxu Aug 3, 2023
de01e0e
rm hamflow.jl
zuhengxu Aug 3, 2023
93a8572
Merge branch 'main' of github.com:zuhengxu/NormalizingFlows.jl into m…
zuhengxu Aug 8, 2023
b8a8f5f
sync with main
zuhengxu Aug 9, 2023
c858b10
fix minor bugs in affine coupling layer
zuhengxu Aug 16, 2023
baef6a9
test affine coupling flow on banana
zuhengxu Aug 16, 2023
0323680
rename simple flow run files
zuhengxu Aug 16, 2023
144b06a
update loglikelihood to fit in optimize interface
zuhengxu Aug 16, 2023
f70a4b7
fix minor bugs in nsf_layer
zuhengxu Aug 17, 2023
30cfc32
rm unused data in nsf lfow
zuhengxu Aug 17, 2023
00387d3
rm @view to avoid zygote mutation error
zuhengxu Aug 17, 2023
2cacfc6
update optimize for direct minibatch training
zuhengxu Sep 28, 2023
24b1564
add BatchNorm for real nvp and make real nvp compatible with batching
zuhengxu Sep 28, 2023
f720eab
testing
zuhengxu Sep 29, 2023
880976c
update resblock
zuhengxu Oct 4, 2023
f6e4189
update resnet arch
zuhengxu Oct 5, 2023
0a78963
minor update
zuhengxu Oct 5, 2023
7c953c2
Merge branch 'stability' of github.com:zuhengxu/NormalizingFlows.jl i…
zuhengxu Oct 5, 2023
e41c3f6
udpate
zuhengxu Oct 5, 2023
91cf585
fix affine bugs
zuhengxu Oct 5, 2023
c523f94
add invertiblemlp
zuhengxu Oct 5, 2023
1f67b10
start testing invertiblenetworks
zuhengxu Oct 5, 2023
6cde9cc
minor update
zuhengxu Oct 5, 2023
eada0e6
minor bug fix
zuhengxu Oct 5, 2023
8255949
update MLP to make it work
zuhengxu Oct 9, 2023
37bd3ab
fix type instability in invertible MLP
zuhengxu Oct 9, 2023
a67a6e2
update deep mlp
zuhengxu Oct 10, 2023
d6c6a2b
update deep hamflow result
zuhengxu Oct 10, 2023
799343a
update flows
zuhengxu Oct 10, 2023
b908a80
add shadowing window computation
zuhengxu Oct 11, 2023
8bde6bf
minor ed
zuhengxu Oct 11, 2023
eaf1775
refactoring training and setup files
zuhengxu Oct 11, 2023
519fd66
minor setup change to allow convenient precision switch
zuhengxu Oct 11, 2023
77816fd
get MLP figs and res'
zuhengxu Oct 12, 2023
8db13c9
obtain figs for ham
zuhengxu Oct 12, 2023
4a656e8
rm duplication
zuhengxu Oct 12, 2023
9f68fdd
add some MLP result
zuhengxu Oct 12, 2023
8c703d5
add shf res
zuhengxu Oct 12, 2023
9e27814
add mom normalization layer implementation
zuhengxu Oct 13, 2023
c63abfc
update shf res
zuhengxu Oct 13, 2023
61862ff
add script for deep shf
zuhengxu Oct 13, 2023
4fe363f
update shf_big res
zuhengxu Oct 13, 2023
a9d46bf
minor update
zuhengxu Oct 13, 2023
8bd27ee
update log reg and new shf res with q0 being std normal
zuhengxu Oct 16, 2023
156faeb
minor bug update
zuhengxu Oct 16, 2023
7e7e782
update sfh banana figs
zuhengxu Oct 17, 2023
8be77ad
add some lip constant figs for shf banana
zuhengxu Oct 17, 2023
0381259
add some lip constant reg
zuhengxu Oct 17, 2023
37835c8
update shf banana/cross
zuhengxu Oct 18, 2023
2f15b1b
update shf/stab
zuhengxu Oct 18, 2023
2ae4f27
update banana plots
zuhengxu Oct 18, 2023
3e453e5
update banana res
zuhengxu Oct 18, 2023
ed088eb
fix conflict
zuhengxu Oct 18, 2023
7db8236
update cross res
zuhengxu Oct 18, 2023
167af8e
update cross res
zuhengxu Oct 18, 2023
d491fee
update
zuhengxu Oct 18, 2023
1aaad46
update cross
zuhengxu Oct 18, 2023
fa80afb
update
zuhengxu Oct 18, 2023
3ffe162
Merge branch 'stability' of github.com:zuhengxu/NormalizingFlows.jl i…
zuhengxu Oct 18, 2023
e25f31b
update cross elbo res
zuhengxu Oct 18, 2023
4e78a5a
update shadowing res
zuhengxu Oct 18, 2023
7704896
udpate banana and cross res
zuhengxu Oct 18, 2023
b4b4aaf
update banana figs
zuhengxu Oct 18, 2023
9d4f6ba
update cross res
zuhengxu Oct 19, 2023
b7d7b67
update cross figs and res
zuhengxu Oct 19, 2023
3b148d7
update cross lp scaling figs
zuhengxu Oct 20, 2023
ae35d63
minor update
zuhengxu Oct 20, 2023
ae9b9f2
new cross figs
zuhengxu Oct 20, 2023
deaf04a
fix conflict
zuhengxu Oct 20, 2023
2dd855a
rm readme
zuhengxu Jan 9, 2025
b30d961
update
zuhengxu Mar 16, 2025
aeda365
Merge branch 'stability' into more_examples
zuhengxu Mar 16, 2025
2c4d793
rm many unrelated code
zuhengxu Mar 16, 2025
6056ef1
merge conflict from upstream
zuhengxu Mar 16, 2025
960a359
keep cleaning
zuhengxu Mar 17, 2025
51bc5b9
merge from turing/main
zuhengxu Apr 9, 2025
f5a2f9a
fix warpgaussian logpdf error/making neal funnel logpdf working with …
zuhengxu Apr 9, 2025
d916b64
add easier model loading code
zuhengxu Apr 9, 2025
d84fce4
restructure example folder and refactor planar and radial examples
zuhengxu Apr 10, 2025
6872760
clean demos for realnvp/planar/radial/ fix a bug in nsf
zuhengxu Apr 10, 2025
e3fb428
rm enzyme dependency
zuhengxu Apr 10, 2025
ddf0ba9
tune nsf flow--enlarge B--to make it work
zuhengxu Apr 10, 2025
f2affe2
rm some useless files from HamVI
zuhengxu Apr 10, 2025
78d5f45
rename MLP_3layer to mlp3 for convenience
zuhengxu Apr 10, 2025
2229577
rename common.jl to utils
zuhengxu Apr 10, 2025
aefc651
minor update file naming, and better nsf implementation
zuhengxu Apr 11, 2025
7f1580c
minor ed
zuhengxu Apr 11, 2025
750c1e1
rm redundant nsf file
zuhengxu Apr 11, 2025
ab48d11
shrink stepsize for nsf
zuhengxu Apr 11, 2025
42bc4a1
cleaned hamiltonian flow
zuhengxu Apr 12, 2025
edf1bac
minor ed
zuhengxu Apr 12, 2025
c05789b
rm redundant pkgs
zuhengxu Apr 12, 2025
228d0e8
MLutils and itertools from dependencies
zuhengxu May 26, 2025
6b1500c
rm useless example utils functions
zuhengxu May 26, 2025
e257fab
add elbo_batch implementation; much faster for invertible NN based flows
zuhengxu May 26, 2025
17c8c5d
rm simple unpack from ham flows
zuhengxu May 26, 2025
c57becd
rm simple unpack dependency from example env
zuhengxu May 26, 2025
738c2dd
bump version
zuhengxu May 26, 2025
0672626
add test for elbo_batch
zuhengxu May 26, 2025
427225e
use elbo_batch for real_nvp; achieved 4+ times speed up
zuhengxu May 26, 2025
c77dc2d
rm load_model and use constructor directly
zuhengxu May 26, 2025
a18bb12
fix doc building error
zuhengxu May 26, 2025
2bd914d
update DI bounds
zuhengxu May 26, 2025
cc3c761
fix cuda test error
zuhengxu May 26, 2025
06e7651
reduce the iter number
zuhengxu May 26, 2025
5508a9f
add CI group for exampels
zuhengxu May 27, 2025
4bf8fe1
update examples CI
zuhengxu May 27, 2025
5f2a8da
fixing CI bugs
zuhengxu May 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions .github/workflows/Examples.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: NF Examples

on:
push:
branches:
- main
tags: ['*']
pull_request:

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
run-examples:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: '1'
arch: x64
- uses: julia-actions/cache@v2
- name: Run NF examples
run: |
cd example
julia --project=. --color=yes -e '
using Pkg;
Pkg.develop(PackageSpec(path=joinpath(pwd(), "..")));
Pkg.instantiate();
@info "Running planar flow demo";
include("demo_planar_flow.jl");
@info "Running radial flow demo";
include("demo_radial_flow.jl");
@info "Running Real NVP demo";
include("demo_RealNVP.jl");
@info "Running neural spline flow demo";
include("demo_neural_spline_flow.jl");
@info "Running Hamiltonian flow demo";
include("demo_hamiltonian_flow.jl");'
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
/docs/build/
test/Manifest.toml
example/Manifest.toml
example/LocalPreferences.toml

# Files generated by invoking Julia with --code-coverage
*.jl.cov
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "NormalizingFlows"
uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
version = "0.2.0"
version = "0.2.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -24,7 +24,7 @@ NormalizingFlowsCUDAExt = "CUDA"
ADTypes = "1"
Bijectors = "0.12.6, 0.13, 0.14, 0.15"
CUDA = "5"
DifferentiationInterface = "0.6.42"
DifferentiationInterface = "0.6, 0.7"
Distributions = "0.25"
DocStringExtensions = "0.9"
Optimisers = "0.2.16, 0.3, 0.4"
Expand Down
5 changes: 5 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ and hope to generate approximate samples from it.
```@docs
NormalizingFlows.elbo
```

```@docs
NormalizingFlows.elbo_batch
```

#### Log-likelihood

By maximizing the log-likelihood, it is equivalent to minimizing the forward KL divergence between $q_\theta$ and $p$, i.e.,
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ See the [documentation](https://turinglang.org/NormalizingFlows.jl/dev/) for mor
To install the package, run the following command in the Julia REPL:
```
] # enter Pkg mode
(@v1.9) pkg> add [email protected]:TuringLang/NormalizingFlows.jl.git
(@v1.11) pkg> add [email protected]:TuringLang/NormalizingFlows.jl.git
```
Then simply run the following command to use the package:
```julia
Expand Down
13 changes: 9 additions & 4 deletions example/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[extras]
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
6 changes: 3 additions & 3 deletions example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ normalizing flow to approximate the target distribution using `NormalizingFlows.
Currently, all examples share the same [Julia project](https://pkgdocs.julialang.org/v1/environments/#Using-someone-else's-project). To run the examples, first activate the project environment:

```julia
# pwd() = "NormalizingFlows.jl/"
using Pkg; Pkg.activate("example"); Pkg.instantiate()
# pwd() = "NormalizingFlows.jl/example"
using Pkg; Pkg.activate("."); Pkg.instantiate()
```
This will install all needed packages, at the exact versions when the model was last updated. Then you can run the model code with include("<example-to-run>.jl"), or by running the example script line-by-line.
This will install all needed packages, at the exact versions when the model was last updated. Then you can run the model code with `include("<example-to-run>.jl")`, or by running the example script line-by-line.
19 changes: 19 additions & 0 deletions example/SyntheticTargets.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using DocStringExtensions
using Distributions, Random, LinearAlgebra
using IrrationalConstants
using Plots


include("targets/banana.jl")
include("targets/cross.jl")
include("targets/neal_funnel.jl")
include("targets/warped_gaussian.jl")

function visualize(p::ContinuousMultivariateDistribution, samples=rand(p, 1000))
xrange = range(minimum(samples[1, :]) - 1, maximum(samples[1, :]) + 1; length=100)
yrange = range(minimum(samples[2, :]) - 1, maximum(samples[2, :]) + 1; length=100)
z = [exp(Distributions.logpdf(p, [x, y])) for x in xrange, y in yrange]
fig = contour(xrange, yrange, z'; levels=15, color=:viridis, label="PDF", linewidth=2)
scatter!(samples[1, :], samples[2, :]; label="Samples", alpha=0.3, legend=:bottomright)
return fig
end
180 changes: 180 additions & 0 deletions example/demo_RealNVP.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
using Flux
using Bijectors
using Bijectors: partition, combine, PartitionMask

using Random, Distributions, LinearAlgebra
using Functors
using Optimisers, ADTypes
using Mooncake
using NormalizingFlows

include("SyntheticTargets.jl")
include("utils.jl")

##################################
# define affine coupling layer using Bijectors.jl interface
#################################
struct AffineCoupling <: Bijectors.Bijector
dim::Int
mask::Bijectors.PartitionMask
s::Flux.Chain
t::Flux.Chain
end

# let params track field s and t
@functor AffineCoupling (s, t)

function AffineCoupling(
dim::Int, # dimension of input
hdims::Int, # dimension of hidden units for s and t
mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
)
cdims = length(mask_idx) # dimension of parts used to construct coupling law
s = mlp3(cdims, hdims, cdims)
t = mlp3(cdims, hdims, cdims)
mask = PartitionMask(dim, mask_idx)
return AffineCoupling(dim, mask, s, t)
end

function Bijectors.transform(af::AffineCoupling, x::AbstractVecOrMat)
# partition vector using 'af.mask::PartitionMask`
x₁, x₂, x₃ = partition(af.mask, x)
y₁ = x₁ .* af.s(x₂) .+ af.t(x₂)
return combine(af.mask, y₁, x₂, x₃)
end

function (af::AffineCoupling)(x::AbstractArray)
return transform(af, x)
end

function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractVector)
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
logjac = sum(log abs, af.s(x_2)) # this is a scalar
return combine(af.mask, y_1, x_2, x_3), logjac
end

function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractMatrix)
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
logjac = sum(log abs, af.s(x_2); dims = 1) # 1 × size(x, 2)
return combine(af.mask, y_1, x_2, x_3), vec(logjac)
end


function Bijectors.with_logabsdet_jacobian(
iaf::Inverse{<:AffineCoupling}, y::AbstractVector
)
af = iaf.orig
# partition vector using `af.mask::PartitionMask`
y_1, y_2, y_3 = partition(af.mask, y)
# inverse transformation
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
logjac = -sum(log abs, af.s(y_2))
return combine(af.mask, x_1, y_2, y_3), logjac
end

function Bijectors.with_logabsdet_jacobian(
iaf::Inverse{<:AffineCoupling}, y::AbstractMatrix
)
af = iaf.orig
# partition vector using `af.mask::PartitionMask`
y_1, y_2, y_3 = partition(af.mask, y)
# inverse transformation
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
logjac = -sum(log abs, af.s(y_2); dims = 1)
return combine(af.mask, x_1, y_2, y_3), vec(logjac)
end

###################
# an equivalent definition of AffineCoupling using Bijectors.Coupling
# (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1)
###################

# struct AffineCoupling <: Bijectors.Bijector
# dim::Int
# mask::Bijectors.PartitionMask
# s::Flux.Chain
# t::Flux.Chain
# end

# # let params track field s and t
# @functor AffineCoupling (s, t)

# function AffineCoupling(dim, mask, s, t)
# return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask)
# end

# function AffineCoupling(
# dim::Int, # dimension of input
# hdims::Int, # dimension of hidden units for s and t
# mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
# )
# cdims = length(mask_idx) # dimension of parts used to construct coupling law
# s = mlp3(cdims, hdims, cdims)
# t = mlp3(cdims, hdims, cdims)
# mask = PartitionMask(dim, mask_idx)
# return AffineCoupling(dim, mask, s, t)
# end



##################################
# start demo
#################################
Random.seed!(123)
rng = Random.default_rng()
T = Float32

######################################
# a difficult banana target
######################################
target = Banana(2, 1.0f0, 100.0f0)
logp = Base.Fix1(logpdf, target)

######################################
# learn the target using Affine coupling flow
######################################
@leaf MvNormal
q0 = MvNormal(zeros(T, 2), ones(T, 2))

d = 2
hdims = 32

# alternating the coupling layers
Ls = [AffineCoupling(d, hdims, [1]) AffineCoupling(d, hdims, [2]) for i in 1:3]

flow = create_flow(Ls, q0)
flow_untrained = deepcopy(flow)


######################################
# start training
######################################
sample_per_iter = 64

# callback function to log training progress
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
flow_trained, stats, _ = train_flow(
rng,
elbo_batch, # using elbo_batch instead of elbo achieves 4-5 times speedup
flow,
logp,
sample_per_iter;
max_iters=100, # change to larger number of iterations (e.g., 50_000) for better results
optimiser=Optimisers.Adam(5e-4),
ADbackend=adtype,
show_progress=true,
callback=cb,
hasconverged=checkconv,
)
θ, re = Optimisers.destructure(flow_trained)
losses = map(x -> x.loss, stats)

######################################
# evaluate trained flow
######################################
plot(losses; label="Loss", linewidth=2) # plot the loss
compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000)
Loading