11
22# interface with MuseInference.jl
33
4- using . MuseInference: AbstractMuseProblem, MuseResult
4+ using . MuseInference: AbstractMuseProblem, MuseResult, Transformedθ, UnTransformedθ
55using . MuseInference. AbstractDifferentiation
6- import . MuseInference: logLike, ∇θ_logLike, sample_x_z, ẑ_at_θ, muse!, standardizeθ
76
87export CMBLensingMuseProblem
98
10- struct CMBLensingMuseProblem{DS<: DataSet ,DS_SIM<: DataSet } <: AbstractMuseProblem
9+ @kwdef struct CMBLensingMuseProblem{DS<: DataSet ,DS_SIM<: DataSet } <: AbstractMuseProblem
1110 ds :: DS
12- ds_for_sims :: DS_SIM
13- parameterization
14- MAP_joint_kwargs
15- θ_fixed
16- x
17- latent_vars
18- autodiff
11+ ds_for_sims :: DS_SIM = ds
12+ parameterization = 0
13+ MAP_joint_kwargs = (;)
14+ θ_fixed = (;)
15+ x = ds. d
16+ latent_vars = nothing
17+ autodiff = AD. HigherOrderBackend ((AD. ForwardDiffBackend (tag= false ), AD. ZygoteBackend ()))
18+ transform_θ = identity
19+ inv_transform_θ = identity
1920end
21+ CMBLensingMuseProblem (ds, ds_for_sims= ds; kwargs... ) = CMBLensingMuseProblem (;ds, ds_for_sims, kwargs... )
2022
21- function CMBLensingMuseProblem (
22- ds,
23- ds_for_sims = ds;
24- parameterization = 0 ,
25- MAP_joint_kwargs = (;),
26- θ_fixed = (;),
27- latent_vars = nothing ,
28- autodiff = AD. HigherOrderBackend ((AD. ForwardDiffBackend (tag= false ), AD. ZygoteBackend ())),
29- )
30- parameterization == 0 || error (" only parameterization=0 (unlensed parameterization) currently implemented" )
31- CMBLensingMuseProblem (ds, ds_for_sims, parameterization, MAP_joint_kwargs, θ_fixed, ds. d, latent_vars, autodiff)
32- end
3323
3424mergeθ (prob:: CMBLensingMuseProblem , θ) = isempty (prob. θ_fixed) ? θ : (;prob. θ_fixed... , θ... )
3525
36- function standardizeθ (prob:: CMBLensingMuseProblem , θ)
26+ function MuseInference . standardizeθ (prob:: CMBLensingMuseProblem , θ)
3727 θ isa Union{NamedTuple,ComponentVector} || error (" θ should be a NamedTuple or ComponentVector" )
3828 1f0 * ComponentVector (θ) # ensure component vector and float
3929end
4030
41- function MuseInference. logLike (prob:: CMBLensingMuseProblem , d, z, θ)
31+ MuseInference. transform_θ (prob:: CMBLensingMuseProblem , θ) = prob. transform_θ (θ)
32+ MuseInference. inv_transform_θ (prob:: CMBLensingMuseProblem , θ) = prob. inv_transform_θ (θ)
33+
34+ function MuseInference. logLike (prob:: CMBLensingMuseProblem , d, z, θ, :: UnTransformed θ)
4235 logpdf (prob. ds; z... , θ = mergeθ (prob, θ), d)
4336end
37+ function MuseInference. logLike (prob:: CMBLensingMuseProblem , d, z, θ, :: Transformed θ)
38+ MuseInference. logLike (prob, d, z, MuseInference. inv_transform_θ (prob, θ), UnTransformedθ ())
39+ end
4440
45- function ∇θ_logLike (prob:: CMBLensingMuseProblem , d, z, θ)
46- AD. gradient (prob. autodiff, θ -> logLike (prob, d, z, θ), θ)[1 ]
41+ function MuseInference . ∇θ_logLike (prob:: CMBLensingMuseProblem , d, z, θ, θ_space)
42+ AD. gradient (prob. autodiff, θ -> MuseInference . logLike (prob, d, z, θ, θ_space ), θ)[1 ]
4743end
4844
49- function sample_x_z (prob:: CMBLensingMuseProblem , rng:: AbstractRNG , θ)
45+ function MuseInference . sample_x_z (prob:: CMBLensingMuseProblem , rng:: AbstractRNG , θ)
5046 sim = simulate (rng, prob. ds_for_sims, θ = mergeθ (prob, θ))
5147 if prob. latent_vars == nothing
5248 # this is a guess which might not work for everything necessarily
@@ -58,18 +54,18 @@ function sample_x_z(prob::CMBLensingMuseProblem, rng::AbstractRNG, θ)
5854 (;x, z)
5955end
6056
61- function ẑ_at_θ (prob:: CMBLensingMuseProblem , d, zguess, θ; ∇z_logLike_atol= nothing )
57+ function MuseInference . ẑ_at_θ (prob:: CMBLensingMuseProblem , d, zguess, θ; ∇z_logLike_atol= nothing )
6258 @unpack ds = prob
6359 Ωstart = delete (NamedTuple (zguess), :f )
6460 MAP = MAP_joint (mergeθ (prob, θ), @set (ds. d= d), Ωstart; fstart= zguess. f, prob. MAP_joint_kwargs... )
6561 LenseBasis (FieldTuple (;delete (MAP, :history )... )), MAP. history
6662end
6763
68- function ẑ_at_θ (prob:: CMBLensingMuseProblem{<:NoLensingDataSet} , d, (f₀,), θ; ∇z_logLike_atol= nothing )
64+ function MuseInference . ẑ_at_θ (prob:: CMBLensingMuseProblem{<:NoLensingDataSet} , d, (f₀,), θ; ∇z_logLike_atol= nothing )
6965 @unpack ds = prob
7066 LenseBasis (FieldTuple (f= argmaxf_logpdf (I, mergeθ (prob, θ), @set (ds. d= d); fstart= f₀, prob. MAP_joint_kwargs... ))), nothing
7167end
7268
73- function muse! (result:: MuseResult , ds:: DataSet , θ₀= nothing ; parameterization= 0 , MAP_joint_kwargs= (;), kwargs... )
69+ function MuseInference . muse! (result:: MuseResult , ds:: DataSet , θ₀= nothing ; parameterization= 0 , MAP_joint_kwargs= (;), kwargs... )
7470 muse! (result, CMBLensingMuseProblem (ds; parameterization, MAP_joint_kwargs), θ₀; kwargs... )
7571end
0 commit comments