Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this added?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because, when I precompiled NeuralPDE, it gave me an error saying this dependency is required for NeuralPDE but not found in its dependencies section

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, it doesn't happen with me. Are you sure you are using NeuralPDE's env?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. I have made all the changes in the same repository and Julia env

LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Expand Down
30 changes: 30 additions & 0 deletions lib/BayesianNeuralPDE/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name = "BayesianNeuralPDE"
uuid = "3cea9122-e921-42ea-a9d7-c72fcb58ce53"
authors = ["paramthakkar123 <[email protected]>"]
version = "0.1.0"

[deps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
NeuralPDE = "315f7962-48a3-4962-8226-d0f33b1235f0"

[compat]
ChainRulesCore = "1.25.1"
ConcreteStructs = "0.2.3"
MonteCarloMeasurements = "1.4.3"
Printf = "1.11.0"
SciMLBase = "2.72.1"
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,4 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem, alg::BNNODE, args...; dt
end

return BPINNsolution(fullsolution, ensemblecurves, estimnnparams, estimated_params, t)
end
end
26 changes: 26 additions & 0 deletions lib/BayesianNeuralPDE/src/BayesianNeuralPDE.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module BayesianNeuralPDE

using MCMCChains, Distributions, OrdinaryDiffEq, OptimizationOptimisers, Lux,
AdvancedHMC, Statistics, Random, Functors, ComponentArrays, MonteCarloMeasurements
using Printf: @printf
using ConcreteStructs: @concrete
using NeuralPDE: PhysicsInformedNN
using SciMLBase: SciMLBase
using ChainRulesCore: ChainRulesCore, @non_differentiable, @ignore_derivatives
using LogDensityProblems: LogDensityProblems

abstract type AbstractPINN end

abstract type AbstractTrainingStrategy end
abstract type NeuralPDEAlgorithm <: SciMLBase.AbstractODEAlgorithm end

include("advancedHMC_MCMC.jl")
include("pinn_types.jl")
include("BPINN_ode.jl")
include("discretize.jl")
include("PDE_BPINN.jl")

export BNNODE, ahmc_bayesian_pinn_ode, ahmc_bayesian_pinn_pde
export BPINNsolution, BayesianPINN

end
Original file line number Diff line number Diff line change
Expand Up @@ -507,4 +507,4 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
return BPINNsolution(
fullsolution, ensemblecurves, estimnnparams, estimated_params, timepoints)
end
end
end
240 changes: 240 additions & 0 deletions lib/BayesianNeuralPDE/src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
"""
ahmc_bayesian_pinn_ode(prob, chain; strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0f0,
l2std = [0.05], phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
MCMCkwargs = (n_leapfrog = 30,), progress = false,
verbose = false)

!!! warn

Note that `ahmc_bayesian_pinn_ode()` only supports ODEs which are written in the
out-of-place form, i.e. `du = f(u,p,t)`, and not `f(du,u,p,t)`. If not declared
out-of-place, then `ahmc_bayesian_pinn_ode()` will exit with an error.

## Example

```julia
linear = (u, p, t) -> -u / p[1] + exp(t / p[2]) * cos(t)
tspan = (0.0, 10.0)
u0 = 0.0
p = [5.0, -5.0]
prob = ODEProblem(linear, u0, tspan, p)

### CREATE DATASET (Necessity for accurate Parameter estimation)
sol = solve(prob, Tsit5(); saveat = 0.05)
u = sol.u[1:100]
time = sol.t[1:100]

### dataset and BPINN create
x̂ = collect(Float64, Array(u) + 0.05 * randn(size(u)))
dataset = [x̂, time]

chain1 = Lux.Chain(Lux.Dense(1, 5, tanh), Lux.Dense(5, 5, tanh), Lux.Dense(5, 1)

### simply solving ode here hence better to not pass dataset(uses ode params specified in prob)
fh_mcmc_chain1, fhsamples1, fhstats1 = ahmc_bayesian_pinn_ode(prob, chain1,
dataset = dataset,
draw_samples = 1500,
l2std = [0.05],
phystd = [0.05],
priorsNNw = (0.0,3.0))

### solving ode + estimating parameters hence dataset needed to optimize parameters upon + Pior Distributions for ODE params
fh_mcmc_chain2, fhsamples2, fhstats2 = ahmc_bayesian_pinn_ode(prob, chain1,
dataset = dataset,
draw_samples = 1500,
l2std = [0.05],
phystd = [0.05],
priorsNNw = (0.0,3.0),
param = [Normal(6.5,0.5), Normal(-3,0.5)])
```

## NOTES

Dataset is required for accurate Parameter estimation + solving equations
Incase you are only solving the Equations for solution, do not provide dataset

## Positional Arguments

* `prob`: DEProblem(out of place and the function signature should be f(u,p,t).
* `chain`: Lux Neural Netork which would be made the Bayesian PINN.

## Keyword Arguments

* `strategy`: The training strategy used to choose the points for the evaluations. By
default GridTraining is used with given physdt discretization.
* `init_params`: initial parameter values for BPINN (ideally for multiple chains different
initializations preferred)
* `nchains`: number of chains you want to sample
* `draw_samples`: number of samples to be drawn in the MCMC algorithms (warmup samples are
~2/3 of draw samples)
* `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset
* `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System
* `phynewstd`: standard deviation of new loss func term
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of
BPINN are Normal Distributions by default.
* `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems.
* `autodiff`: Boolean Value for choice of Derivative Backend(default is numerical)
* `physdt`: Timestep for approximating ODE in it's Time domain. (1/20.0 by default)
* `Kernel`: Choice of MCMC Sampling Algorithm (AdvancedHMC.jl implementations HMC/NUTS/HMCDA)
* `Integratorkwargs`: `Integrator`, `jitter_rate`, `tempering_rate`.
Refer: https://turinglang.org/AdvancedHMC.jl/stable/
* `Adaptorkwargs`: `Adaptor`, `Metric`, `targetacceptancerate`.
Refer: https://turinglang.org/AdvancedHMC.jl/stable/ Note: Target percentage (in decimal)
of iterations in which the proposals are accepted (0.8 by default)
* `MCMCargs`: A NamedTuple containing all the chosen MCMC kernel's (HMC/NUTS/HMCDA)
Arguments, as follows :
* `n_leapfrog`: number of leapfrog steps for HMC
* `δ`: target acceptance probability for NUTS and HMCDA
* `λ`: target trajectory length for HMCDA
* `max_depth`: Maximum doubling tree depth (NUTS)
* `Δ_max`: Maximum divergence during doubling tree (NUTS)
Refer: https://turinglang.org/AdvancedHMC.jl/stable/
* `progress`: controls whether to show the progress meter or not.
* `verbose`: controls the verbosity. (Sample call args in AHMC)

!!! warning

AdvancedHMC.jl is still developing convenience structs so might need changes on new
releases.
"""
function ahmc_bayesian_pinn_ode(
prob::SciMLBase.ODEProblem, chain; strategy = GridTraining, dataset = [nothing],
init_params = nothing, draw_samples = 1000, physdt = 1 / 20.0, l2std = [0.05],
phystd = [0.05], phynewstd = [0.05], priorsNNw = (0.0, 2.0), param = [], nchains = 1,
autodiff = false, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,), MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false, estim_collocate = false)
@assert !isinplace(prob) "The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."

chain isa AbstractLuxLayer || (chain = FromFluxAdaptor()(chain))

strategy = strategy == GridTraining ? strategy(physdt) : strategy

if dataset != [nothing] &&
(length(dataset) < 2 || !(dataset isa Vector{<:Vector{<:AbstractFloat}}))
error("Invalid dataset. dataset would be timeseries (x̂,t) where type: Vector{Vector{AbstractFloat}")
end

if dataset != [nothing] && param == []
println("Dataset is only needed for Parameter Estimation + Forward Problem, not in only Forward Problem case.")
elseif dataset == [nothing] && param != []
error("Dataset Required for Parameter Estimation.")
end

initial_nnθ, chain, st = generate_ltd(chain, init_params)

@assert nchains≤Threads.nthreads() "number of chains is greater than available threads"
@assert nchains≥1 "number of chains must be greater than 1"

# eltype(physdt) cause needs Float64 for find_good_stepsize
# Lux chain(using component array later as vector_to_parameter need namedtuple)
T = eltype(physdt)
initial_θ = getdata(ComponentArray{T}(initial_nnθ))

# adding ode parameter estimation
nparameters = length(initial_θ)
ninv = length(param)
priors = [
MvNormal(T(priorsNNw[1]) * ones(T, nparameters),
Diagonal(abs2.(T(priorsNNw[2]) .* ones(T, nparameters))))
]

# append Ode params to all paramvector
if ninv > 0
# shift ode params(initialise ode params by prior means)
initial_θ = vcat(initial_θ, [Distributions.params(param[i])[1] for i in 1:ninv])
priors = vcat(priors, param)
nparameters += ninv
end

smodel = StatefulLuxLayer{true}(chain, nothing, st)
# dimensions would be total no of params,initial_nnθ for Lux namedTuples
ℓπ = LogTargetDensity(nparameters, prob, smodel, strategy, dataset, priors,
phystd, phynewstd, l2std, autodiff, physdt, ninv, initial_nnθ, estim_collocate)

if verbose
@printf("Current Physics Log-likelihood: %g\n", physloglikelihood(ℓπ, initial_θ))
@printf("Current Prior Log-likelihood: %g\n", priorweights(ℓπ, initial_θ))
@printf("Current SSE against dataset Log-likelihood: %g\n",
L2LossData(ℓπ, initial_θ))
if estim_collocate
@printf("Current gradient loss against dataset Log-likelihood: %g\n",
L2loss2(ℓπ, initial_θ))
end
end

Adaptor = Adaptorkwargs[:Adaptor]
Metric = Adaptorkwargs[:Metric]
targetacceptancerate = Adaptorkwargs[:targetacceptancerate]

# Define Hamiltonian system (nparameters ~ dimensionality of the sampling space)
metric = Metric(nparameters)
hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff)

# parallel sampling option
if nchains != 1
# Cache to store the chains
chains = Vector{Any}(undef, nchains)
statsc = Vector{Any}(undef, nchains)
samplesc = Vector{Any}(undef, nchains)

Threads.@threads for i in 1:nchains
# each chain has different initial NNparameter values(better posterior exploration)
initial_θ = vcat(
randn(eltype(initial_θ), nparameters - ninv),
initial_θ[(nparameters - ninv + 1):end]
)
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = integratorchoice(Integratorkwargs, initial_ϵ)
adaptor = adaptorchoice(Adaptor, MassMatrixAdaptor(metric),
StepSizeAdaptor(targetacceptancerate, integrator))

MCMC_alg = kernelchoice(Kernel, MCMCkwargs)
Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator)
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor;
progress = progress, verbose = verbose)

samplesc[i] = samples
statsc[i] = stats
mcmc_chain = Chains(reduce(hcat, samples)')
chains[i] = mcmc_chain
end

return chains, samplesc, statsc
else
initial_ϵ = find_good_stepsize(hamiltonian, initial_θ)
integrator = integratorchoice(Integratorkwargs, initial_ϵ)
adaptor = adaptorchoice(Adaptor, MassMatrixAdaptor(metric),
StepSizeAdaptor(targetacceptancerate, integrator))

MCMC_alg = kernelchoice(Kernel, MCMCkwargs)
Kernel = AdvancedHMC.make_kernel(MCMC_alg, integrator)
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
adaptor; progress = progress, verbose = verbose)

if verbose
println("Sampling Complete.")
@printf("Final Physics Log-likelihood: %g\n",
physloglikelihood(ℓπ, samples[end]))
@printf("Final Prior Log-likelihood: %g\n", priorweights(ℓπ, samples[end]))
@printf("Final SSE against dataset Log-likelihood: %g\n",
L2LossData(ℓπ, samples[end]))
if estim_collocate
@printf("Final gradient loss against dataset Log-likelihood: %g\n",
L2loss2(ℓπ, samples[end]))
end
end

# return a chain(basic chain),samples and stats
matrix_samples = reshape(hcat(samples...), (length(samples[1]), length(samples), 1))
mcmc_chain = MCMCChains.Chains(matrix_samples)
return mcmc_chain, samples, stats
end
end
Loading
Loading