Skip to content

Commit 2690a4a

Browse files
committed
add require_convex option to grid_and_sample
1 parent c94ed3d commit 2690a4a

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

src/numerical_algorithms.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,32 @@ function gmres(A, b; Pl=I, maxiter)
205205
view(K, :, 1:n) * α
206206

207207
end
208+
209+
"""
210+
finite_second_derivative(x)
211+
212+
Second derivative of a vector `x` via finite differences, including at end points.
213+
"""
214+
function finite_second_derivative(x)
215+
map(eachindex(x)) do i
216+
if i==1
217+
x[3]-2x[2]+x[1]
218+
elseif i==length(x)
219+
x[end]-2x[end-1]+x[end-2]
220+
else
221+
x[i+1]-2x[i]+x[i-1]
222+
end
223+
end
224+
end
225+
226+
"""
227+
longest_run_of_trues(x)
228+
229+
The slice corresponding to the longest run of `true`s in the vector `x`.
230+
"""
231+
function longest_run_of_trues(x)
232+
next_true = findnext.(Ref(.!x), eachindex(x))
233+
next_true[isnothing.(next_true)] .= 0
234+
(len,start) = findmax(next_true .- eachindex(x))
235+
start:start+len
236+
end

src/sampling.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,18 @@ function grid_and_sample(lnPs::Vector{<:BatchedReal}, xs::AbstractVector; kwargs
8585
((batch(getindex.(batches,i)) for i=1:3)...,)
8686
end
8787

88-
function grid_and_sample(lnPs::Vector, xs::AbstractVector; progress=false, nsamples=1, span=0.25)
88+
function grid_and_sample(lnPs::Vector, xs::AbstractVector; progress=false, nsamples=1, span=0.25, require_convex=false)
8989

9090
# trim leading/trailing zero-probability regions
9191
support = findnext(isfinite,lnPs,1):findprev(isfinite,lnPs,length(lnPs))
9292
xs = xs[support]
9393
lnPs = lnPs[support]
94+
95+
if require_convex
96+
support = longest_run_of_trues(finite_second_derivative(lnPs) .< 0)
97+
xs = xs[support]
98+
lnPs = lnPs[support]
99+
end
94100

95101
# interpolate PDF
96102
xmin, xmax = first(xs), last(xs)
@@ -222,8 +228,10 @@ function sample_joint(
222228

223229
θstarts = if θstart == :prior
224230
[map(range->batch((first(range) .+ rand(D) .* (last(range) - first(range)))...), θrange) for i=1:nchains]
225-
elseif (θstart isa NamedTuple)
231+
elseif θstart isa NamedTuple
226232
fill(θstart, nchains)
233+
elseif θstart isa Vector{<:NamedTuple}
234+
θstart
227235
else
228236
error("`θstart` should be either `nothing` to randomly sample the starting value or a NamedTuple giving the starting point.")
229237
end
@@ -236,6 +244,8 @@ function sample_joint(
236244
fill(batch(zero(diag(ds().Cϕ)), D), nchains)
237245
elseif ϕstart isa Field
238246
fill(ϕstart, nchains)
247+
elseif ϕstart isa Vector{<:Field}
248+
ϕstart
239249
elseif ϕstart in [:quasi_sample, :best_fit]
240250
pmap(θstarts) do θstart
241251
MAP_joint(adapt(storage,ds(;θstart...)), progress=(progress==:verbose ? :summary : false), Nϕ=adapt(storage,Nϕ), quasi_sample=(ϕstart==:quasi_sample); MAP_kwargs...).ϕ
@@ -277,9 +287,11 @@ function sample_joint(
277287

278288
for chunks_index = (chunks_index+1):(nsamps_per_chain÷nchunk+1)
279289

290+
println("starting")
280291
last_chunks = pmap(last.(last_chunks)) do state
281292

282293
@unpack i,ϕ°,f,θ = state
294+
@show i
283295
f,ϕ°,ds,Nϕ = (adapt(storage, x) for x in (f,ϕ°,dsₐ,Nϕₐ))
284296
dsθ = ds(θ)
285297
ϕ = dsθ.G\ϕ°

0 commit comments

Comments
 (0)