Skip to content

Commit d79184d

Browse files
committed
muse implicit diff working
1 parent 89ce8f5 commit d79184d

File tree

5 files changed

+73
-34
lines changed

5 files changed

+73
-34
lines changed

src/base_fields.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ lastindex(f::BaseField, i::Int) = lastindex(f.arr, i)
3636
@propagate_inbounds getindex(f::BaseField, I::Union{Int,Colon,AbstractArray}...) = getindex(f.arr, I...)
3737
@propagate_inbounds setindex!(f::BaseField, X, I::Union{Int,Colon,AbstractArray}...) = (setindex!(f.arr, X, I...); f)
3838
similar(f::BaseField{B}, ::Type{T}) where {B,T} = BaseField{B}(similar(f.arr, T), f.metadata)
39+
similar(f::BaseField{B}, ::Type{T}, dims::Base.DimOrInd...) where {B,T} = similar(f.arr, T, dims...)
3940
copy(f::BaseField{B}) where {B} = BaseField{B}(copy(f.arr), f.metadata)
41+
copyto!(dst::AbstractArray, src::BaseField) = copyto!(dst, src.arr)
4042
(==)(f₁::BaseField, f₂::BaseField) = strict_compatible_metadata(f₁,f₂) && (f₁.arr == f₂.arr)
4143

4244

@@ -46,7 +48,9 @@ function promote(f₁::BaseField{B₁}, f₂::BaseField{B₂}) where {B₁,B₂}
4648
B = typeof(promote_basis_generic(B₁(), B₂()))
4749
B(f₁), B(f₂)
4850
end
49-
51+
# allow very basic arithmetic with BaseField & AbstractArray
52+
promote(f::BaseField{B}, x::AbstractArray) where {B} = (f, BaseField{B}(reshape(x, size(f.arr)), f.proj))
53+
promote(x::AbstractArray, f::BaseField{B}) where {B} = reverse(promote(f, x))
5054

5155
## broadcasting
5256

@@ -61,6 +65,7 @@ BroadcastStyle(::Type{F}) where {B,M,T,A,F<:BaseField{B,M,T,A}} =
6165
BroadcastStyle(::BaseFieldStyle{S₁,B₁}, ::BaseFieldStyle{S₂,B₂}) where {S₁,B₁,S₂,B₂} =
6266
BaseFieldStyle{typeof(result_style(S₁(), S₂())), typeof(promote_basis_strict(B₁(),B₂()))}()
6367
BroadcastStyle(S::BaseFieldStyle, ::DefaultArrayStyle{0}) = S
68+
BaseFieldStyle{S,B}(::Val{2}) where {S,B} = DefaultArrayStyle{2}()
6469

6570
# with the Broadcasted object created, we now compute the answer
6671
function materialize(bc::Broadcasted{BaseFieldStyle{S,B}}) where {S,B}
@@ -101,10 +106,13 @@ function materialize!(dst::BaseField{B}, bc::Broadcasted{BaseFieldStyle{S,B′}}
101106

102107
end
103108

104-
# the default preprocessing, which just unwraps the underlying array.
105-
# this doesn't dispatch on the first argument, but custom BaseFields
106-
# are free to override this and dispatch on it if they need
107-
preprocess(::Any, f::BaseField) = f.arr
109+
# if broadcasting into a BaseField, the first method here is hit with
110+
# dest::Tuple{BaseFieldStyle,M}, in which case just unwrap the array,
111+
# since it will be fed into a downstream regular broadcast
112+
preprocess(::Tuple{BaseFieldStyle{S,B},M}, f::BaseField) where {S,B,M} = f.arr
113+
# if broadcasting into an Array (ie dropping the BaseField wrapper) we
114+
# need to return the vector representation
115+
preprocess(::AbstractArray, f::BaseField) = view(f.arr, :)
108116

109117
# we re-wrap each Broadcasted object as we go through preprocessing
110118
# because some array types do special things here (e.g. CUDA wraps
@@ -135,8 +143,7 @@ function strict_compatible_metadata(f₁::BaseField, f₂::BaseField)
135143
end
136144

137145
## mapping
138-
139-
# this comes up in Zygote.broadcast_forward, and the generic falls back to a regular Array
146+
# map over entries in the array like a true AbstractArray
140147
map(func, f::BaseField{B}) where {B} = BaseField{B}(map(func, f.arr), f.metadata)
141148

142149

@@ -169,4 +176,4 @@ getproperty(f::BaseField{B}, k::Union{typeof.(Val.((:I,:Q,:U,:E,:B)))...}) where
169176
BaseField{B₀}(_reshape_batch(view(getfield(f,:arr), pol_slice(f, pol_index(B(), k))...)), getfield(f,:metadata))
170177
getproperty(f::BaseS02{Basis3Prod{𝐈,B₂,B₀}}, ::Val{:P}) where {B₂,B₀} =
171178
BaseField{Basis2Prod{B₂,B₀}}(view(getfield(f,:arr), pol_slice(f, 2:3)...), getfield(f,:metadata))
172-
getproperty(f::BaseS2, ::Val{:P}) = f
179+
getproperty(f::BaseS2, ::Val{:P}) = f

src/field_tuples.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ typealias_def(::Type{<:FieldTuple{FS,T}}) where {FS<:Tuple,T} =
2828
### array interface
2929
size(f::FieldTuple) = (mapreduce(length, +, f.fs, init=0),)
3030
copy(f::FieldTuple) = FieldTuple(map(copy,f.fs))
31+
copyto!(dst::AbstractArray, src::FieldTuple) = copyto!(dst, src[:]) # todo: memory optimization possible
3132
iterate(ft::FieldTuple, args...) = iterate(ft.fs, args...)
3233
getindex(f::FieldTuple, i::Union{Int,UnitRange}) = getindex(f.fs, i)
3334
fill!(ft::FieldTuple, x) = (map(f->fill!(f,x), ft.fs); ft)
3435
get_storage(f::FieldTuple) = only(unique(map(get_storage, f.fs)))
3536
adapt_structure(to, f::FieldTuple) = FieldTuple(map(f->adapt(to,f),f.fs))
3637
similar(ft::FieldTuple) = FieldTuple(map(similar,ft.fs))
3738
similar(ft::FieldTuple, ::Type{T}) where {T<:Number} = FieldTuple(map(f->similar(f,T),ft.fs))
39+
similar(ft::FieldTuple, ::Type{T}, dims::Base.DimOrInd...) where {B,T} = similar(ft.fs[1].arr, T, dims...) # todo: make work for heterogenous arrays?
3840
similar(ft::FieldTuple, Nbatch::Int) = FieldTuple(map(f->similar(f,Nbatch),ft.fs))
3941
sum(f::FieldTuple; dims=:) = dims == (:) ? sum(sum, f.fs) : error("sum(::FieldTuple, dims=$dims not supported")
4042

@@ -54,6 +56,7 @@ function BroadcastStyle(::FieldTupleStyle{S₁,Names}, ::FieldTupleStyle{S₂,Na
5456
FieldTupleStyle{Tuple{map_tupleargs((s₁,s₂)->typeof(result_style(s₁(),s₂())), S₁, S₂)...}, Names}()
5557
end
5658
BroadcastStyle(S::FieldTupleStyle, ::DefaultArrayStyle{0}) = S
59+
FieldTupleStyle{S,Names}(::Val{2}) where {S,Names} = DefaultArrayStyle{2}()
5760

5861

5962
@generated function materialize(bc::Broadcasted{FieldTupleStyle{S,Names}}) where {S,Names}
@@ -73,13 +76,29 @@ end
7376
struct FieldTupleComponent{i} end
7477

7578
preprocess(::Tuple{<:Any,FieldTupleComponent{i}}, ft::FieldTuple) where {i} = ft.fs[i]
79+
preprocess(::AbstractArray, ft::FieldTuple) = vcat((view(f.arr, :) for f in ft.fs)...)
7680

7781

82+
### mapping
83+
# map over entries in the component fields like a true AbstractArray
84+
map(func, ft::FieldTuple) = FieldTuple(map(f -> map(func, f), ft.fs))
85+
7886
### promotion
7987
function promote(ft1::FieldTuple, ft2::FieldTuple)
8088
fts = map(promote, ft1.fs, ft2.fs)
8189
FieldTuple(map(first,fts)), FieldTuple(map(last,fts))
8290
end
91+
# allow very basic arithmetic with FieldTuple & AbstractArray
92+
function promote(ft::FieldTuple, x::AbstractVector)
93+
lens = map(length, ft.fs)
94+
offsets = typeof(lens)((cumsum([1; lens...])[1:end-1]...,))
95+
x_ft = FieldTuple(map(ft.fs, offsets, lens) do f, offset, len
96+
promote(f, view(x, offset:offset+len-1))[2]
97+
end)
98+
(ft, x_ft)
99+
end
100+
promote(x::AbstractVector, ft::FieldTuple) = reverse(promote(ft, x))
101+
83102

84103
### conversion
85104
Basis(ft::FieldTuple) = ft
@@ -120,4 +139,4 @@ tr(L::Diagonal{<:Union{Real,Complex}, <:FieldTuple}) = reduce(+, map(tr∘Diagon
120139
batch_length(ft::FieldTuple) = only(unique(map(batch_length, ft.fs)))
121140
batch_index(ft::FieldTuple, I) = FieldTuple(map(f -> batch_index(f, I), ft.fs))
122141
getindex(ft::FieldTuple, k::Symbol) = ft.fs[k]
123-
haskey(ft::FieldTuple, k::Symbol) = haskey(ft.fs, k)
142+
haskey(ft::FieldTuple, k::Symbol) = haskey(ft.fs, k)

src/generic.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,16 @@ show_vector(io::IO, f::Field) = !isempty(f) && show_vector(io, f[:])
330330
Base.has_offset_axes(::Field) = false # needed for Diagonal(::Field) if the Field is implicitly-sized
331331

332332

333-
# addition/subtraction works between any fields and scalars, promotion is done
334-
# automatically if fields are in different bases
335-
for op in (:+,:-), (T1,T2,promote) in ((:Field,:Scalar,false),(:Scalar,:Field,false),(:Field,:Field,true))
333+
# addition/subtraction works between fields, scalars, and
334+
# abstractarrays. promotion is done automatically for fields in
335+
# different bases are wrapped assuming they're the same field type
336+
for op in (:+,:-), (T1,T2,promote) in [
337+
(:Field, :Scalar, false),
338+
(:Scalar, :Field, false),
339+
(:Field, :Field, true),
340+
(:Field, :AbstractArray, true),
341+
(:AbstractArray, :Field, true)
342+
]
336343
@eval ($op)(a::$T1, b::$T2) = broadcast($op, ($promote ? promote(a,b) : (a,b))...)
337344
end
338345

src/muse.jl

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
# interface with MuseInference.jl
33

44
using .MuseInference: AbstractMuseProblem, MuseResult
5-
import .MuseInference: ∇θ_logLike, sample_x_z, ẑ_at_θ, muse!, standardizeθ
5+
using .MuseInference.AbstractDifferentiation
6+
import .MuseInference: logLike, ∇θ_logLike, sample_x_z, ẑ_at_θ, muse!, standardizeθ
67

78
export CMBLensingMuseProblem
89

@@ -14,10 +15,20 @@ struct CMBLensingMuseProblem{DS<:DataSet,DS_SIM<:DataSet} <: AbstractMuseProblem
1415
θ_fixed
1516
x
1617
latent_vars
18+
autodiff
1719
end
1820

19-
function CMBLensingMuseProblem(ds, ds_for_sims=ds; parameterization=0, MAP_joint_kwargs=(;), θ_fixed=(;), latent_vars=nothing)
20-
CMBLensingMuseProblem(ds, ds_for_sims, parameterization, MAP_joint_kwargs, θ_fixed, ds.d, latent_vars)
21+
function CMBLensingMuseProblem(
22+
ds,
23+
ds_for_sims = ds;
24+
parameterization = 0,
25+
MAP_joint_kwargs = (;),
26+
θ_fixed = (;),
27+
latent_vars = nothing,
28+
autodiff = AD.HigherOrderBackend((AD.ForwardDiffBackend(tag=false), AD.ZygoteBackend())),
29+
)
30+
parameterization == 0 || error("only parameterization=0 (unlensed parameterization) currently implemented")
31+
CMBLensingMuseProblem(ds, ds_for_sims, parameterization, MAP_joint_kwargs, θ_fixed, ds.d, latent_vars, autodiff)
2132
end
2233

2334
mergeθ(prob::CMBLensingMuseProblem, θ) = isempty(prob.θ_fixed) ? θ : (;prob.θ_fixed..., θ...)
@@ -27,26 +38,21 @@ function standardizeθ(prob::CMBLensingMuseProblem, θ)
2738
1f0 * ComponentVector(θ) # ensure component vector and float
2839
end
2940

41+
function MuseInference.logLike(prob::CMBLensingMuseProblem, d, z, θ)
42+
logpdf(prob.ds; z..., θ = mergeθ(prob, θ), d)
43+
end
44+
3045
function ∇θ_logLike(prob::CMBLensingMuseProblem, d, z, θ)
31-
@unpack ds, parameterization = prob
32-
@set! ds.d = d
33-
if parameterization == 0
34-
gradient-> logpdf(ds; z..., θ = mergeθ(prob, θ)), θ)[1]
35-
elseif parameterization == :mix
36-
= mix(ds; z..., θ = mergeθ(prob, θ))
37-
gradient-> logpdf(Mixed(ds); z°..., θ = mergeθ(prob, θ)), θ)[1]
38-
else
39-
error("parameterization should be 0 or :mix")
40-
end
46+
AD.gradient(prob.autodiff, θ -> logLike(prob, d, z, θ), θ)[1]
4147
end
4248

4349
function sample_x_z(prob::CMBLensingMuseProblem, rng::AbstractRNG, θ)
4450
sim = simulate(rng, prob.ds_for_sims, θ = mergeθ(prob, θ))
4551
if prob.latent_vars == nothing
4652
# this is a guess which might not work for everything necessarily
47-
z = FieldTuple(delete(sim, (:f̃, :d, )))
53+
z = LenseBasis(FieldTuple(delete(sim, (:f̃, :d, ))) )
4854
else
49-
z = FieldTuple(select(sim, prob.latent_vars))
55+
z = LenseBasis(FieldTuple(select(sim, prob.latent_vars)))
5056
end
5157
x = sim.d
5258
(;x, z)
@@ -56,12 +62,12 @@ function ẑ_at_θ(prob::CMBLensingMuseProblem, d, zguess, θ; ∇z_logLike_atol
5662
@unpack ds = prob
5763
Ωstart = delete(NamedTuple(zguess), :f)
5864
MAP = MAP_joint(mergeθ(prob, θ), @set(ds.d=d), Ωstart; fstart=zguess.f, prob.MAP_joint_kwargs...)
59-
FieldTuple(;delete(MAP, :history)...), MAP.history
65+
LenseBasis(FieldTuple(;delete(MAP, :history)...)), MAP.history
6066
end
6167

6268
function ẑ_at_θ(prob::CMBLensingMuseProblem{<:NoLensingDataSet}, d, (f₀,), θ; ∇z_logLike_atol=nothing)
6369
@unpack ds = prob
64-
FieldTuple(f=argmaxf_logpdf(I, mergeθ(prob, θ), @set(ds.d=d); fstart=f₀, prob.MAP_joint_kwargs...)), nothing
70+
LenseBasis(FieldTuple(f=argmaxf_logpdf(I, mergeθ(prob, θ), @set(ds.d=d); fstart=f₀, prob.MAP_joint_kwargs...))), nothing
6571
end
6672

6773
function muse!(result::MuseResult, ds::DataSet, θ₀=nothing; parameterization=0, MAP_joint_kwargs=(;), kwargs...)

src/proj_lambert.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,17 @@ promote_metadata_generic(metadata₁::ProjLambert, metadata₂::ProjLambert) =
131131
# return `Broadcasted` objects which are spliced into the final
132132
# broadcast, thus avoiding allocating any temporary arrays.
133133

134-
function preprocess((_,proj)::Tuple{<:Any,<:ProjLambert{T,V}}, r::Real) where {T,V}
134+
function preprocess((_,proj)::Tuple{<:BaseFieldStyle,<:ProjLambert{T,V}}, r::Real) where {T,V}
135135
r isa BatchedReal ? adapt(V, reshape(r.vals, 1, 1, 1, :)) : r
136136
end
137137
# need custom adjoint here bc Δ can come back batched from the
138138
# backward pass even though r was not batched on the forward pass
139-
@adjoint function preprocess(m::Tuple{<:Any,<:ProjLambert{T,V}}, r::Real) where {T,V}
139+
@adjoint function preprocess(m::Tuple{<:BaseFieldStyle,<:ProjLambert{T,V}}, r::Real) where {T,V}
140140
preprocess(m, r), Δ -> (nothing, Δ isa AbstractArray ? batch(real.(Δ[:])) : Δ)
141141
end
142142

143143

144-
function preprocess((_,proj)::Tuple{BaseFieldStyle{S,B},<:ProjLambert}, ∇d::∇diag) where {S,B}
144+
function preprocess((_,proj)::Tuple{<:BaseFieldStyle{S,B},<:ProjLambert}, ∇d::∇diag) where {S,B}
145145

146146
(B <: Union{Fourier,QUFourier,IQUFourier}) ||
147147
error("Can't broadcast ∇[$(∇d.coord)] as a $(typealias(B)), its not diagonal in this basis.")
@@ -156,15 +156,15 @@ function preprocess((_,proj)::Tuple{BaseFieldStyle{S,B},<:ProjLambert}, ∇d::
156156
end
157157
end
158158

159-
function preprocess((_,proj)::Tuple{BaseFieldStyle{S,B},<:ProjLambert}, ::²diag) where {S,B}
159+
function preprocess((_,proj)::Tuple{<:BaseFieldStyle{S,B},<:ProjLambert}, ::²diag) where {S,B}
160160

161161
(B <: Union{Fourier,<:Basis2Prod{<:Any,Fourier},<:Basis3Prod{<:Any,<:Any,Fourier}}) ||
162162
error("Can't broadcast a BandPass as a $(typealias(B)), its not diagonal in this basis.")
163163

164164
broadcasted(+, broadcasted(^, proj.ℓx', 2), broadcasted(^, proj.ℓy, 2))
165165
end
166166

167-
function preprocess((_,proj)::Tuple{<:Any,<:ProjLambert}, bp::BandPass)
167+
function preprocess((_,proj)::Tuple{<:BaseFieldStyle,<:ProjLambert}, bp::BandPass)
168168
Cℓ_to_2D(bp.Wℓ, proj)
169169
end
170170

0 commit comments

Comments
 (0)