Skip to content

Commit 409216b

Browse files
committed
Merge branch 'ultradeep_Alens'
2 parents f4c5b4f + 2690a4a commit 409216b

File tree

6 files changed

+105
-65
lines changed

6 files changed

+105
-65
lines changed

src/CMBLensing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ using Zygote: unbroadcast, Numeric, @adjoint, @nograd
4545

4646

4747
import Adapt: adapt_structure
48-
import Base: +, -, *, \, /, ^, ~, , <, <=, |, &, ==,
48+
import Base: +, -, *, \, /, ^, ~, , <, <=, |, &, ==, !,
4949
abs, adjoint, all, any, axes, broadcast, broadcastable, BroadcastStyle, conj, copy, convert,
5050
copy, copyto!, eltype, eps, fill!, getindex, getproperty, hash, hcat, hvcat, inv, isfinite,
5151
iterate, keys, lastindex, length, literal_pow, mapreduce, materialize!,

src/chains.jl

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,52 @@ import Base: getindex, lastindex
22

33

44
@doc doc"""
5-
load_chains(filename; burnin=0, thin=1, join=false)
6-
7-
Load a single chain or multiple parallel chains which were written to a file by
8-
[`sample_joint`](@ref).
5+
load_chains(filename; burnin=0, burnin_chunks=0, thin=1, join=false, unbatch=true)
6+
7+
Load a single chain or multiple parallel chains which were written to
8+
a file by [`sample_joint`](@ref).
99
1010
Keyword arguments:
1111
12-
* `burnin` — Remove this many samples from the start of each chain.
12+
* `burnin` — Remove this many samples from the start of each chain, or
13+
if negative, keep only this many samples at the end of each chain.
14+
* `burnin_chunks` — Same as burnin, but in terms of chain "chunks"
15+
stored in the chain file, rather than in terms of samples.
1316
* `thin` — If `thin` is an integer, thin the chain by this factor. If
14-
`thin == :hasmaps`, return only samples which have maps saved. If thin is a
15-
`Function`, filter the chain by this function (e.g. `thin=haskey(:g)` on Julia 1.5+)
16-
* `unbatch` — If true, [unbatch](@ref) the chains if they are batched.
17+
`thin == :hasmaps`, return only samples which have maps saved. If
18+
thin is a `Function`, filter the chain by this function (e.g.
19+
`thin=haskey(:g)` on Julia 1.5+)
20+
* `unbatch` — If true, [unbatch](@ref) the chains if they are batched.
1721
* `join` — If true, concatenate all the chains together.
1822
* `skip_missing_chunks` — Skip missing chunks in the chain instead of
1923
terminating the chain there.
2024
2125
22-
The object returned by this function is a `Chain` or `Chains` object, which
23-
simply wraps an `Array` of `Dicts` or an `Array` of `Array` of `Dicts`,
24-
respectively (each sample is a `Dict`). The wrapper object has some extra
25-
indexing properties for convenience:
26+
The object returned by this function is a `Chain` or `Chains` object,
27+
which simply wraps an `Array` of `Dicts` or an `Array` of `Array` of
28+
`Dicts`, respectively (each sample is a `Dict`). The wrapper object
29+
has some extra indexing properties for convenience:
2630
27-
* It can be indexed as if it were a single multidimensional object, e.g.
28-
`chains[1,:,:accept]` would return the `:accept` key of all samples in the
29-
first chain.
30-
* Leading colons can be dropped, i.e. `chains[:,:,:accept]` is the same as
31-
`chains[:accept]`.
32-
* If some samples are missing a particular key, `missing` is returned for those
33-
samples insted of an error.
34-
* The recursion goes arbitrarily deep into the objects it finds. E.g., since
35-
sampled parameters are stored in a `NamedTuple` like `(Aϕ=1.3,)` in the `θ`
36-
key of each sample `Dict`, you can do `chain[:θ,:Aϕ]` to get all `Aϕ` samples
37-
as a vector.
31+
* It can be indexed as if it were a single multidimensional object,
32+
e.g. `chains[1,:,:accept]` would return the `:accept` key of all
33+
samples in the first chain.
34+
* Leading colons can be dropped, i.e. `chains[:,:,:accept]` is the
35+
same as `chains[:accept]`.
36+
* If some samples are missing a particular key, `missing` is returned
37+
for those samples insted of an error.
38+
* The recursion goes arbitrarily deep into the objects it finds. E.g.,
39+
since sampled parameters are stored in a `NamedTuple` like
40+
`(Aϕ=1.3,)` in the `θ` key of each sample `Dict`, you can do
41+
`chain[:θ,:Aϕ]` to get all `Aϕ` samples as a vector.
3842
3943
4044
"""
41-
function load_chains(filename; burnin=0, thin=1, join=false, unbatch=true, dropmaps=false)
45+
function load_chains(filename; burnin=0, thin=1, join=false, unbatch=true, dropmaps=false, burnin_chunks=0)
4246
chains = jldopen(filename) do io
4347
ks = keys(io)
44-
chunk_ks = [k for k in ks if startswith(k,"chunks_")]
45-
for (isfirst,k) in flagfirst(sort(chunk_ks, by=k->parse(Int,k[8:end])))
48+
chunk_ks = sort([k for k in ks if startswith(k,"chunks_")], by=k->parse(Int,k[8:end]))
49+
chunk_ks = chunk_ks[burnin_chunks>=0 ? (burnin_chunks+1:end) : (end+burnin_chunks+1:end)]
50+
for (isfirst,k) in flagfirst(chunk_ks)
4651
if isfirst
4752
chains = read(io,k)
4853
else
@@ -55,13 +60,13 @@ function load_chains(filename; burnin=0, thin=1, join=false, unbatch=true, dropm
5560
chains
5661
end
5762
if thin isa Int
58-
chains = [chain[(1+burnin):thin:end] for chain in chains]
63+
chains = [chain[burnin>=0 ? ((1+burnin):thin:end) : (end+(1+burnin):thin:end)] for chain in chains]
5964
elseif thin == :hasmaps
6065
chains = [[samp for samp in chain[(1+burnin):end] if in keys(samp)] for chain in chains]
6166
elseif thin isa Function
6267
chains = [filter(thin,chain) for chain in chains]
6368
else
64-
error("`thin` should be an Int or :hasmaps")
69+
error("`thin` should be an Int, :hasmaps, or a filter function")
6570
end
6671
chains = wrap_chains(chains)
6772
if unbatch
@@ -121,8 +126,8 @@ _getindex(x::Union{Dict,NamedTuple}, k::Symbol) = haskey(x,k) ? getindex(x, k) :
121126
_getindex(x, k) = getindex(x, k)
122127

123128

124-
wrap_chains(chains::Vector{<:Vector{<:Dict}}) = Chains(Chain.(chains))
125-
wrap_chains(chain::Vector{<:Dict}) = Chain(chain)
129+
wrap_chains(chains::Vector{<:Vector}) = Chains(Chain.(chains))
130+
wrap_chains(chain::Vector) = Chain(chain)
126131

127132

128133
# batching

src/flat_batch.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ for op in [:+, :-, :*, :/, :<, :<=, :&, :|, :(==)]
102102
($op)(a::Real, b::BatchedReal) = batch(broadcast(($op), a, b.vals))
103103
end
104104
end
105-
for op in [:-, :sqrt, :one, :zero, :isfinite, :eps]
105+
for op in [:-, :!, :sqrt, :one, :zero, :isfinite, :eps]
106106
@eval ($op)(br::BatchedReal) = batch(broadcast(($op),br.vals))
107107
end
108108
for op in [:any, :all]

src/gpu.jl

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -128,34 +128,19 @@ gc = () -> (GC.gc(true); CUDA.reclaim())
128128
129129
Assuming you submitted a SLURM job and got several GPUs, possibly across several
130130
nodes, this assigns each Julia worker process a unique GPU using `CUDA.device!`.
131-
Assumes the SLURM variables `SLURM_STEP_GPUS` and `GPU_DEVICE_ORDINAL` are
132-
defined on the workers.
133131
"""
134132
function assign_GPU_workers()
135-
@everywhere @eval using CUDA, Distributed
136-
topo = @eval Main pmap(workers()) do _
137-
hostname = gethostname()
138-
virtgpus = parse.(Int,split(ENV["GPU_DEVICE_ORDINAL"],","))
139-
if "SLURM_STEP_GPUS" in keys(ENV)
140-
physgpus = parse.(Int,split(ENV["SLURM_STEP_GPUS"],","))
141-
else
142-
@warn "SLURM_STEP_GPUS not defined, assign_GPU_workers may fail."
143-
# SLURM_STEP_GPUS seems not correctly set on all systems. this
144-
# will work if you requested a full node's worth of GPUs at least
145-
physgpus = virtgpus
146-
end
147-
if Set(virtgpus)!=Set(deviceid.(devices()))
148-
@warn "Virtual GPUs not same as CUDA.devices(), using latter"
149-
virtgpus = deviceid.(devices())
150-
end
151-
(i=myid(), hostname=hostname, virtgpus=virtgpus, physgpus=physgpus)
152-
end
133+
@everywhere @eval Main using CUDA, Distributed
134+
accessible_gpus = @eval Main Dict(pmap(workers()) do _
135+
ds = CUDA.devices()
136+
myid() => Dict(CUDA.deviceid.(ds) .=> CUDA.uuid.(ds))
137+
end)
153138
claimed = Set()
154-
assignments = Dict(map(topo) do (i,hostname,virtgpus,physgpus)
155-
for (virtgpu,physgpu) in zip(virtgpus,physgpus)
156-
if !((hostname,physgpu) in claimed)
157-
push!(claimed,(hostname,physgpu))
158-
return i => virtgpu
139+
assignments = Dict(map(workers()) do myid
140+
for (gpu_id, gpu_uuid) in accessible_gpus[myid]
141+
if !(gpu_uuid in claimed)
142+
push!(claimed, gpu_uuid)
143+
return myid => gpu_id
159144
end
160145
end
161146
end)

src/numerical_algorithms.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,32 @@ function gmres(A, b; Pl=I, maxiter)
205205
view(K, :, 1:n) * α
206206

207207
end
208+
209+
"""
210+
finite_second_derivative(x)
211+
212+
Second derivative of a vector `x` via finite differences, including at end points.
213+
"""
214+
function finite_second_derivative(x)
215+
map(eachindex(x)) do i
216+
if i==1
217+
x[3]-2x[2]+x[1]
218+
elseif i==length(x)
219+
x[end]-2x[end-1]+x[end-2]
220+
else
221+
x[i+1]-2x[i]+x[i-1]
222+
end
223+
end
224+
end
225+
226+
"""
227+
longest_run_of_trues(x)
228+
229+
The slice corresponding to the longest run of `true`s in the vector `x`.
230+
"""
231+
function longest_run_of_trues(x)
232+
next_true = findnext.(Ref(.!x), eachindex(x))
233+
next_true[isnothing.(next_true)] .= 0
234+
(len,start) = findmax(next_true .- eachindex(x))
235+
start:start+len
236+
end

src/sampling.jl

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,23 +85,38 @@ function grid_and_sample(lnPs::Vector{<:BatchedReal}, xs::AbstractVector; kwargs
8585
((batch(getindex.(batches,i)) for i=1:3)...,)
8686
end
8787

88-
function grid_and_sample(lnPs::Vector, xs::AbstractVector; progress=false, nsamples=1, span=0.25, rtol=1e-5)
88+
function grid_and_sample(lnPs::Vector, xs::AbstractVector; progress=false, nsamples=1, span=0.25, require_convex=false)
8989

90+
# trim leading/trailing zero-probability regions
91+
support = findnext(isfinite,lnPs,1):findprev(isfinite,lnPs,length(lnPs))
92+
xs = xs[support]
93+
lnPs = lnPs[support]
94+
95+
if require_convex
96+
support = longest_run_of_trues(finite_second_derivative(lnPs) .< 0)
97+
xs = xs[support]
98+
lnPs = lnPs[support]
99+
end
100+
101+
# interpolate PDF
90102
xmin, xmax = first(xs), last(xs)
91103
lnPs = lnPs .- maximum(lnPs)
92104
ilnP = loess(xs, lnPs, span=span)
93105

94106
# normalize the PDF. note the smoothing is done of the log PDF.
95-
A = @ondemand(QuadGK.quadgk)(expilnP, xmin, xmax)[1]
96-
lnPs .-= log(A)
97-
ilnP = loess(xs, lnPs, span=span)
107+
cdf(x) = @ondemand(QuadGK.quadgk)(nan2zeroexpilnP,xmin,x,rtol=1e-5)[1]
108+
logA = nan2zero(log(cdf(xmax)))
109+
lnPs = (ilnP.ys .-= logA)
110+
ilnP.bs[:,1] .-= logA
98111

99112
# draw samples via inverse transform sampling
100-
# (the `+ eps()`` is a workaround since Loess.predict seems to NaN sometimes when
101-
# evaluated right at the lower bound)
102113
θsamples = @showprogress (progress ? 1 : Inf) map(1:nsamples) do i
103114
r = rand()
104-
fzero((x->@ondemand(QuadGK.quadgk)(expilnP,xmin+sqrt(eps()),x,rtol=rtol)[1]-r),xmin+sqrt(eps()),xmax,rtol=rtol)
115+
if (cdf(xmin)-r)*(cdf(xmax)-r) >= 0
116+
first(lnPs) > last(lnPs) ? xmin : xmax
117+
else
118+
fzero(x->cdf(x)-r, xmin, xmax, xatol=(xmax-xmin)*1e-3)
119+
end
105120
end
106121

107122
(nsamples==1 ? θsamples[1] : θsamples), ilnP, lnPs
@@ -213,8 +228,10 @@ function sample_joint(
213228

214229
θstarts = if θstart == :prior
215230
[map(range->batch((first(range) .+ rand(D) .* (last(range) - first(range)))...), θrange) for i=1:nchains]
216-
elseif (θstart isa NamedTuple)
231+
elseif θstart isa NamedTuple
217232
fill(θstart, nchains)
233+
elseif θstart isa Vector{<:NamedTuple}
234+
θstart
218235
else
219236
error("`θstart` should be either `nothing` to randomly sample the starting value or a NamedTuple giving the starting point.")
220237
end
@@ -227,6 +244,8 @@ function sample_joint(
227244
fill(batch(zero(diag(ds().Cϕ)), D), nchains)
228245
elseif ϕstart isa Field
229246
fill(ϕstart, nchains)
247+
elseif ϕstart isa Vector{<:Field}
248+
ϕstart
230249
elseif ϕstart in [:quasi_sample, :best_fit]
231250
pmap(θstarts) do θstart
232251
MAP_joint(adapt(storage,ds(;θstart...)), progress=(progress==:verbose ? :summary : false), Nϕ=adapt(storage,Nϕ), quasi_sample=(ϕstart==:quasi_sample); MAP_kwargs...).ϕ
@@ -268,9 +287,11 @@ function sample_joint(
268287

269288
for chunks_index = (chunks_index+1):(nsamps_per_chain÷nchunk+1)
270289

290+
println("starting")
271291
last_chunks = pmap(last.(last_chunks)) do state
272292

273293
@unpack i,ϕ°,f,θ = state
294+
@show i
274295
f,ϕ°,ds,Nϕ = (adapt(storage, x) for x in (f,ϕ°,dsₐ,Nϕₐ))
275296
dsθ = ds(θ)
276297
ϕ = dsθ.G\ϕ°

0 commit comments

Comments
 (0)