@@ -2,47 +2,52 @@ import Base: getindex, lastindex
22
33
44@doc doc"""
5- load_chains(filename; burnin=0, thin=1, join=false)
6-
7- Load a single chain or multiple parallel chains which were written to a file by
8- [`sample_joint`](@ref).
5+ load_chains(filename; burnin=0, burnin_chunks=0, thin=1, join=false, unbatch=true )
6+
7+ Load a single chain or multiple parallel chains which were written to
8+ a file by [`sample_joint`](@ref).
99
1010Keyword arguments:
1111
12- * `burnin` — Remove this many samples from the start of each chain.
12+ * `burnin` — Remove this many samples from the start of each chain, or
13+ if negative, keep only this many samples at the end of each chain.
14+ * `burnin_chunks` — Same as burnin, but in terms of chain "chunks"
15+ stored in the chain file, rather than in terms of samples.
1316* `thin` — If `thin` is an integer, thin the chain by this factor. If
14- `thin == :hasmaps`, return only samples which have maps saved. If thin is a
15- `Function`, filter the chain by this function (e.g. `thin=haskey(:g)` on Julia 1.5+)
16- * `unbatch` — If true, [unbatch](@ref) the chains if they are batched.
17+ `thin == :hasmaps`, return only samples which have maps saved. If
18+ thin is a `Function`, filter the chain by this function (e.g.
19+ `thin=haskey(:g)` on Julia 1.5+)
20+ * `unbatch` — If true, [unbatch](@ref) the chains if they are batched.
1721* `join` — If true, concatenate all the chains together.
1822* `skip_missing_chunks` — Skip missing chunks in the chain instead of
1923 terminating the chain there.
2024
2125
22- The object returned by this function is a `Chain` or `Chains` object, which
23- simply wraps an `Array` of `Dicts` or an `Array` of `Array` of `Dicts`,
24- respectively (each sample is a `Dict`). The wrapper object has some extra
25- indexing properties for convenience:
26+ The object returned by this function is a `Chain` or `Chains` object,
27+ which simply wraps an `Array` of `Dicts` or an `Array` of `Array` of
28+ `Dicts`, respectively (each sample is a `Dict`). The wrapper object
29+ has some extra indexing properties for convenience:
2630
27- * It can be indexed as if it were a single multidimensional object, e.g.
28- `chains[1,:,:accept]` would return the `:accept` key of all samples in the
29- first chain.
30- * Leading colons can be dropped, i.e. `chains[:,:,:accept]` is the same as
31- `chains[:accept]`.
32- * If some samples are missing a particular key, `missing` is returned for those
33- samples insted of an error.
34- * The recursion goes arbitrarily deep into the objects it finds. E.g., since
35- sampled parameters are stored in a `NamedTuple` like `(Aϕ=1.3,)` in the `θ`
36- key of each sample `Dict`, you can do `chain[:θ,:Aϕ]` to get all `Aϕ` samples
37- as a vector.
31+ * It can be indexed as if it were a single multidimensional object,
32+ e.g. `chains[1,:,:accept]` would return the `:accept` key of all
33+ samples in the first chain.
34+ * Leading colons can be dropped, i.e. `chains[:,:,:accept]` is the
35+ same as `chains[:accept]`.
36+ * If some samples are missing a particular key, `missing` is returned
37+ for those samples insted of an error.
38+ * The recursion goes arbitrarily deep into the objects it finds. E.g.,
39+ since sampled parameters are stored in a `NamedTuple` like
40+ `(Aϕ=1.3,)` in the `θ` key of each sample `Dict`, you can do
41+ `chain[:θ,:Aϕ]` to get all `Aϕ` samples as a vector.
3842
3943
4044"""
41- function load_chains (filename; burnin= 0 , thin= 1 , join= false , unbatch= true , dropmaps= false )
45+ function load_chains (filename; burnin= 0 , thin= 1 , join= false , unbatch= true , dropmaps= false , burnin_chunks = 0 )
4246 chains = jldopen (filename) do io
4347 ks = keys (io)
44- chunk_ks = [k for k in ks if startswith (k," chunks_" )]
45- for (isfirst,k) in flagfirst (sort (chunk_ks, by= k-> parse (Int,k[8 : end ])))
48+ chunk_ks = sort ([k for k in ks if startswith (k," chunks_" )], by= k-> parse (Int,k[8 : end ]))
49+ chunk_ks = chunk_ks[burnin_chunks>= 0 ? (burnin_chunks+ 1 : end) : (end + burnin_chunks+ 1 : end)]
50+ for (isfirst,k) in flagfirst (chunk_ks)
4651 if isfirst
4752 chains = read (io,k)
4853 else
@@ -55,13 +60,13 @@ function load_chains(filename; burnin=0, thin=1, join=false, unbatch=true, dropm
5560 chains
5661 end
5762 if thin isa Int
58- chains = [chain[( 1 + burnin): thin: end ] for chain in chains]
63+ chains = [chain[burnin >= 0 ? (( 1 + burnin): thin: end) : ( end + ( 1 + burnin) : thin : end) ] for chain in chains]
5964 elseif thin == :hasmaps
6065 chains = [[samp for samp in chain[(1 + burnin): end ] if :ϕ in keys (samp)] for chain in chains]
6166 elseif thin isa Function
6267 chains = [filter (thin,chain) for chain in chains]
6368 else
64- error (" `thin` should be an Int or :hasmaps" )
69+ error (" `thin` should be an Int, :hasmaps, or a filter function " )
6570 end
6671 chains = wrap_chains (chains)
6772 if unbatch
@@ -121,8 +126,8 @@ _getindex(x::Union{Dict,NamedTuple}, k::Symbol) = haskey(x,k) ? getindex(x, k) :
121126_getindex (x, k) = getindex (x, k)
122127
123128
124- wrap_chains (chains:: Vector{<:Vector{<:Dict} } ) = Chains (Chain .(chains))
125- wrap_chains (chain:: Vector{<:Dict} ) = Chain (chain)
129+ wrap_chains (chains:: Vector{<:Vector} ) = Chains (Chain .(chains))
130+ wrap_chains (chain:: Vector ) = Chain (chain)
126131
127132
128133# batching
0 commit comments