Skip to content

Commit 2d822aa

Browse files
committed
Use adaptive k-sampling in full solve command
1 parent 0ffc702 commit 2d822aa

File tree

3 files changed

+77
-3
lines changed

3 files changed

+77
-3
lines changed

src/SymBoltz.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ include("plot.jl")
4343

4444
export RMΛ, ΛCDM, w0waCDM, QCDM, GRΛCDM, BDΛCDM
4545
export CosmologyProblem, CosmologySolution
46-
export solve, solvebg, solvept, remake, issuccess, parameter_updater
46+
export solve, solvebg, solvept, solvept_adaptive, remake, issuccess, parameter_updater
4747
export parameters_Planck18
4848
export spectrum_primordial, spectrum_matter, spectrum_matter_nonlinear, spectrum_cmb, correlation_function, variance_matter, stddev_matter, los_integrate, source_grid, source_grid_adaptive, sound_horizon, SphericalBesselCache
4949
export express_derivatives

src/solve.jl

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct CosmologyProblem{Tbg <: ODEProblem, Tpt <: Union{ODEProblem, Nothing}, Tb
3131
bgspline::Tbgspline
3232
end
3333

34-
struct CosmologySolution{Tbg <: ODESolution, Tpts <: Union{Nothing, EnsembleSolution}, Tks <: Union{Nothing, AbstractVector}, Th <: Number}
34+
struct CosmologySolution{Tbg <: ODESolution, Tpts <: Union{Nothing, EnsembleSolution, AbstractArray{<:ODESolution}}, Tks <: Union{Nothing, AbstractVector}, Th <: Number}
3535
prob::CosmologyProblem # problem which is solved
3636
bg::Tbg # background solution
3737
ks::Tks # perturbation wavenumbers
@@ -272,19 +272,22 @@ end
272272
bgopts = (alg = DEFAULT_BGALG, reltol = 1e-9, abstol = 1e-9),
273273
ptopts = (alg = DEFAULT_PTALG, reltol = 1e-8, abstol = 1e-8),
274274
shootopts = (alg = DEFAULT_SHOOTALG, abstol = 1e-5),
275+
adaptive = nothing,
275276
thread = true, verbose = false, kwargs...
276277
)
277278
278279
Solve the cosmological problem `prob` up to the perturbative level with wavenumbers `ks` (or only to the background level if it is empty).
279280
The options `bgopts` and `ptopts` are passed to the background and perturbations ODE `solve()` calls,
280281
and `shootopts` to the shooting method nonlinear `solve()`.
282+
If `adaptive`, the wavenumbers `ks` is treated as an initial grid of wavenumbers which is adaptively refined based on `isapprox(...; kwargs...)`.
281283
If `threads`, integration over independent perturbation modes are parallellized.
282284
"""
283285
function solve(
284286
prob::CosmologyProblem, ks::Union{Nothing, AbstractArray} = nothing;
285287
bgopts = (alg = DEFAULT_BGALG, reltol = 1e-9, abstol = 1e-9),
286288
ptopts = (alg = DEFAULT_PTALG, reltol = 1e-8, abstol = 1e-8),
287289
shootopts = (alg = DEFAULT_SHOOTALG, abstol = 1e-5),
290+
adaptive = nothing,
288291
thread = true, verbose = false, kwargs...
289292
)
290293
if !isempty(prob.shoot)
@@ -299,7 +302,11 @@ function solve(
299302
ptsol = nothing
300303
else
301304
ks = k_dimensionless.(ks, h)
302-
ptsol = solvept(prob.pt, bgsol, ks, prob.bgspline; thread, verbose, ptopts..., kwargs...)
305+
if !isnothing(adaptive)
306+
ks, ptsol = solvept_adaptive(prob.pt, bgsol, ks, prob.bgspline; adaptive...)
307+
else
308+
ptsol = solvept(prob.pt, bgsol, ks, prob.bgspline; thread, verbose, ptopts..., kwargs...)
309+
end
303310
end
304311

305312
return CosmologySolution(prob, bgsol, ks, ptsol, h)
@@ -449,6 +456,66 @@ function solvept(ptprob::ODEProblem; alg = DEFAULT_PTALG, reltol = 1e-8, abstol
449456
return solve(ptprob, alg; reltol, abstol, kwargs...)
450457
end
451458

459+
function solvept_adaptive(ptprob::ODEProblem, bgsol::ODESolution, ks::AbstractArray, bgsplinepar, Ss = nothing; alg = DEFAULT_PTALG, reltol = 1e-8, abstol = 1e-8, sort = true, verbose = false, kwargs...)
460+
!issorted(ks) && throw(error("ks = $ks are not sorted in ascending order"))
461+
462+
ptprob0, ptprobgen = setuppt(ptprob, bgsol, bgsplinepar)
463+
solveptk(k) = solvept(ptprobgen(ptprob0, k))
464+
465+
Sfunc(ptsol) = ptsol(range(ptsol.t[begin], ptsol.t[end]; length = 200); idxs = Ss).u # TODO: customize N points
466+
467+
ks = copy(ks) # don't modify input array
468+
469+
savelock = ReentrantLock()
470+
function refine(i1, i2)
471+
k1 = ks[i1]
472+
k2 = ks[i2]
473+
ptsol1 = ptsols[i1]
474+
ptsol2 = ptsols[i2]
475+
476+
k2 k1 || error("shit")
477+
478+
k = (k1 + k2) / 2
479+
ptsol = solveptk(k)
480+
S = Sfunc(ptsol)
481+
482+
i = 0 # define in outer scope, then set inside lock block
483+
@lock savelock begin
484+
push!(ks, k)
485+
push!(ptsols, ptsol)
486+
push!(Scache, S)
487+
i = length(ks) # copy, since first refine below can modify i before call to second refine
488+
verbose && println("Refined k-grid between [$k1, $k2] on thread $(threadid()) to $i total points")
489+
end
490+
491+
S1 = Scache[i1]
492+
S2 = Scache[i2]
493+
Sint = (S1 .+ S2) ./ 2 # linear interpolation
494+
495+
# check if interpolation is close enough for all sources
496+
# (equivalent to finding the source grid of each source separately)
497+
@sync if !isapprox(S, Sint; kwargs...) # TODO: several S
498+
@spawn refine(i1, i) # refine left subinterval
499+
@spawn refine(i, i2) # refine right subinterval
500+
end
501+
end
502+
503+
ptsols = tmap(solveptk, ks)
504+
Scache = tmap(Sfunc, ptsols)
505+
@threads for i in 1:length(ks)-1
506+
refine(i, i+1)
507+
end
508+
509+
# sort according to k
510+
if sort
511+
is = sortperm(ks)
512+
ks = ks[is]
513+
ptsols = ptsols[is]
514+
end
515+
516+
return ks, ptsols
517+
end
518+
452519
function time_today(prob::CosmologyProblem)
453520
getτ0 = SymBoltz.getsym(prob.bg, :τ0)
454521
bgprob = prob.bg

test/runtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,3 +452,10 @@ using SpecialFunctions: zeta as ζ
452452
end
453453
end
454454
end
455+
456+
@testset "Adaptive k-solving" begin
457+
bgsol = solvebg(prob.bg)
458+
ks = [1.0, 1000.0]
459+
Ss =
460+
ks, ptsols = solvept_adaptive(prob.pt, bgsol, ks, prob.bgspline, Ss; rtol = 1e-1);
461+
end

0 commit comments

Comments
 (0)