Skip to content

Commit f1ec67d

Browse files
committed
changes from hack session
1 parent fed2d97 commit f1ec67d

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

src/Solus.jl

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ An ensemble of `inputs` and their corresponding `outputs` from the forward model
4646
4747
$(DocStringExtensions.FIELDS)
4848
"""
49-
struct Ensemble{T,O}
49+
struct Ensemble{I,O}
5050
"""
5151
A matrix of inputs: each column corresponds to a vector of ``θ``s
5252
"""
53-
inputs::Matrix{T}
53+
inputs::Vector{I}
5454
"""
5555
Result of forward model for each column of `inputs`.
5656
"""
@@ -63,19 +63,26 @@ end
6363
Peform an iteration of NEKI, returning a new `Ensemble` object.
6464
"""
6565
function neki_iter(prob::SolusProblem, ens::Ensemble)
66-
covθ = cov(ens.inputs; dims=2)
66+
covθ = cov(ens.inputs)
6767
covθ += (tr(covθ)*1e-15)I
6868

6969
m = mean(ens.outputs)
70-
CG = [dot(u-prob.obs, v-m, prob.space) for u in ens.outputs, v in ens.outputs] #compute mean-field matrix
71-
72-
Δt = 0.1 / norm(CG)
73-
implicit = lu( I + 1 * Δt .* covθ * cov(prob.prior)) # todo: incorporate means
74-
rhsv = ens.inputs*CG
75-
rhs = ens.inputs - rhsv*Δt
7670

77-
inputs = (implicit \ rhs) + rand(MvNormal(Δt .* covθ), size(ens.inputs, 2))
78-
outputs = [prob.forwardmodel(θ) for θ in eachslice(ens.inputs,dims=2)]
71+
CG = [dot(Gθk - m, Gθj - prob.obs, prob.space) for Gθj in ens.outputs, Gθk in ens.outputs]
72+
73+
Δt = 0.1 / norm(CG) # use better constants
74+
implicit = lu( I + Δt .* (covθ / cov(prob.prior)) ) # todo: incorporate means
75+
noise = MvNormal(covθ)
76+
77+
inputs = map(enumerate(ens.inputs)) do (j, θj)
78+
X = sum(enumerate(ens.inputs)) do (k, θk)
79+
CG[k,j]*θk
80+
end
81+
rhs = θj .- Δt .* X
82+
(implicit \ rhs) .+ sqrt(Δt)*rand(noise)
83+
end
84+
85+
outputs = map(prob.forwardmodel, inputs)
7986
Ensemble(inputs, outputs)
8087
end
8188

src/spaces.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,11 @@ end
2020
norm(x, ::DefaultSpace) = norm(x)
2121
dot(x,y, ::DefaultSpace) = dot(x,y)
2222

23+
24+
struct CovarianceSpace{M} <: HilbertSpace
25+
Γ::M
26+
end
27+
norm(x, c::CovarianceSpace) = x'*(c.Γ\x)
28+
dot(x,y, c::CovarianceSpace) = x'*(c.Γ\y)
29+
2330
# TODO: CovarianceSpace/PDSpace with a positive definite matrix

0 commit comments

Comments
 (0)