Skip to content

Commit 9a27327

Browse files
authored
Merge pull request #51 from marius311/modular_sampling
make sample_joint code more modular / customizable
2 parents 8f7a95f + c424cc1 commit 9a27327

File tree

6 files changed

+296
-211
lines changed

6 files changed

+296
-211
lines changed

src/CMBLensing.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ using Base: @kwdef, @propagate_inbounds, Bottom, OneTo, showarg, show_datatype,
1111
using Combinatorics
1212
using DataStructures
1313
using DelimitedFiles
14-
using Distributed: pmap, nworkers, myid, workers, addprocs, @everywhere, remotecall_wait, @spawnat, pgenerate, procs, @fetchfrom
14+
using Distributed: pmap, nworkers, myid, workers, addprocs, @everywhere, remotecall_wait,
15+
@spawnat, pgenerate, procs, @fetchfrom, default_worker_pool
1516
using FileIO
1617
using FFTW
1718
using InteractiveUtils
1819
using IterTools: flagfirst
1920
using JLD2
21+
using JLD2: jldopen, JLDWriteSession
2022
using KahanSummation
2123
using Loess
2224
using LinearAlgebra
@@ -86,7 +88,14 @@ export
8688
simulate, SymmetricFuncOp, symplectic_integrate, Taylens, toCℓ, toDℓ,
8789
ud_grade, unbatch, unmix, white_noise, Ð, Ł,
8890
ℓ², ℓ⁴, ∇, ∇², ∇ᵢ, ∇ⁱ
89-
91+
92+
# bunch of sampling-related exports
93+
export gibbs_initialize_f!, gibbs_initialize_ϕ!, gibbs_initialize_θ!,
94+
gibbs_sample_f!, gibbs_sample_ϕ!, gibbs_sample_slice_θ!,
95+
gibbs_mix!, gibbs_unmix!, gibbs_postprocess!,
96+
once_every, start_after_burnin, mass_matrix_ϕ, hmc_step
97+
98+
9099
# generic stuff
91100
include("util.jl")
92101
include("util_fft.jl")

src/dataset.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,35 @@ getproperty(ds::DS, k::Symbol) where {DS<:DataSet{<:DataSet}} =
66
hasfield(DS, k) ? getfield(ds, k) : getproperty(getfield(ds, :_super), k)
77
setproperty!(ds::DS, k::Symbol, v) where {DS<:DataSet{<:DataSet}} =
88
hasfield(DS, k) ? setfield!(ds, k, v) : setproperty!(getfield(ds, :_super), k, v)
9-
propertynames(ds::DS) where {DS′<:DataSet, DS<:DataSet{DS′}} =
10-
union(fieldnames(DS), fieldnames(DS′))
9+
propertynames(ds::DS) where {DS<:DataSet} = propertynames(DS)
10+
propertynames(::Type{DS}) where {DS′<:DataSet, DS<:DataSet{DS′}} =
11+
union(propertynames(DS′), fieldnames(DS))
12+
propertynames(::Type{DS}) where {DS<:DataSet{Nothing}} = fieldnames(DS)
1113

1214
function new_dataset(::Type{DS}; kwargs...) where {DS′<:DataSet, DS<:DataSet{DS′}}
13-
kw = filter(((k,_),)-> k in fieldnames(DS), kwargs)
14-
kw′ = filter(((k,_),)->!(k in fieldnames(DS)), kwargs)
15-
DS(_super=DS′(;kw′...); kw...)
15+
kw = Dict(k => v for (k,v) in kwargs if k in fieldnames(DS))
16+
kw′ = Dict(k => v for (k,v) in kwargs if !(k in fieldnames(DS)))
17+
DS(; kw..., _super=new_dataset(DS′; kw′...))
18+
end
19+
20+
function new_dataset(::Type{DS}; kwargs...) where {DS<:DataSet{Nothing}}
21+
DS(; kwargs...)
1622
end
1723

1824
copy(ds::DS) where {DS<:DataSet} =
1925
DS(((k==:_super ? copy(v) : v) for (k,v) in pairs(fields(ds)))...)
2026

2127
hash(ds::DataSet, h::UInt64) = foldr(hash, (typeof(ds), fieldvalues(ds)...), init=h)
2228

29+
function show(io::IO, ds::DataSet)
30+
println(io, typeof(ds), ": ")
31+
ds_dict = OrderedDict(k => getproperty(ds,k) for k in propertynames(ds) if k!=Symbol("_super"))
32+
for line in split(sprint(show, MIME"text/plain"(), ds_dict, context = (:limit => true)), "\n")[2:end]
33+
println(io, line)
34+
end
35+
end
36+
37+
2338
# needed until fix to https://github.com/FluxML/Zygote.jl/issues/685
2439
Zygote.grad_mut(ds::DataSet) = Ref{Any}((;(propertynames(ds) .=> nothing)...))
2540

src/healpix.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,42 +15,43 @@ function xy_to_θϕ((x,y))
1515
θ, ϕ
1616
end
1717

18-
function healpix_to_flat(healpix_map::Vector{T}, proj::ProjLambert{T}; rot=(0,0,0)) where {T}
18+
function healpix_to_flat(healpix_map::Vector{T}, proj::ProjLambert{T}; rots=[((0,90,0),)]) where {T}
1919

2020
Nside_sphere = hp.npix2nside(length(healpix_map))
21-
@unpack Δx = proj
21+
@unpack Δx, Ny, Nx = proj
2222

2323
# compute projection coordinate mapping
24-
xs = ys = Δx*((-Nside÷2:Nside÷2-1) .+ 0.5)
24+
ys = @. Δx * ((-Nx÷2:Nx÷2-1) + 0.5) # x/y switch here intentional
25+
xs = @. Δx * ((-Ny÷2:Ny÷2-1) + 0.5)
2526
xys = tuple.(xs,ys')[:]
2627
θϕs = xy_to_θϕ.(xys)
2728
(θs, ϕs) = first.(θϕs), last.(θϕs)
2829

29-
# rotate the pole to the equator to match Healpy's azeqview convention, in
30-
# addition to applying the user rotation
31-
R = hp.Rotator((0,90,0), eulertype="ZYX") * hp.Rotator(rot, eulertype="ZYX")
30+
# the default rots=[(0,90,0)] makes it so you get a view of the
31+
# equator, to match Helapy's azeqview convention
32+
R = prod([hp.Rotator(rot..., eulertype="ZYX") for rot in rots])
3233
(θs′, ϕs′) = eachrow(R.get_inverse()(θs, ϕs))
3334

3435
# interpolate map
35-
FlatMap(reshape(hp.get_interp_val(healpix_map, θs′, ϕs′), Nside, Nside), proj)
36+
FlatMap(reshape(hp.get_interp_val(healpix_map, θs′, ϕs′), Ny, Nx), proj)
3637

3738
end
3839

39-
function healpix_pixel_centers_to_flat(f::FlatField, Nside_sphere; rots=[(0,90,0)], healpix_pixels=0:(12*Nside_sphere^2-1))
40+
function healpix_pixel_centers_to_flat(f::FlatField, Nside_sphere; rots=[((0,90,0),)], healpix_pixels=0:(12*Nside_sphere^2-1))
4041

4142
@unpack Δx, Ny, Nx = f
4243

4344
(θs, ϕs) = hp.pix2ang(Nside_sphere, healpix_pixels)
4445

45-
# rotate the pole to the equator to match Healpy's azeqview convention, in
46-
# addition to applying the user rotation
47-
R = prod([hp.Rotator(rot, eulertype="ZYX") for rot in rots])
46+
# the default rots=[(0,90,0)] makes it so you get a view of the
47+
# equator, to match Helapy's azeqview convention
48+
R = prod([hp.Rotator(rot..., eulertype="ZYX") for rot in rots])
4849
(θs′, ϕs′) = eachrow(R(θs, ϕs))
4950

5051
# compute projection coordinate mapping
5152
# (using Ny for xs and vice-versa intentional)
5253
xys = @. θϕ_to_xy(tuple(θs′, ϕs′))
53-
xs = @. first(xys) / Δx + Ny÷2 + 0.5
54+
xs = @. first(xys) / Δx + Ny÷2 + 0.5 # x/y switch here intentional
5455
ys = @. last(xys) / Δx + Nx÷2 + 0.5
5556

5657
(xs, ys)

0 commit comments

Comments
 (0)