Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions exp/compare_resampling_methods.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using BinomialSynapses, Plots
using Statistics

function f1(sim, time)
return abs2(mean(sim.fstate.model.N) - only(sim.hmodel.N)),
abs2(mean(sim.fstate.model.p) - only(sim.hmodel.p)),
abs2(mean(sim.fstate.model.q) - only(sim.hmodel.q)),
abs2(mean(sim.fstate.model.τ) - only(sim.hmodel.τ)),
abs2(mean(sim.fstate.model.σ) - only(sim.hmodel.σ))
end

f2(data) = nothing

function make_sim(rm)
return NestedFilterSimulation(
10, 0.85, 1.0, 0.2, 0.2, # ground truth parameters
1:20, # parameter ranges for filter
LinRange(0.05, 0.95, 100), # .
LinRange(0.10, 2.00, 100), # .
LinRange(0.05, 2.00, 100), # .
LinRange(0.05, 2.00, 100), # .
2048, 512, # outer and inner number of particles
12; # jittering kernel width
resampling_method = rm
)
end

function run_and_record!(rec, rm)
sim = make_sim(rm)
run!(sim; T = 1000, recording = rec)
return rec
end

function run_and_record(rm)
sim = make_sim(rm)
rec = Recording(f1, f2, sim)
run!(sim; T = 1000, recording = rec)
return rec
end

rec1 = run_and_record(Multinomial())
rec2 = run_and_record(Stratified())

for i in 1:99
run_and_record!(rec1, Multinomial())
run_and_record!(rec2, Stratified())
end

pop!(rec1.data)
pop!(rec2.data)

dat1 = reduce(hcat, mean(collect, reshape(rec1.data, 1000, 100); dims=2)[:, 1])
dat2 = reduce(hcat, mean(collect, reshape(rec2.data, 1000, 100); dims=2)[:, 1])

names = ["N", "p", "q", "τ", "σ"]
for (i, name) in enumerate(names)
plt = plot(; ylabel = "E(Estimated $name - True $name)²", xlabel = "Time")
plot!(plt, dat1[i, :]; label = "Multinomial resampling")
plot!(plt, dat2[i, :]; label = "Stratified resampling")
savefig(plt, "exp/resampling_$name.svg")
end
48 changes: 48 additions & 0 deletions exp/resampling_N.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
46 changes: 46 additions & 0 deletions exp/resampling_p.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
46 changes: 46 additions & 0 deletions exp/resampling_q.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
46 changes: 46 additions & 0 deletions exp/resampling_σ.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
46 changes: 46 additions & 0 deletions exp/resampling_τ.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 4 additions & 3 deletions src/BinomialSynapses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ export Timestep, FixedTimestep, RandomTimestep, DeterministicTrain
include("emission.jl")
export emit

include("resample.jl")
export outer_resample!, indices!, resample!
export Multinomial, Stratified

include("likelihood.jl")
export likelihood, likelihood_resample!

include("jitter.jl")
export jitter!

include("resample.jl")
export outer_resample!, indices!, resample!

include("filter.jl")
export NestedParticleFilter, NestedParticleState, update!

Expand Down
9 changes: 6 additions & 3 deletions src/filter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@

Construct a nested particle filter with a given jittering kernel width parameter.
"""
struct NestedParticleFilter
struct NestedParticleFilter{T <: ResamplingMethod}
jittering_width::Int
resampling_method::T
end

NestedParticleFilter(jw::Int) = NestedParticleFilter(jw, Multinomial())

"""
NestedParticleState(state, model)

Expand Down Expand Up @@ -54,7 +57,7 @@ function update!(

jitter!(model, filter.jittering_width)
propagate!(state, model, observation.dt)
u = likelihood_resample!(state, model, observation)
outer_resample!(state, model, u)
u = likelihood_resample!(state, model, observation, filter.resampling_method)
outer_resample!(state, model, u, filter.resampling_method)
return filterstate
end
14 changes: 10 additions & 4 deletions src/likelihood.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ gauss(x, μ, σ) = exp(-((x-μ)/σ)^2/2)
function likelihood_indices(
k,
model::AbstractBinomialModel,
obs
obs,
rm::ResamplingMethod
)
v = gauss.(obs, model.q .* k, model.σ)
u, idx = indices!(v)
u, idx = indices!(v, rm)

# normalization of u
T = eltype(model.σ)
Expand All @@ -37,8 +38,13 @@ end

Return the likelihood of an observation conditioned on the current state and model ensemble and at the same time resample the state ensemble (inner particles).
"""
function likelihood_resample!(state::BinomialState, model, obs::BinomialObservation)
u, idx = likelihood_indices(state.k, model, obs.EPSP)
function likelihood_resample!(
state::BinomialState,
model,
obs::BinomialObservation,
rm::ResamplingMethod
)
u, idx = likelihood_indices(state.k, model, obs.EPSP, rm)
resample!(state, idx)
return u
end
188 changes: 179 additions & 9 deletions src/resample.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
function outer_indices!(u::AbstractVector)
abstract type ResamplingMethod end
struct Multinomial <: ResamplingMethod end
struct Stratified <: ResamplingMethod end

function outer_indices!(u::AbstractVector, rm::ResamplingMethod)
uu = Array(u)
usum, idx = outer_indices!(uu)
usum, idx = outer_indices!(uu, rm)
return usum, cu(idx)
end

function outer_indices!(u::Vector)
function outer_indices!(u::Vector, ::Multinomial)
M_out = length(u)
usum = zero(eltype(u))

Expand Down Expand Up @@ -41,13 +45,44 @@ function outer_indices!(u::Vector)
return usum, idx
end


function outer_indices!(u::Vector, ::Stratified)
M_out = length(u)
usum = zero(eltype(u))

# compute cumulative sum, overwriting u
@inbounds for i in 1:M_out
usum += u[i]
u[i] = usum
end

# shift cumulative sums to the right by one
@inbounds for i in M_out:-1:2
u[i] = u[i-1]
end
u[1] = 0

idx = zeros(Int, M_out)
bindex = M_out # bin index
@inbounds for i in M_out:-1:1
rsample = (i - 1 + rand(eltype(u))) * usum / M_out
# checking bindex >= 1 is redundant since
# ucum[1] = 0
while rsample < u[bindex]
bindex -= 1
end
idx[i] = bindex
end
return usum, idx
end

"""
outer_resample!(state, model, u)

Resample the outer particles of the `state` and `model` ensemble based on their likelihoods `u`.
"""
function outer_resample!(state, model, u)
usum, idx = indices!(u)
function outer_resample!(state, model, u, resampling_method)
_, idx = indices!(u, resampling_method)
resample!(state, idx)
resample!(model, idx)
return state, model
Expand All @@ -60,9 +95,11 @@ Return index table and total likelihoods from likelihood table `v`.
This function modifies v; after execution, v will be the cumulative sum of the original v
along the last dimension.
"""
indices!(v::AnyCuVector) = outer_indices!(v)
indices!(v::AnyCuVector, rm::ResamplingMethod) = outer_indices!(v, rm)
indices!(v::AnyCuVector, rm::Multinomial) = outer_indices!(v, rm)
indices!(v::AnyCuVector, rm::Stratified) = outer_indices!(v, rm)

function indices!(v::AnyCuArray)
function indices!(v::AnyCuArray, ::Multinomial)
function kernel!(
u, v, idx, r,
Rout, M_out, M_in
Expand Down Expand Up @@ -155,10 +192,97 @@ function indices!(v::AnyCuArray)
return u, idx
end


function indices!(v::AnyCuArray, ::Stratified)
function kernel!(
u, v, idx, r,
Rout, M_out, M_in
)
# grid-stride loop
tid = threadIdx().x
window = (blockDim().x - 1i32) * gridDim().x
offset = (blockIdx().x - 1i32) * blockDim().x
while offset < M_out
id = tid + offset
vsum = 0f0
for j in 1:M_in
if id <= M_out
@inbounds i = Rout[id]
if u[i] < 0 # prevents visiting same `i' more than once
@inbounds vsum = v[i, j] += vsum
@inbounds r[i, j] = (j - 1 + rand(Float32)) / M_in
end
end
end
if id <= M_out
@inbounds i = Rout[id]

if u[i] < 0 # prevents visiting same `i' more than once
# compute average likelihood across inner particles
# (with normalization constant that was omitted from v for speed)
@inbounds u[i] = vsum

# O(n) binning algorithm for sorted samples
bindex = 1 # bin index
for j in 1:M_in
# scale random numbers (this is equivalent to normalizing v)
@inbounds rsample = r[i, j] * vsum
# checking bindex <= M_in - 1 not necessary since
# v[i, M_in] = vsum
@inbounds while rsample > v[i, bindex]
bindex += 1
end
@inbounds idx[i, j] = bindex
end
end
end

offset += window
end
return nothing
end

# initializations:

# indices
idx = CuArray{Int}(undef, size(v)...)

# outer likelihoods
# Initialize to -1 in order to track which elements have been written to.
# Since likelihoods are nonnegative, negative elements have never been visited.
u = CUDA.fill(-one(Float32), size(v)[1:end-1]...)

# random numbers
r = CuArray{Float32}(undef, size(v)...)

Rout = CartesianIndices(u) # indices for first n-1 dimensions
M_out = length(u)
M_in = last(size(v))

kernel = @cuda launch=false kernel!(
u, v, idx, r,
Rout, M_out, M_in
)
config = launch_configuration(kernel.fun)
threads = max(32, min(config.threads, M_out))
blocks = cld(M_out, threads)
kernel(
u, v,
idx,
r,
Rout, M_out, M_in
;
threads=threads, blocks=blocks
)
return u, idx
end

# CPU fallbacks:
indices!(v::AbstractVector) = outer_indices!(v)
indices!(v::AbstractVector, rm::ResamplingMethod) = outer_indices!(v, rm)
indices!(v::AbstractVector, rm::Multinomial) = outer_indices!(v, rm)
indices!(v::AbstractVector, rm::Stratified) = outer_indices!(v, rm)

function indices!(v::AbstractArray{T}) where T
function indices!(v::AbstractArray{T}, ::Multinomial) where T
# initializations:

# indices
Expand Down Expand Up @@ -212,6 +336,52 @@ function indices!(v::AbstractArray{T}) where T
return u, idx
end

function indices!(v::AbstractArray{T}, ::Stratified) where T
# initializations:

# indices
idx = Array{Int}(undef, size(v)...)

# outer likelihoods
# Initialize to -1 in order to track which elements have been written to.
# Since likelihoods are nonnegative, negative elements have never been visited.
u = zeros(T, size(v)[1:end-1]...)

# random numbers
r = Array{T}(undef, size(v)...)

Rout = CartesianIndices(u) # indices for first n-1 dimensions
M_out = length(u)
M_in = last(size(v))

for id in 1:M_out
vsum = zero(T)
@inbounds i = Rout[id]
for j in 1:M_in
@inbounds v[i, j] += vsum
@inbounds vsum = v[i, j]
@inbounds r[i, j] = (j - 1 + rand(Float32)) / M_in
end
# compute average likelihood across inner particles
# (with normalization constant that was omitted from v for speed)
@inbounds u[i] = vsum

# O(n) binning algorithm for sorted samples
bindex = 1 # bin index
for j in 1:M_in
# scale random numbers (this is equivalent to normalizing v)
@inbounds rsample = r[i, j] * vsum
# checking bindex <= M_in - 1 not necessary since
# v[i, M_in] = vsum
@inbounds while rsample > v[i, bindex]
bindex += 1
end
@inbounds idx[i, j] = bindex
end
end
return u, idx
end

function resample!(in::AnyCuArray, out::AnyCuArray, idx::AnyCuArray)
function kernel(in, out, idx, Ra, R1, R2, R3)
i = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x
Expand Down
5 changes: 3 additions & 2 deletions src/simulate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ function NestedFilterSimulation(
Nrng, prng, qrng, σrng, τrng,
m_out, m_in, width;
timestep::Timestep = RandomTimestep(Exponential(0.121)),
device::Symbol = :gpu
device::Symbol = :gpu,
resampling_method::ResamplingMethod = Multinomial()
)
hmodel = ScalarBinomialModel(N, p, q, σ, τ)
filter = NestedParticleFilter(width)
filter = NestedParticleFilter(width, resampling_method)
hstate = ScalarBinomialState(N, 0)
fstate = NestedParticleState(
m_out, m_in,
Expand Down
Loading