Skip to content

Commit e8c9365

Browse files
authored
Merge pull request #2 from climate-machine/sb/solus
rename to Solus.jl, add NEKI
2 parents 2d0f44f + d622824 commit e8c9365

File tree

7 files changed

+153
-153
lines changed

7 files changed

+153
-153
lines changed

Manifest.toml

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
1616
version = "0.8.10"
1717

1818
[[BinaryProvider]]
19-
deps = ["Libdl", "Pkg", "SHA", "Test"]
20-
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
19+
deps = ["Libdl", "SHA"]
20+
git-tree-sha1 = "c7361ce8a2129f20b0e05a89f7070820cfed6648"
2121
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
22-
version = "0.5.3"
22+
version = "0.5.4"
2323

2424
[[Compat]]
2525
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
@@ -51,6 +51,12 @@ git-tree-sha1 = "dec0ebacfbc3a2126c614ab5e903c9ef063688d0"
5151
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
5252
version = "0.17.0"
5353

54+
[[DocStringExtensions]]
55+
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
56+
git-tree-sha1 = "4d30e889c9f106a51ffa4791a88ffd4765bf20c3"
57+
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
58+
version = "0.7.0"
59+
5460
[[InteractiveUtils]]
5561
deps = ["Markdown"]
5662
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@@ -73,25 +79,25 @@ deps = ["Base64"]
7379
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
7480

7581
[[Missings]]
76-
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
77-
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
82+
deps = ["SparseArrays", "Test"]
83+
git-tree-sha1 = "f0719736664b4358aa9ec173077d4285775f8007"
7884
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
79-
version = "0.4.0"
85+
version = "0.4.1"
8086

8187
[[Mmap]]
8288
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
8389

8490
[[OrderedCollections]]
8591
deps = ["Random", "Serialization", "Test"]
86-
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
92+
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
8793
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
88-
version = "1.0.2"
94+
version = "1.1.0"
8995

9096
[[PDMats]]
9197
deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"]
92-
git-tree-sha1 = "b6c91fc0ab970c0563cbbe69af18d741a49ce551"
98+
git-tree-sha1 = "8b68513175b2dc4023a564cb0e917ce90e74fd69"
9399
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
94-
version = "0.9.6"
100+
version = "0.9.7"
95101

96102
[[Pkg]]
97103
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
@@ -155,10 +161,10 @@ deps = ["LinearAlgebra", "SparseArrays"]
155161
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
156162

157163
[[StatsBase]]
158-
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
159-
git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94"
164+
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
165+
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
160166
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
161-
version = "0.29.0"
167+
version = "0.30.0"
162168

163169
[[StatsFuns]]
164170
deps = ["Rmath", "SpecialFunctions", "Test"]

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
name = "EnsembleKalmanInversion"
1+
name = "Solus"
22
uuid = "16b9c28e-565c-11e9-13a9-53e2645fa95d"
33
authors = ["Simon Byrne <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
8+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
911
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# EnsembleKalmanInversion.jl
1+
# Solus.jl
22

33
[![Travis Build Status](https://travis-ci.org/climate-machine/EnsembleKalmanInversion.jl.svg?branch=master)](https://travis-ci.org/climate-machine/EnsembleKalmanInversion.jl)
44
[![Appveyor Build Status](https://ci.appveyor.com/api/projects/status/s7l958ngcd3efv9t/branch/master?svg=true)](https://ci.appveyor.com/project/simonbyrne/ensemblekalmaninversion-jl/branch/master)

src/EnsembleKalmanInversion.jl

Lines changed: 0 additions & 104 deletions
This file was deleted.

src/Solus.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
module Solus
2+
3+
using Distributions, Statistics, LinearAlgebra, DocStringExtensions
4+
5+
include("spaces.jl")
6+
7+
"""
8+
SolusProblem
9+
10+
An uncertainty quantification problem.
11+
12+
# Fields
13+
14+
$(DocStringExtensions.FIELDS)
15+
"""
16+
struct SolusProblem{P,M,O,S<:HilbertSpace}
17+
"""
18+
The prior distribution for the parameters ``θ``. Currently only `MvNormal` objects are supported.
19+
"""
20+
prior::P
21+
22+
"""
23+
A forward model maps the parameter ``θ`` to a predicted observation.
24+
"""
25+
forwardmodel::M
26+
27+
"""
28+
The observed data.
29+
"""
30+
obs::O
31+
32+
"""
33+
The space on which the forward model output and observation exist. `DefaultSpace()` is the default.
34+
"""
35+
space::S
36+
end
37+
SolusProblem(prior, forwardmodel, obs) = SolusProblem(prior, forwardmodel, obs, DefaultSpace())
38+
39+
struct FlatPrior
40+
end
41+
42+
Distributions.logpdf(::FlatPrior, x) = 0.0
43+
44+
function neglogposteriordensity(s::SolusProblem, θ)
45+
-logpdf(s.prior, θ) + norm(s.forwardmodel(θ) - s.obs, s.space)
46+
end
47+
48+
49+
"""
50+
Ensemble
51+
52+
An ensemble of `inputs` and their corresponding `outputs` from the forward model.
53+
54+
# Fields
55+
56+
$(DocStringExtensions.FIELDS)
57+
"""
58+
struct Ensemble{I,O}
59+
"""
60+
A matrix of inputs: each column corresponds to a vector of ``θ``s
61+
"""
62+
inputs::Vector{I}
63+
"""
64+
Result of forward model for each column of `inputs`.
65+
"""
66+
outputs::Vector{O}
67+
end
68+
69+
"""
70+
neki_iter(prob::SolusProblem, ens::Ensemble)
71+
72+
Peform an iteration of NEKI, returning a new `Ensemble` object.
73+
"""
74+
function neki_iter(prob::SolusProblem, ens::Ensemble)
75+
covθ = cov(ens.inputs)
76+
covθ += (tr(covθ)*1e-15)I
77+
78+
m = mean(ens.outputs)
79+
80+
CG = [dot(Gθk - m, Gθj - prob.obs, prob.space) for Gθj in ens.outputs, Gθk in ens.outputs]
81+
82+
Δt = 0.1 / norm(CG) # use better constants
83+
implicit = lu( I + Δt .* (covθ / cov(prob.prior)) ) # todo: incorporate means
84+
noise = MvNormal(covθ)
85+
86+
inputs = map(enumerate(ens.inputs)) do (j, θj)
87+
X = sum(enumerate(ens.inputs)) do (k, θk)
88+
CG[k,j]*θk
89+
end
90+
rhs = θj .- Δt .* X
91+
(implicit \ rhs) .+ sqrt(Δt)*rand(noise)
92+
end
93+
94+
outputs = map(prob.forwardmodel, inputs)
95+
Ensemble(inputs, outputs)
96+
end
97+
98+
99+
end # module

src/spaces.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import LinearAlgebra: norm, dot
2+
3+
"""
4+
HilbertSpace
5+
6+
A `Space <: HilbertSpace` defines methods for
7+
- `norm(x, ::Space)`
8+
- `dot(x, y, ::Space)`
9+
"""
10+
abstract type HilbertSpace
11+
end
12+
13+
"""
14+
DefaultSpace
15+
16+
Uses the default `norm` and `dot` implementations (e.g. `Array`s and `Array`s of `Array`s).
17+
"""
18+
struct DefaultSpace <: HilbertSpace
19+
end
20+
norm(x, ::DefaultSpace) = norm(x)
21+
dot(x,y, ::DefaultSpace) = dot(x,y)
22+
23+
24+
struct CovarianceSpace{M} <: HilbertSpace
25+
Γ::M
26+
end
27+
norm(x, c::CovarianceSpace) = x'*(c.Γ\x)
28+
dot(x,y, c::CovarianceSpace) = x'*(c.Γ\y)
29+
30+
# TODO: CovarianceSpace/PDSpace with a positive definite matrix

test/runtests.jl

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,2 @@
11
using Test
2-
using EnsembleKalmanInversion, Distributions
3-
4-
let
5-
iter_max = 5
6-
J = 100
7-
n = 5
8-
r = 0.1
9-
# r = 0.0
10-
11-
A = rand(Normal(0, 2), (n,n))
12-
u_t = rand(Normal(0, 3), n)
13-
g_t = A * u_t
14-
15-
cov = rand(Normal(0, 2), (n,n))
16-
cov = cov' * cov
17-
cov *= r^2
18-
19-
u_init = rand(Normal(0, 2), (n,J))
20-
g_ens = A * u_init
21-
22-
eki = EKI(g_t, cov, u_init)
23-
24-
iter = 0
25-
while iter < iter_max
26-
g_ens = A * eki.u[end]
27-
EnsembleKalmanInversion.update!(eki, g_ens)
28-
if EnsembleKalmanInversion.residual(eki) < 0.01
29-
break
30-
end
31-
iter += 1
32-
end
33-
34-
@test u_t vec(EnsembleKalmanInversion.get_u(eki)) atol=0.01
35-
end
2+
using Solus

0 commit comments

Comments
 (0)