Skip to content

Commit 749b69e

Browse files
committed
allow θ transform in CMBLensingMuseProblem
1 parent 5feed16 commit 749b69e

File tree

1 file changed

+26
-30
lines changed

1 file changed

+26
-30
lines changed

src/muse.jl

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,48 @@
11

22
# interface with MuseInference.jl
33

4-
using .MuseInference: AbstractMuseProblem, MuseResult
4+
using .MuseInference: AbstractMuseProblem, MuseResult, Transformedθ, UnTransformedθ
55
using .MuseInference.AbstractDifferentiation
6-
import .MuseInference: logLike, ∇θ_logLike, sample_x_z, ẑ_at_θ, muse!, standardizeθ
76

87
export CMBLensingMuseProblem
98

10-
struct CMBLensingMuseProblem{DS<:DataSet,DS_SIM<:DataSet} <: AbstractMuseProblem
9+
@kwdef struct CMBLensingMuseProblem{DS<:DataSet,DS_SIM<:DataSet} <: AbstractMuseProblem
1110
ds :: DS
12-
ds_for_sims :: DS_SIM
13-
parameterization
14-
MAP_joint_kwargs
15-
θ_fixed
16-
x
17-
latent_vars
18-
autodiff
11+
ds_for_sims :: DS_SIM = ds
12+
parameterization = 0
13+
MAP_joint_kwargs = (;)
14+
θ_fixed = (;)
15+
x = ds.d
16+
latent_vars = nothing
17+
autodiff = AD.HigherOrderBackend((AD.ForwardDiffBackend(tag=false), AD.ZygoteBackend()))
18+
transform_θ = identity
19+
inv_transform_θ = identity
1920
end
21+
CMBLensingMuseProblem(ds, ds_for_sims=ds; kwargs...) = CMBLensingMuseProblem(;ds, ds_for_sims, kwargs...)
2022

21-
function CMBLensingMuseProblem(
22-
ds,
23-
ds_for_sims = ds;
24-
parameterization = 0,
25-
MAP_joint_kwargs = (;),
26-
θ_fixed = (;),
27-
latent_vars = nothing,
28-
autodiff = AD.HigherOrderBackend((AD.ForwardDiffBackend(tag=false), AD.ZygoteBackend())),
29-
)
30-
parameterization == 0 || error("only parameterization=0 (unlensed parameterization) currently implemented")
31-
CMBLensingMuseProblem(ds, ds_for_sims, parameterization, MAP_joint_kwargs, θ_fixed, ds.d, latent_vars, autodiff)
32-
end
3323

3424
mergeθ(prob::CMBLensingMuseProblem, θ) = isempty(prob.θ_fixed) ? θ : (;prob.θ_fixed..., θ...)
3525

36-
function standardizeθ(prob::CMBLensingMuseProblem, θ)
26+
function MuseInference.standardizeθ(prob::CMBLensingMuseProblem, θ)
3727
θ isa Union{NamedTuple,ComponentVector} || error("θ should be a NamedTuple or ComponentVector")
3828
1f0 * ComponentVector(θ) # ensure component vector and float
3929
end
4030

41-
function MuseInference.logLike(prob::CMBLensingMuseProblem, d, z, θ)
31+
MuseInference.transform_θ(prob::CMBLensingMuseProblem, θ) = prob.transform_θ(θ)
32+
MuseInference.inv_transform_θ(prob::CMBLensingMuseProblem, θ) = prob.inv_transform_θ(θ)
33+
34+
function MuseInference.logLike(prob::CMBLensingMuseProblem, d, z, θ, ::UnTransformedθ)
4235
logpdf(prob.ds; z..., θ = mergeθ(prob, θ), d)
4336
end
37+
function MuseInference.logLike(prob::CMBLensingMuseProblem, d, z, θ, ::Transformedθ)
38+
MuseInference.logLike(prob, d, z, MuseInference.inv_transform_θ(prob, θ), UnTransformedθ())
39+
end
4440

45-
function ∇θ_logLike(prob::CMBLensingMuseProblem, d, z, θ)
46-
AD.gradient(prob.autodiff, θ -> logLike(prob, d, z, θ), θ)[1]
41+
function MuseInference.∇θ_logLike(prob::CMBLensingMuseProblem, d, z, θ, θ_space)
42+
AD.gradient(prob.autodiff, θ -> MuseInference.logLike(prob, d, z, θ, θ_space), θ)[1]
4743
end
4844

49-
function sample_x_z(prob::CMBLensingMuseProblem, rng::AbstractRNG, θ)
45+
function MuseInference.sample_x_z(prob::CMBLensingMuseProblem, rng::AbstractRNG, θ)
5046
sim = simulate(rng, prob.ds_for_sims, θ = mergeθ(prob, θ))
5147
if prob.latent_vars == nothing
5248
# this is a guess which might not work for everything necessarily
@@ -58,18 +54,18 @@ function sample_x_z(prob::CMBLensingMuseProblem, rng::AbstractRNG, θ)
5854
(;x, z)
5955
end
6056

61-
function ẑ_at_θ(prob::CMBLensingMuseProblem, d, zguess, θ; ∇z_logLike_atol=nothing)
57+
function MuseInference.ẑ_at_θ(prob::CMBLensingMuseProblem, d, zguess, θ; ∇z_logLike_atol=nothing)
6258
@unpack ds = prob
6359
Ωstart = delete(NamedTuple(zguess), :f)
6460
MAP = MAP_joint(mergeθ(prob, θ), @set(ds.d=d), Ωstart; fstart=zguess.f, prob.MAP_joint_kwargs...)
6561
LenseBasis(FieldTuple(;delete(MAP, :history)...)), MAP.history
6662
end
6763

68-
function ẑ_at_θ(prob::CMBLensingMuseProblem{<:NoLensingDataSet}, d, (f₀,), θ; ∇z_logLike_atol=nothing)
64+
function MuseInference.ẑ_at_θ(prob::CMBLensingMuseProblem{<:NoLensingDataSet}, d, (f₀,), θ; ∇z_logLike_atol=nothing)
6965
@unpack ds = prob
7066
LenseBasis(FieldTuple(f=argmaxf_logpdf(I, mergeθ(prob, θ), @set(ds.d=d); fstart=f₀, prob.MAP_joint_kwargs...))), nothing
7167
end
7268

73-
function muse!(result::MuseResult, ds::DataSet, θ₀=nothing; parameterization=0, MAP_joint_kwargs=(;), kwargs...)
69+
function MuseInference.muse!(result::MuseResult, ds::DataSet, θ₀=nothing; parameterization=0, MAP_joint_kwargs=(;), kwargs...)
7470
muse!(result, CMBLensingMuseProblem(ds; parameterization, MAP_joint_kwargs), θ₀; kwargs...)
7571
end

0 commit comments

Comments
 (0)