Skip to content

Commit 22174a7

Browse files
committed
basic working model
0 parents  commit 22174a7

File tree

7 files changed

+391
-0
lines changed

7 files changed

+391
-0
lines changed

.appveyor.yml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
environment:
2+
matrix:
3+
- julia_version: 0.7
4+
- julia_version: 1
5+
- julia_version: nightly
6+
7+
platform:
8+
- x86 # 32-bit
9+
- x64 # 64-bit
10+
11+
# # Uncomment the following lines to allow failures on nightly julia
12+
# # (tests will run but not make your overall status red)
13+
# matrix:
14+
# allow_failures:
15+
# - julia_version: nightly
16+
17+
branches:
18+
only:
19+
- master
20+
- /release-.*/
21+
22+
notifications:
23+
- provider: Email
24+
on_build_success: false
25+
on_build_failure: false
26+
on_build_status_changed: false
27+
28+
install:
29+
- ps: iex ((new-object net.webclient).DownloadString("https://raw.githubusercontent.com/JuliaCI/Appveyor.jl/version-1/bin/install.ps1"))
30+
31+
build_script:
32+
- echo "%JL_BUILD_SCRIPT%"
33+
- C:\julia\bin\julia -e "%JL_BUILD_SCRIPT%"
34+
35+
test_script:
36+
- echo "%JL_TEST_SCRIPT%"
37+
- C:\julia\bin\julia -e "%JL_TEST_SCRIPT%"
38+
39+
# # Uncomment to support code coverage upload. Should only be enabled for packages
40+
# # which would have coverage gaps without running on Windows
41+
# on_success:
42+
# - echo "%JL_CODECOV_SCRIPT%"
43+
# - C:\julia\bin\julia -e "%JL_CODECOV_SCRIPT%"

.travis.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Documentation: http://docs.travis-ci.com/user/languages/julia/
2+
language: julia
3+
os:
4+
- linux
5+
- osx
6+
julia:
7+
- 1.0
8+
- 1.1
9+
- nightly
10+
notifications:
11+
email: false

Manifest.toml

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
3+
[[Arpack]]
4+
deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Random", "SparseArrays", "Test"]
5+
git-tree-sha1 = "1ce1ce9984683f0b6a587d5bdbc688ecb480096f"
6+
uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97"
7+
version = "0.3.0"
8+
9+
[[Base64]]
10+
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
11+
12+
[[BinDeps]]
13+
deps = ["Compat", "Libdl", "SHA", "URIParser"]
14+
git-tree-sha1 = "12093ca6cdd0ee547c39b1870e0c9c3f154d9ca9"
15+
uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
16+
version = "0.8.10"
17+
18+
[[BinaryProvider]]
19+
deps = ["Libdl", "Pkg", "SHA", "Test"]
20+
git-tree-sha1 = "055eb2690182ebc31087859c3dd8598371d3ef9e"
21+
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
22+
version = "0.5.3"
23+
24+
[[Compat]]
25+
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
26+
git-tree-sha1 = "84aa74986c5b9b898b0d1acaf3258741ee64754f"
27+
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
28+
version = "2.1.0"
29+
30+
[[DataStructures]]
31+
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
32+
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
33+
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
34+
version = "0.15.0"
35+
36+
[[Dates]]
37+
deps = ["Printf"]
38+
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
39+
40+
[[DelimitedFiles]]
41+
deps = ["Mmap"]
42+
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
43+
44+
[[Distributed]]
45+
deps = ["Random", "Serialization", "Sockets"]
46+
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
47+
48+
[[Distributions]]
49+
deps = ["Distributed", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test"]
50+
git-tree-sha1 = "dec0ebacfbc3a2126c614ab5e903c9ef063688d0"
51+
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
52+
version = "0.17.0"
53+
54+
[[InteractiveUtils]]
55+
deps = ["Markdown"]
56+
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
57+
58+
[[LibGit2]]
59+
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
60+
61+
[[Libdl]]
62+
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
63+
64+
[[LinearAlgebra]]
65+
deps = ["Libdl"]
66+
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
67+
68+
[[Logging]]
69+
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
70+
71+
[[Markdown]]
72+
deps = ["Base64"]
73+
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
74+
75+
[[Missings]]
76+
deps = ["Dates", "InteractiveUtils", "SparseArrays", "Test"]
77+
git-tree-sha1 = "d1d2585677f2bd93a97cfeb8faa7a0de0f982042"
78+
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
79+
version = "0.4.0"
80+
81+
[[Mmap]]
82+
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
83+
84+
[[OrderedCollections]]
85+
deps = ["Random", "Serialization", "Test"]
86+
git-tree-sha1 = "85619a3f3e17bb4761fe1b1fd47f0e979f964d5b"
87+
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
88+
version = "1.0.2"
89+
90+
[[PDMats]]
91+
deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"]
92+
git-tree-sha1 = "b6c91fc0ab970c0563cbbe69af18d741a49ce551"
93+
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
94+
version = "0.9.6"
95+
96+
[[Pkg]]
97+
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
98+
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
99+
100+
[[Printf]]
101+
deps = ["Unicode"]
102+
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
103+
104+
[[QuadGK]]
105+
deps = ["DataStructures", "LinearAlgebra", "Test"]
106+
git-tree-sha1 = "3ce467a8e76c6030d4c3786e7d3a73442017cdc0"
107+
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
108+
version = "2.0.3"
109+
110+
[[REPL]]
111+
deps = ["InteractiveUtils", "Markdown", "Sockets"]
112+
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
113+
114+
[[Random]]
115+
deps = ["Serialization"]
116+
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
117+
118+
[[Rmath]]
119+
deps = ["BinaryProvider", "Libdl", "Random", "Statistics", "Test"]
120+
git-tree-sha1 = "9a6c758cdf73036c3239b0afbea790def1dabff9"
121+
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
122+
version = "0.5.0"
123+
124+
[[SHA]]
125+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
126+
127+
[[Serialization]]
128+
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
129+
130+
[[SharedArrays]]
131+
deps = ["Distributed", "Mmap", "Random", "Serialization"]
132+
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
133+
134+
[[Sockets]]
135+
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
136+
137+
[[SortingAlgorithms]]
138+
deps = ["DataStructures", "Random", "Test"]
139+
git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
140+
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
141+
version = "0.3.1"
142+
143+
[[SparseArrays]]
144+
deps = ["LinearAlgebra", "Random"]
145+
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
146+
147+
[[SpecialFunctions]]
148+
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
149+
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
150+
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
151+
version = "0.7.2"
152+
153+
[[Statistics]]
154+
deps = ["LinearAlgebra", "SparseArrays"]
155+
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
156+
157+
[[StatsBase]]
158+
deps = ["DataStructures", "DelimitedFiles", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "Test"]
159+
git-tree-sha1 = "435707791dc85a67d98d671c1c3fcf1b20b00f94"
160+
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
161+
version = "0.29.0"
162+
163+
[[StatsFuns]]
164+
deps = ["Rmath", "SpecialFunctions", "Test"]
165+
git-tree-sha1 = "b3a4e86aa13c732b8a8c0ba0c3d3264f55e6bb3e"
166+
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
167+
version = "0.8.0"
168+
169+
[[SuiteSparse]]
170+
deps = ["Libdl", "LinearAlgebra", "SparseArrays"]
171+
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
172+
173+
[[Test]]
174+
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
175+
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
176+
177+
[[URIParser]]
178+
deps = ["Test", "Unicode"]
179+
git-tree-sha1 = "6ddf8244220dfda2f17539fa8c9de20d6c575b69"
180+
uuid = "30578b45-9adc-5946-b283-645ec420af67"
181+
version = "0.4.0"
182+
183+
[[UUIDs]]
184+
deps = ["Random", "SHA"]
185+
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
186+
187+
[[Unicode]]
188+
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

Project.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
name = "EnsembleKalmanInversion"
2+
uuid = "16b9c28e-565c-11e9-13a9-53e2645fa95d"
3+
authors = ["Simon Byrne <[email protected]>"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
8+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# EnsembleKalmanInversion.jl

src/EnsembleKalmanInversion.jl

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
module EnsembleKalmanInversion
2+
3+
export EKI
4+
5+
using Distributions, LinearAlgebra
6+
7+
struct EKI
8+
"""
9+
truth[k] is the measurement of the kth output
10+
"""
11+
truth::Vector{Float64}
12+
13+
"""
14+
cov[k1,k2] is the covariance of the (k1,k2) outputs
15+
"""
16+
cov::Matrix{Float64}
17+
18+
"""
19+
u[i][k,j] is the kth parameter of the jth ensemble for the ith iteration
20+
"""
21+
u::Vector{Matrix{Float64}}
22+
23+
"""
24+
size of the ensemble
25+
"""
26+
J::Int
27+
28+
"""
29+
g[i][k,j] is the kth output of the jth ensemble for the ith iteration
30+
"""
31+
g::Vector{Matrix{Float64}}
32+
33+
"""
34+
err[i] is the error of the ith iteration
35+
"""
36+
err::Vector{Float64}
37+
end
38+
39+
function EKI(truth::Vector, cov::Matrix, u_init::Matrix)
40+
J = size(u_init,2)
41+
g = Matrix{Float64}[]
42+
err = Float64[]
43+
EKI(truth, cov, [u_init], J, g, err)
44+
end
45+
46+
function get_u(eki)
47+
return mean(eki.u[end], dims=2)
48+
end
49+
50+
function get_g(eki)
51+
return mean(eki.g[end], dims=2)
52+
end
53+
54+
55+
function residual(eki)
56+
g = get_g(eki)
57+
diff = g - eki.truth
58+
err = dot(diff, eki.cov \ diff )
59+
push!(eki.err, err)
60+
return err
61+
end
62+
63+
64+
function update!(eki::EKI, g)
65+
66+
u = eki.u[end]
67+
g_t = eki.truth
68+
cov = eki.cov
69+
J = eki.J
70+
71+
us = size(u,1) # parameters x particles
72+
ps = size(g,1) # data x particles
73+
74+
u_bar = zeros(us)
75+
p_bar = zeros(ps)
76+
c_up = zeros((us, ps))
77+
c_pp = zeros((ps, ps))
78+
79+
for j in 1:J
80+
u_hat = u[:,j]
81+
p_hat = g[:,j]
82+
83+
u_bar .+= u_hat
84+
p_bar .+= p_hat
85+
86+
c_up .+= [u_hat[a] * p_hat[b] for a=1:us, b=1:ps]
87+
c_pp .+= [p_hat[a] * p_hat[b] for a=1:ps, b=1:ps]
88+
end
89+
90+
u_bar = u_bar / J
91+
p_bar = p_bar / J
92+
c_up = c_up / J - [u_bar[a] * p_bar[b] for a=1:us, b=1:ps]
93+
c_pp = c_pp / J - [p_bar[a] * p_bar[b] for a=1:ps, b=1:ps]
94+
95+
noise = rand(MvNormal(zeros(ps), cov), J)
96+
y = g_t .+ noise
97+
tmp = (c_pp + cov) \ (y - g)
98+
u += c_up * tmp
99+
100+
push!(eki.u, u)
101+
push!(eki.g, g)
102+
end
103+
104+
end # module

test/runtests.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using Test
2+
using EnsembleKalmanInversion, Distributions
3+
4+
let
5+
iter_max = 5
6+
J = 100
7+
n = 5
8+
r = 0.1
9+
# r = 0.0
10+
11+
A = rand(Normal(0, 2), (n,n))
12+
u_t = rand(Normal(0, 3), n)
13+
g_t = A * u_t
14+
15+
cov = rand(Normal(0, 2), (n,n))
16+
cov = cov' * cov
17+
cov *= r^2
18+
19+
u_init = rand(Normal(0, 2), (n,J))
20+
g_ens = A * u_init
21+
22+
eki = EKI(g_t, cov, u_init)
23+
24+
iter = 0
25+
while iter < iter_max
26+
g_ens = A * eki.u[end]
27+
EnsembleKalmanInversion.update!(eki, g_ens)
28+
if EnsembleKalmanInversion.residual(eki) < 0.01
29+
break
30+
end
31+
iter += 1
32+
end
33+
34+
@test u_t vec(EnsembleKalmanInversion.get_u(eki)) atol=0.01
35+
end

0 commit comments

Comments
 (0)