Skip to content

Commit c201b24

Browse files
committed
Merge branch 'master' of git+ssh://github.com/marius311/CMBLensing.jl
2 parents 26db2f6 + 1282314 commit c201b24

File tree

6 files changed

+134
-76
lines changed

6 files changed

+134
-76
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,5 @@ Setfield = "0.6.0"
7575
StaticArrays = "0.12.1"
7676
StatsBase = "0.32.0"
7777
Strided = "0.3.3"
78-
Zygote = "0.4.14"
78+
Zygote = "0.4.14 - 0.4.15"
7979
julia = "1.3"

src/CMBLensing.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,4 +128,10 @@ include("autodiff.jl")
128128
is_gpu_backed(x) = false
129129
@init @require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("gpu.jl")
130130

131+
# misc init
132+
# see https://github.com/timholy/ProgressMeter.jl/issues/71 and links therein
133+
@init if ProgressMeter.@isdefined ijulia_behavior
134+
ProgressMeter.ijulia_behavior(:clear)
135+
end
136+
131137
end

src/maximization.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ Keyword arguments:
9292
* `cgtol` — Conjugrate gradient tolerance (will stop at `cgtol` or `Ncg`, whichever is first)
9393
* `αtol` — Absolute tolerance on $\alpha$ in the linesearch in the $\phi$ quasi-Newton-Rhapson step, $x^\prime = x - \alpha H^{-1} g$
9494
* `αmax` — Maximum value for $\alpha$ in the linesearch
95-
* `progress` — `false`, `:summary`, or `:verbose`, to control progress output
95+
* `progress` — whether to show progress bar
9696
9797
Returns a tuple `(f, ϕ, tr)` where `f` is the best-fit (or quasi-sample) field,
9898
`ϕ` is the lensing potential, and `tr` contains info about the run.
@@ -109,10 +109,9 @@ function MAP_joint(
109109
αmax = 0.5,
110110
cache_function = nothing,
111111
callback = nothing,
112-
interruptable = false,
113-
progress = false)
112+
interruptable::Bool = false,
113+
progress::Bool = false)
114114

115-
@assert progress in [false,:summary,:verbose]
116115
if !(isa(quasi_sample,Bool) || isa(quasi_sample,Int))
117116
throw(ArgumentError("quasi_sample should be true, false, or an Int."))
118117
end
@@ -138,7 +137,9 @@ function MAP_joint(
138137
Hϕ⁻¹ = (Nϕ == nothing) ?: pinv(pinv(Cϕ) + pinv(Nϕ))
139138

140139
try
141-
@showprogress (progress==:summary ? 1 : Inf) "MAP_joint: " for i=1:nsteps
140+
pbar = Progress(nsteps, (progress ? 0 : Inf), "MAP_joint: ")
141+
142+
for i=1:nsteps
142143

143144
# ==== f step ====
144145

@@ -160,9 +161,7 @@ function MAP_joint(
160161
lnPcur = lnP(:mix,f°,ϕ,ds)
161162

162163
# ==== show progress ====
163-
if (progress==:verbose)
164-
@printf("(step=%i, χ²=%.2f, Ncg=%i%s)\n", i, -2lnPcur, length(hist), (α==0 ? "" : @sprintf(", α=%.6f",α)))
165-
end
164+
next!(pbar, showvalues=[("step",i), ("χ²",-2lnPcur), ("Ncg",length(hist)), ("α",α)])
166165
push!(tr,@namedtuple(i,lnPcur,hist,ϕ,f,α,ϕstep))
167166
if callback != nothing
168167
callback(f, ϕ, tr)

src/plotting.jl

Lines changed: 108 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -22,98 +22,136 @@ end
2222

2323
plotsize₀ = 4
2424

25-
pretty_name(s::Symbol) = pretty_name(Val.(Symbol.(split(string(s),"")))...)
26-
pretty_name(::Val{s},::Val{:x}) where {s} = "$s Map"
27-
pretty_name(::Val{s},::Val{:l}) where {s} = "$s Fourier"
28-
pretty_name(::Val{:T},::Val{:x}) where {s} = "Map"
29-
pretty_name(::Val{:T},::Val{:l}) where {s} = "Fourier"
30-
31-
# generic plotting some components of a FlatField
32-
function _plot(f::FlatField{P}, ax, k, title, vlim; units=:deg, ticklabels=true, axeslabels=false, kwargs...) where {N,θ,P<:Flat{N,θ}}
33-
if string(k)[2] == 'x'
34-
x = θ*N/Dict(:deg=>60,:arcmin=>1)[units]/2
35-
elseif string(k)[2] == 'l'
36-
x = fieldinfo(f).nyq
37-
else
38-
throw(ArgumentError("Invalid `which`: $k"))
39-
end
40-
extent = [-x,x,-x,x]
41-
(title == nothing) && (title="$(pretty_name(k)) ($(N)x$(N) @ $(θ)')")
42-
(vlim == nothing) && (vlim=:sym)
43-
_plot(Array(f[k]); ax=ax, extent=extent, title=title, vlim=vlim, kwargs...)
44-
if ticklabels
45-
if string(k)[2] == 'x'
46-
@pydef mutable struct MyFmt <: pyimport(:matplotlib).ticker.ScalarFormatter
47-
__call__(self,v,p=nothing) = py"super"(MyFmt,self).__call__(v,p)*Dict(:deg=>"°",:arcmin=>"")[units]
48-
end
49-
ax.xaxis.set_major_formatter(MyFmt())
50-
ax.yaxis.set_major_formatter(MyFmt())
51-
if axeslabels
52-
ax.set_xlabel("RA")
53-
ax.set_ylabel("Dec")
54-
end
55-
else
56-
ax.set_xlabel(raw"$\ell_x$")
57-
ax.set_ylabel(raw"$\ell_y$")
58-
ax.tick_params(axis="x", rotation=45)
59-
end
60-
ax.tick_params(labeltop=false, labelbottom=true)
61-
else
62-
ax.tick_params(labeltop=false, labelleft=false)
63-
end
64-
end
25+
pretty_name(s) = pretty_name(Val.(Symbol.(split(string(s),"")))...)
26+
pretty_name(::Val{s}, b::Val) where {s} = "$s "*pretty_name(b)
27+
pretty_name(::Val{:x}) = "Map"
28+
pretty_name(::Val{:l}) = "Fourier"
6529

66-
# plotting a map
67-
function _plot(m::AbstractMatrix{<:Real}; ax=gca(), title=nothing, vlim=:sym, cmap="RdBu_r", vscale=:linear, cbar=true, kwargs...)
30+
function _plot(f, ax, k, title, vlim, vscale, cmap; cbar=true, units=:deg, ticklabels=true, axeslabels=false, kwargs...)
6831

69-
# some logic to automatically get upper/lower limits
70-
if vlim==:sym
71-
vmax = quantile(abs.(m[@. !isnan(m)][:]),0.999)
32+
@unpack Nside, θpix = fieldinfo(f)
33+
ismap = endswith(string(k), "x")
34+
35+
# default values
36+
if title == nothing
37+
if f isa FlatS0
38+
title = pretty_name(string(k)[2])
39+
else
40+
title = pretty_name(k)
41+
end
42+
title *= " ($(Nside)x$(Nside) @ $(θpix)')"
43+
end
44+
if vlim == nothing
45+
vlim = ismap ? :sym : :asym
46+
end
47+
if vscale == nothing
48+
vscale = ismap ? :linear : :log
49+
end
50+
if cmap == nothing
51+
if ismap
52+
cmap = get_cmap("RdBu_r")
53+
else
54+
cmap = get_cmap("viridis")
55+
cmap.set_bad("lightgray")
56+
end
57+
end
58+
59+
# build array
60+
if ismap
61+
arr = Array(f[k])
62+
else
63+
arr = abs.(ifftshift(unfold(Array(f[k]))))
64+
end
65+
if vscale == :log
66+
arr[arr .== 0] .= NaN
67+
end
68+
69+
# auto vlim's
70+
if vlim==:sym
71+
vmax = quantile(abs.(arr[@. !isnan(arr)][:]),0.999)
7272
vmin = -vmax
7373
elseif vlim==:asym
74-
vmin, vmax = (quantile(m[@. !isnan(m)][:],q) for q=(0.001,0.999))
74+
vmin, vmax = (quantile(arr[@. !isnan(arr)][:],q) for q=(0.001,0.999))
7575
elseif isa(vlim,Tuple)
7676
vmin, vmax = vlim
7777
else
7878
vmax = vlim
7979
vmin = -vmax
8080
end
81-
82-
m = Float64.(m)
83-
m[isinf.(m)] .= NaN
84-
85-
cax = ax.matshow(clamp.(m,vmin,vmax); vmin=vmin, vmax=vmax, cmap=cmap, rasterized=true, kwargs...)
86-
cbar && gcf().colorbar(cax,ax=ax)
87-
title!=nothing && ax.set_title(title, y=1)
88-
ax
89-
end
90-
91-
# plotting fourier coefficients
92-
function _plot(m::AbstractMatrix{<:Complex}; vscale=:log, kwargs...)
93-
dat = ifftshift(unfold(m))
94-
if vscale==:log
95-
dat .= log10.(abs.(dat))
81+
82+
# make the plot
83+
if ismap
84+
extent = [-1,1,-1,1] .* θpix*Nside/Dict(:deg=>60,:arcmin=>1)[units]/2
85+
else
86+
extent = [-1,1,-1,1] .* fieldinfo(f).nyq
87+
end
88+
norm = vscale == :log ? matplotlib.colors.LogNorm() : nothing
89+
cax = ax.matshow(
90+
arr;
91+
vmin=vmin, vmax=vmax, extent=extent,
92+
cmap=cmap, rasterized=true, norm=norm,
93+
kwargs...
94+
)
95+
96+
# annonate
97+
if cbar
98+
colorbar(cax,ax=ax)
99+
end
100+
ax.set_title(title, y=1)
101+
if ticklabels
102+
if ismap
103+
@pydef mutable struct MyFmt <: pyimport(:matplotlib).ticker.ScalarFormatter
104+
__call__(self,v,p=nothing) = py"super"(MyFmt,self).__call__(v,p)*Dict(:deg=>"°",:arcmin=>"")[units]
105+
end
106+
ax.xaxis.set_major_formatter(MyFmt())
107+
ax.yaxis.set_major_formatter(MyFmt())
108+
if axeslabels
109+
ax.set_xlabel("RA")
110+
ax.set_ylabel("Dec")
111+
end
112+
else
113+
ax.set_xlabel(raw"$\ell_x$")
114+
ax.set_ylabel(raw"$\ell_y$")
115+
ax.tick_params(axis="x", rotation=45)
116+
end
117+
ax.tick_params(labeltop=false, labelbottom=true)
118+
else
119+
ax.tick_params(labeltop=false, labelleft=false)
96120
end
97-
_plot(real.(dat); vlim=(nothing,nothing), cmap=nothing, kwargs...)
121+
98122
end
99123

100124

101-
102125
@doc doc"""
103126
plot(f::Field; kwargs...)
104127
plot(fs::VecOrMat{\<:Field}; kwarg...)
105128
106129
Plotting fields.
107130
"""
108131
plot(f::Field; kwargs...) = plot([f]; kwargs...)
109-
function plot(fs::AbstractVecOrMat{F}; plotsize=plotsize₀, which=default_which(fs), title=nothing, vlim=nothing, return_all=false, kwargs...) where {F<:Field}
132+
plot(D::DiagOp; kwargs...) =
133+
plot([diag(D)]; which=permutedims([x for x in propertynames(diag(D)) if string(x)[end] in "xl"]), kwargs...)
134+
135+
function plot(
136+
fs::AbstractVecOrMat{F};
137+
plotsize = plotsize₀,
138+
which = default_which(fs),
139+
title = nothing,
140+
vlim = nothing,
141+
vscale = nothing,
142+
cmap = nothing,
143+
return_all = false,
144+
kwargs...) where {F<:Field}
145+
110146
(m,n) = size(tuple.(fs, which)[:,:])
111147
fig,axs = subplots(m, n; figsize=plotsize.*[1.4*n,m], squeeze=false)
112148
axs = getindex.(Ref(axs), 1:m, (1:n)') # see https://github.com/JuliaPy/PyCall.jl/pull/487#issuecomment-456998345
113-
_plot.(fs,axs,which,title,vlim; kwargs...)
149+
_plot.(fs,axs,which,title,vlim,vscale,cmap; kwargs...)
114150
tight_layout(w_pad=-10)
115151
return_all ? (fig,axs,which) : isjuno ? fig : nothing
152+
116153
end
154+
117155
default_which(::AbstractVecOrMat{<:FlatS0}) = [:Ix]
118156
default_which(::AbstractVecOrMat{<:FlatS2}) = [:Ex :Bx]
119157
default_which(::AbstractVecOrMat{<:FlatS02}) = [:Ix :Ex :Bx]
@@ -126,6 +164,8 @@ function default_which(fs::AbstractVecOrMat{<:Field})
126164
end
127165

128166

167+
### animations of FlatFields
168+
129169
@doc doc"""
130170
animate(fields::Vector{\<:Vector{\<:Field}}; interval=50, motionblur=false, kwargs...)
131171
@@ -168,7 +208,9 @@ for plot in (:plot, :loglog, :semilogx, :semilogy)
168208

169209
@eval function ($plot)(f::Function, m::Loess.LoessModel, args...; kwargs...)
170210
l, = ($plot)(m.xs, f.(m.ys), ".", args...; kwargs...)
171-
xs′ = range(first(m.xs),last(m.xs),length=10*length(m.xs))
211+
xs′ = vcat(map(1:length(m.xs)-1) do i
212+
collect(range(m.xs[i],m.xs[i+1],length=10))[1:end-1]
213+
end..., [last(m.xs)])
172214
($plot)(xs′, f.(m.(xs′)), args...; c=l.get_color(), kwargs...)
173215
end
174216

src/sampling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ function sample_joint(
193193
end
194194

195195
# seed
196-
@everywhere seed_for_storage!($storage)
196+
@everywhere seed_for_storage!((Array,$storage))
197197

198198
# initialize chains
199199
if (filename != nothing) && isfile(filename)

src/util.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,13 +353,16 @@ end
353353

354354
@doc doc"""
355355
seed_for_storage!(storage[, seed])
356+
seed_for_storage!((storage1, storage2, ...)[, seed])
356357
357358
Set the global random seed for the RNG which controls `storage`-type.
358359
"""
359360
seed_for_storage!(::Type{<:Array}, seed=nothing) =
360361
Random.seed!((seed == nothing ? () : (seed,))...)
361362
seed_for_storage!(storage::Any, seed=nothing) =
362363
error("Don't know how to set seed for storage=$storage")
364+
seed_for_storage!(storages::Tuple, seed=nothing) =
365+
seed_for_storage!.(storages, seed)
363366

364367

365368
### parallel utility function
@@ -414,6 +417,10 @@ init_GPU_workers(n=nothing) = init_GPU_workers(Val(PARALLEL_WORKER_TYPE), n)
414417

415418
function init_GPU_workers(::Val{:MPI}, n=nothing; stdout_to_master=false, stderr_to_master=false)
416419

420+
if !CuArrays.functional()
421+
return
422+
end
423+
417424
!MPI.Initialized() && MPI.Init()
418425
size = MPI.Comm_size(MPI.COMM_WORLD)
419426
rank = MPI.Comm_rank(MPI.COMM_WORLD)
@@ -433,6 +440,10 @@ init_GPU_workers(n=nothing) = init_GPU_workers(Val(PARALLEL_WORKER_TYPE), n)
433440

434441
function init_GPU_workers(::Val{:procs}, n=nothing)
435442

443+
if !CuArrays.functional()
444+
return
445+
end
446+
436447
if n == nothing
437448
n = length(devices())
438449
end

0 commit comments

Comments
 (0)