Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions examples/usage.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
# Switch To MKL For Faster Computation
using MKL
# using MKL

# Enable Logging
using Logging, TerminalLoggers
global_logger(TerminalLogger())

# Data
using Distributions
ndata = 1024
ndimension = 1
data_dist = Beta{Float32}(2.0f0, 4.0f0)
r = rand(data_dist, ndimension, ndata)
r = convert.(Float32, r)

# Parameters
nvars = 1
nvars = size(r, 1)
naugs = nvars
# n_in = nvars # without augmentation
n_in = nvars + naugs # with augmentation
n = 1024
n_in = nvars + naugs

# Model
using ContinuousNormalizingFlows,
Lux, OrdinaryDiffEqDefault, SciMLSensitivity, ADTypes, Zygote, MLDataDevices

# To use gpu, add related packages
# using LuxCUDA, CUDA, cuDNN

nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh))
icnf = construct(
RNODE,
Expand All @@ -24,6 +34,7 @@ icnf = construct(
compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote
inplace = false, # not using the inplace version of functions
device = cpu_device(), # process data by CPU
# device = gpu_device(), # process data by GPU
tspan = (0.0f0, 13.0f0), # have bigger time span
steer_rate = 1.0f-1, # add random noise to end of the time span
λ₁ = 1.0f-2, # regulate flow
Expand All @@ -36,12 +47,6 @@ icnf = construct(
), # pass to the solver
)

# Data
using Distributions
data_dist = Beta{Float32}(2.0f0, 4.0f0)
r = rand(data_dist, nvars, n)
r = convert.(Float32, r)

# Fit It
using DataFrames, MLJBase, Zygote, ADTypes, OptimizationOptimisers
df = DataFrame(transpose(r), :auto)
Expand All @@ -55,25 +60,26 @@ model = ICNFModel(
)
mach = machine(model, df)
fit!(mach)
# CUDA.@allowscalar fit!(mach) # needed for gpu
ps, st = fitted_params(mach)

# Store It
using JLD2, UnPack
jldsave("fitted.jld2"; ps, st) # save
@unpack ps, st = load("fitted.jld2") # load
jldsave("fitted.jld2"; ps, st) # save it
@unpack ps, st = load("fitted.jld2") # load it

# Use It
d = ICNFDist(icnf, TestMode(), ps, st) # direct way
# d = ICNFDist(mach, TestMode()) # alternative way
actual_pdf = pdf.(data_dist, vec(r))
estimated_pdf = pdf(d, r)
new_data = rand(d, n)
new_data = rand(d, ndata)

# Evaluate It
using Distances
mad_ = meanad(estimated_pdf, actual_pdf)
msd_ = msd(estimated_pdf, actual_pdf)
tv_dis = totalvariation(estimated_pdf, actual_pdf) / n
tv_dis = totalvariation(estimated_pdf, actual_pdf) / ndata
res_df = DataFrame(; mad_, msd_, tv_dis)
display(res_df)

Expand Down
Loading