Skip to content

Commit f2f34a5

Browse files
committed
howto custom Likelihood
1 parent 0697a44 commit f2f34a5

File tree

8 files changed

+457
-0
lines changed

8 files changed

+457
-0
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ makedocs(;
2424
],
2525
"How to" => [
2626
".. use GPU" => "tutorials/lux_gpu.md",
27+
".. specify log-Likelihood" => "tutorials/logden_user.md",
2728
".. model independent parameters" => "tutorials/blocks_corr.md",
2829
".. model site-global corr" => "tutorials/corr_site_global.md",
2930
],

docs/src/tutorials/logden_user.md

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# How to specify a custom LogLikelihood of the observations
2+
3+
4+
``` @meta
5+
CurrentModule = HybridVariationalInference
6+
```
7+
8+
This guide shows how the user can specify a customized log-density function.
9+
10+
## Motivation
11+
12+
The loglikelihood function assigns a cost to the mismatch between predictions and
13+
observations. This often needs to be customized to the specific inversion.
14+
15+
This guide walks through he specification of such a function and inspects
16+
differences among two log-likelihood functions.
17+
Specifically, it will assume observation errors to be independently distributed
18+
according to a LogNormal distribution with a specified fixed relative error,
19+
compared to an inversion assuming observation error to be distributed independently normal.
20+
21+
First load necessary packages.
22+
23+
``` julia
24+
using HybridVariationalInference
25+
using ComponentArrays: ComponentArrays as CA
26+
using Bijectors
27+
using SimpleChains
28+
using MLUtils
29+
using JLD2
30+
using Random
31+
using CairoMakie
32+
using PairPlots # scatterplot matrices
33+
```
34+
35+
This tutorial reuses and modifies the fitted object saved at the end of the
36+
[Basic workflow without GPU](@ref) tutorial, that used a log-Likelihood
37+
function assuming observation error to be distributed independently normal.
38+
39+
``` julia
40+
fname = "intermediate/basic_cpu_results.jld2"
41+
print(abspath(fname))
42+
prob = probo_normal = load(fname, "probo");
43+
```
44+
45+
## Write the LogLikelihood Function
46+
47+
The function signature corresponds to the one of [`neg_logden_indep_normal`](@ref).
48+
of signature
49+
50+
`neg_log_den_user(y_pred, y_obs, y_unc; kwargs...)`
51+
52+
It takes inputs of predictions, `y_pred`, observations, `y_obs`,
53+
and uncertainties parameters, `y_unc` and returns the logarithm of the
54+
likelihhood up to a constant.
55+
56+
All of the arguments are vectors of the same length specifying predictions and
57+
observations for one site.
58+
If `y_pred`, `y_obs` are given as a matrix of several column-vectors, their summed
59+
Likelihood is computed.
60+
61+
The density of a LogNormal distribution is
62+
63+
$$
64+
\frac{ 1 }{ x \sqrt{2 \pi \sigma^2} } \exp\left( -\frac{ (\ln x-\mu)^2 }{2 \sigma^2} \right)$$
65+
66+
where x is the observation, μ is the log of the prediction, and σ is the scale
67+
parameter that is related to the relative error, $c_v$ by $\sigma = \sqrt{ln(c^2_v + 1)}$.
68+
69+
Taking the log:
70+
71+
$$
72+
-ln x -\frac{1}{2} ln \sigma^2 -\frac{1}{2} ln (2 \pi) -\frac{ (\ln x-\mu)^2 }{2 \sigma^2}$$
73+
74+
Negating and dropping the constants $-\frac{1}{2} ln (2 \pi)$ and $-\frac{1}{2} ln \sigma^2$
75+
76+
$$
77+
ln x + \frac{1}{2} \left(\frac{ (\ln x-\mu)^2 }{\sigma^2} \right)$$
78+
79+
``` julia
80+
function neg_logden_lognormalep_lognormal(y_pred, y_obs::AbstractArray{ET}, y_unc;
81+
σ2 = log(abs2(ET(0.02)) + ET(1))) where ET
82+
lnx = log.(CA.getdata(y_obs))
83+
μ = log.(CA.getdata(y_pred))
84+
nlogL = sum(lnx .+ abs2.(lnx .- μ) ./ (ET(2) .* σ2))
85+
#nlogL = sum(lnx + (log(σ2) .+ abs2.(lnx .- μ) ./ σ2) ./ ET(2)) # nonconstant σ2
86+
return (nlogL)
87+
end
88+
```
89+
90+
If information on the different relative error by observation was available,
91+
we could pass that information using the DataLoader with `y_unc`, rather than
92+
assuming a constant relative error across observations.
93+
94+
## Update the problem and redo the inversion
95+
96+
HybridProblem has keyword argument `py` to specify the function of negative Log-Likelihood.
97+
98+
``` julia
99+
prob_lognormal = HybridProblem(prob; py = neg_logden_lognormalep_lognormal)
100+
101+
using OptimizationOptimisers
102+
import Zygote
103+
104+
solver = HybridPosteriorSolver(; alg=Adam(0.02), n_MC=3)
105+
106+
(; probo) = solve(prob_lognormal, solver;
107+
callback = callback_loss(100), # output during fitting
108+
epochs = 20,
109+
); probo_lognormal = probo;
110+
```
111+
112+
## Compare results between assumptions of observation error
113+
114+
First, draw a sample form the inversion assumping normal and a sample from
115+
the inversion assuming loglornally distributed observation errors.
116+
117+
``` julia
118+
n_sample_pred = 400
119+
(y_normal, θsP_normal, θsMs_normal) = (; y, θsP, θsMs) = predict_hvi(
120+
Random.default_rng(), probo_normal; n_sample_pred)
121+
(y_lognormal, θsP_lognormal, θsMs_lognormal) = (; y, θsP, θsMs) = predict_hvi(
122+
Random.default_rng(), probo_lognormal; n_sample_pred)
123+
```
124+
125+
Get the original observations from the DataLoader of the problem, and
126+
compute the residuals.
127+
128+
``` julia
129+
train_loader = get_hybridproblem_train_dataloader(probo_normal; scenario=())
130+
y_o = train_loader.data[3]
131+
resid_normal = y_o .- y_normal
132+
resid_lognormal = y_o .- y_lognormal
133+
```
134+
135+
And compare plots of some of the results.
136+
137+
``` julia
138+
i_out = 4
139+
i_site = 1
140+
fig = Figure(); ax = Axis(fig[1,1], xlabel="observations error (y_obs - y_pred)",ylabel="probability density")
141+
#hist!(ax, resid_normal[i_out,i_site,:], label="normal", normalization=:pdf)
142+
density!(ax, resid_normal[i_out,i_site,:], alpha = 0.8, label="normal")
143+
density!(ax, resid_lognormal[i_out,i_site,:], alpha = 0.8, label="lognormal")
144+
axislegend(ax, unique=true)
145+
fig
146+
```
147+
148+
![](logden_user_files/figure-commonmark/cell-8-output-1.png)
149+
150+
The density plot of the observation residuals does not show the lognormal shape.
151+
The used synthetic observations were actually noramally
152+
distributed around predictions with true parameters.
153+
154+
How does the wrong assumption of observation error influence the parameter
155+
posterior?
156+
157+
``` julia
158+
i_site = 1
159+
fig = Figure(); ax = Axis(fig[1,1], xlabel="global parameter K2",ylabel="probability density")
160+
#hist!(ax, resid_normal[i_out,i_site,:], label="normal", normalization=:pdf)
161+
density!(ax, θsP_normal[:K2,:], alpha = 0.8, label="normal")
162+
density!(ax, θsP_lognormal[:K2,:], alpha = 0.8, label="lognormal")
163+
axislegend(ax, unique=true)
164+
fig
165+
```
166+
167+
![](logden_user_files/figure-commonmark/cell-9-output-1.png)
168+
169+
The marginal posterior of the global parameters is also similar, with a small
170+
trend of lower values.
171+
172+
``` julia
173+
i_site = 1
174+
θln = vcat(θsP_lognormal, θsMs_lognormal[i_site,:,:])
175+
θln_nt = NamedTuple(Symbol("$(k)_lognormal") => CA.getdata(θln[k,:]) for k in keys(θln[:,1])) #
176+
#θn = vcat(θsP_normal, θsMs_normal[i_site,:,:])
177+
#θn_nt = NamedTuple(Symbol("$(k)_normal") => CA.getdata(θn[k,:]) for k in keys(θn[:,1])) #
178+
# ntc = (;θn_nt..., θln_nt...)
179+
plt = pairplot(θln_nt)
180+
```
181+
182+
![](logden_user_files/figure-commonmark/cell-10-output-1.png)
183+
184+
The corner plot of the independent-parameters estimate
185+
looks similar and shows correlations between site parameters, $r_1$ and $K_1$.
186+
187+
``` julia
188+
i_out = 4
189+
fig = Figure(); ax = Axis(fig[1,1], xlabel="mean(y)",ylabel="sd(y)")
190+
ymean_normal = [mean(y_normal[i_out,s,:]) for s in axes(y_normal, 2)]
191+
ysd_normal = [std(y_normal[i_out,s,:]) for s in axes(y_normal, 2)]
192+
scatter!(ax, ymean_normal, ysd_normal, label="normal")
193+
ymean_lognormal = [mean(y_lognormal[i_out,s,:]) for s in axes(y_lognormal, 2)]
194+
ysd_lognormal = [std(y_lognormal[i_out,s,:]) for s in axes(y_lognormal, 2)]
195+
scatter!(ax, ymean_lognormal, ysd_lognormal, label="lognormal")
196+
axislegend(ax, unique=true)
197+
fig
198+
```
199+
200+
![](logden_user_files/figure-commonmark/cell-11-output-1.png)
201+
202+
The predicted magnitude of error in predictions for the fourth observation across sites
203+
is of the same magnitude,
204+
and still shows (although weaker) pattern of decreasing uncertainty with
205+
increasing value.
206+
207+
``` julia
208+
plot_sd_vs_mean = (par) -> begin
209+
fig = Figure(); ax = Axis(fig[1,1], xlabel="mean($par)",ylabel="sd($par)")
210+
θmean_normal = [mean(θsMs_normal[s,par,:]) for s in axes(θsMs_normal, 1)]
211+
θsd_normal = [std(θsMs_normal[s,par,:]) for s in axes(θsMs_normal, 1)]
212+
scatter!(ax, θmean_normal, θsd_normal, label="correlated")
213+
θmean_lognormal = [mean(θsMs_lognormal[s,par,:]) for s in axes(θsMs_lognormal, 1)]
214+
θsd_lognormal = [std(θsMs_lognormal[s,par,:]) for s in axes(θsMs_lognormal, 1)]
215+
scatter!(ax, θmean_lognormal, θsd_lognormal, label="independent")
216+
axislegend(ax, unique=true)
217+
fig
218+
end
219+
plot_sd_vs_mean(:K1)
220+
```
221+
222+
![](logden_user_files/figure-commonmark/cell-12-output-1.png)
223+
224+
For the assumed fixed relative error,the uncertainty in the model
225+
parameters, $K_1$, across sites is similar to the uncertainty with nornal log-likelihood.

0 commit comments

Comments
 (0)