Skip to content

Commit 42e9a63

Browse files
committed
working GPU posterior hessians
1 parent 01a9225 commit 42e9a63

File tree

2 files changed

+59
-69
lines changed

2 files changed

+59
-69
lines changed

src/gpu.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,29 @@ end
9090
# prevents unnecessary CuArray views in some cases
9191
Base.view(arr::CuArray{T,2}, I, J, K, ::typeof(..)) where {T} = view(arr, I, J, K)
9292
Base.view(arr::CuArray{T,3}, I, J, K, ::typeof(..)) where {T} = view(arr, I, J, K)
93+
94+
95+
## ForwardDiff through FFTs
96+
# these definitions needed bc the CUDA.jl definitions supersede the
97+
# AbstractArray ones in autodiff.jl
98+
99+
for P in [AbstractFFTs.Plan, AbstractFFTs.ScaledPlan]
100+
for op in [:(Base.:*), :(Base.:\)]
101+
@eval function ($op)(plan::$P, arr::CuArray{<:Union{Dual{T},Complex{<:Dual{T}}}}) where {T}
102+
arr_of_duals(T, apply_plan($op, plan, arr)...)
103+
end
104+
end
105+
end
106+
107+
AbstractFFTs.plan_fft(arr::CuArray{<:Complex{<:Dual}}, region) = plan_fft(complex.(value.(real.(arr)), value.(imag.(arr))), region)
108+
AbstractFFTs.plan_rfft(arr::CuArray{<:Dual}, region; kws...) = plan_rfft(value.(arr), region; kws...)
109+
110+
# until something like https://github.com/JuliaDiff/ForwardDiff.jl/pull/619
111+
function ForwardDiff.extract_gradient!(::Type{T}, result::CuArray, dual::Dual) where {T}
112+
result[:] .= partials(T, dual)
113+
return result
114+
end
115+
function ForwardDiff.extract_gradient_chunk!(::Type{T}, result::CuArray, dual, index, chunksize) where {T}
116+
result[index:index+chunksize-1] .= partials.(T, dual, 1:chunksize)
117+
return result
118+
end

src/proj_lambert.jl

Lines changed: 33 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ end
355355

356356

357357
### creating covariance operators
358-
# fixed covariances
358+
359+
## fixed covariances
359360
Cℓ_to_Cov(pol::Symbol, args...; kwargs...) = Cℓ_to_Cov(Val(pol), args...; kwargs...)
360361
function Cℓ_to_Cov(::Val{:I}, proj::ProjLambert, Cℓ::Cℓs; units=proj.Ωpix)
361362
Diagonal(LambertFourier(Cℓ_to_2D(Cℓ,proj), proj) / units)
@@ -367,85 +368,48 @@ function Cℓ_to_Cov(::Val{:IP}, proj::ProjLambert, CℓTT, CℓEE, CℓBB, Cℓ
367368
ΣTT, ΣEE, ΣBB, ΣTE = [Cℓ_to_Cov(:I,proj,Cℓ; kwargs...) for Cℓ in (CℓTT,CℓEE,CℓBB,CℓTE)]
368369
BlockDiagIEB(@SMatrix([ΣTT ΣTE; ΣTE ΣEE]), ΣBB)
369370
end
370-
# ParamDependentOp covariances scaled by amplitudes in different ℓ-bins
371-
function Cℓ_to_Cov(::Val{:I}, proj::ProjLambert{T}, (Cℓ, ℓedges, θname)::Tuple; kwargs...) where {T}
372-
# we need an @eval here since we want to dynamically select a
373-
# keyword argument name, θname. the @eval happens into Main rather
374-
# than CMBLensing as a workaround for
375-
# https://discourse.julialang.org/t/closure-not-shipping-to-remote-workers-except-from-main/38831
376-
C₀ = diag(Cℓ_to_Cov(:I, proj, Cℓ; kwargs...))
377-
@eval Main let ℓedges=$((T.(ℓedges))...,), C₀=$C₀
378-
$ParamDependentOp(function (;$θname=ones($T,length(ℓedges)-1),_...)
379-
As = $preprocess.(Ref((nothing,C₀.metadata)), $T.($ensure1d($θname)))
380-
CℓI = $Zygote.ignore() do
381-
copy(C₀.Il) .* one.(first(As))# gets batching right
382-
end
383-
$Diagonal($LambertFourier($bandpower_rescale!(ℓedges, C₀.ℓmag, CℓI, As...), C₀.metadata))
384-
end)
385-
end
371+
372+
## ParamDependentOp covariances scaled by amplitudes in different ℓ-bins
373+
# note we need an @eval below since we want to dynamically select a
374+
# keyword argument name, θname. the @eval happens into Main rather
375+
# than CMBLensing as a workaround for
376+
# https://discourse.julialang.org/t/closure-not-shipping-to-remote-workers-except-from-main/38831
377+
function Cℓ_to_Cov(::Val{:I}, proj::ProjLambert{T,V}, (Cℓ, ℓedges, θname)::Tuple; kwargs...) where {T,V}
378+
C₀ = Cℓ_to_Cov(:I, proj, Cℓ; kwargs...)
379+
ℓbin_indices = findbin.(Ref(adapt(proj.storage, ℓedges)), proj.ℓmag)
380+
Cov(θ) = Diagonal(LambertFourier(bandpower_rescale(C₀.diag.arr, ℓbin_indices, θ), proj))
381+
ParamDependentOp(@eval Main let Cov=$Cov
382+
(;$θname=$(ones(T,length(ℓedges)-1)), _...) -> Cov($θname)
383+
end)
386384
end
385+
387386
function Cℓ_to_Cov(::Val{:P}, proj::ProjLambert{T}, (CℓEE, ℓedges, θname)::Tuple, CℓBB::Cℓs; kwargs...) where {T}
388-
C₀ = diag(Cℓ_to_Cov(:P, proj, CℓEE, CℓBB; kwargs...))
389-
@eval Main let ℓedges=$((T.(ℓedges))...,), C₀=$C₀
390-
ParamDependentOp(function (;$θname=ones($T,length(ℓedges)-1),_...)
391-
AEs = $preprocess.(Ref((nothing,C₀.metadata)), $T.($ensure1d($θname)))
392-
CℓE, CℓB = $Zygote.ignore() do
393-
copy(C₀.El) .* one.(first(AEs)), copy(C₀.Bl) .* one.(first(AEs)) # gets batching right
394-
end
395-
Diagonal(LambertEBFourier($bandpower_rescale!(ℓedges, C₀.ℓmag, CℓE, AEs...), CℓB, C₀.metadata))
396-
end)
397-
end
387+
C₀ = Cℓ_to_Cov(:P, proj, CℓEE, CℓBB; kwargs...)
388+
ℓbin_indices = findbin.(Ref(adapt(proj.storage, ℓedges)), proj.ℓmag)
389+
Cov(θ) = Diagonal(LambertEBFourier(bandpower_rescale(C₀.diag.El, ℓbin_indices, θ), one(eltype(θ)) .* C₀.diag.Bl, proj))
390+
ParamDependentOp(@eval Main let Cov=$Cov
391+
(;$θname=$(ones(T,length(ℓedges)-1)), _...) -> Cov($θname)
392+
end)
398393
end
399-
# this is written weird because the stuff inside the broadcast! needs
400-
# to work as a GPU kernel
401-
function bandpower_rescale!(ℓedges, ℓ, Cℓ, A...)
402-
length(A)==length(ℓedges)-1 || error("Expected $(length(ℓedges)-1) bandpower parameters, got $(length(A)).")
403-
eltype(A[1]) <: Real || error("Bandpower parameters must be real numbers.")
404-
if length(A)>30
405-
# if more than 30 bandpowers, we need to chunk the rescaling
406-
# because of a maximum argument limit of CUDA kernels
407-
for p in partition(1:length(A), 30)
408-
bandpower_rescale!(ℓedges[p.start:(p.stop+1)], ℓ, Cℓ, A[p]...)
409-
end
410-
else
411-
broadcast!(Cℓ, ℓ, Cℓ, A...) do ℓ, Cℓ, A...
412-
for i=1:length(ℓedges)-1
413-
(ℓedges[i] << ℓedges[i+1]) && return A[i] * Cℓ
414-
end
415-
return Cℓ
416-
end
417-
end
418-
Cℓ
419-
end
420-
# cant reliably get Zygote's gradients to work through these
421-
# broadcasts, which on GPU use ForwardDiff, so write the adjoint by
422-
# hand for now. likely more performant, in any case.
423-
@adjoint function bandpower_rescale!(ℓedges, ℓ, Cℓ, A...)
424-
back = let Cℓ = copy(Cℓ) # need copy bc Cℓ mutated on forward pass
425-
function (Δ)
426-
= map(1:length(A)) do i
427-
sum(
428-
real,
429-
broadcast(Δ, ℓ, Cℓ) do Δ, ℓ, Cℓ
430-
(ℓedges[i] << ℓedges[i+1]) ? Δ * Cℓ : zero(Δ)
431-
end,
432-
dims = ndims(Δ)==4 ? (1,2) : (:)
433-
)
434-
end
435-
(nothing, nothing, nothing, Ā...)
436-
end
437-
end
438-
bandpower_rescale!(ℓedges, ℓ, Cℓ, A...), back
394+
395+
# helper function for scaling the covariances in ℓ-bins
396+
function findbin(ℓedges, ℓ; out_of_range=length(ℓedges))
397+
(ℓ<ℓedges[1] ||>=ℓedges[end]) ? out_of_range : findfirst(>(ℓ), ℓedges)::Int - 1
398+
end
399+
function bandpower_rescale(arr::A, ℓbin_indices, amplitudes) where {T<:Real, A<:AbstractArray{T}}
400+
amplitudes_arr = adapt(basetype(A), [T.(amplitudes); 1])
401+
return amplitudes_arr[ℓbin_indices] .* arr
439402
end
403+
404+
405+
### Covariance back to Cℓs
440406
function cov_to_Cℓ(C::DiagOp{<:LambertS0}; kwargs...)
441407
@unpack Nx, Ny, Δx = diag(C)
442408
α = Nx*Ny/Δx^2
443409
get_Cℓ(sqrt.(diag(C)); kwargs...)*sqrt(α)
444410
end
445411

446412

447-
448-
449413
### spin adjoints
450414
function *(a::SpinAdjoint{F}, b::F) where {B<:Union{Map,Basis2Prod{<:Any,Map},Basis3Prod{<:Any,<:Any,Map}},F<:LambertField{B}}
451415
r = sum(a.f.arr .* b.arr, dims=3)

0 commit comments

Comments
 (0)