diff --git a/src/CMBLensing.jl b/src/CMBLensing.jl index 40bd43f6..0b76ae1a 100644 --- a/src/CMBLensing.jl +++ b/src/CMBLensing.jl @@ -106,6 +106,7 @@ include("flat_batch.jl") include("masking.jl") include("taylens.jl") include("bilinearlens.jl") +include("biquadlens.jl") # plotting isjuno = false diff --git a/src/bilinearlens.jl b/src/bilinearlens.jl index 35fdad68..b512799b 100644 --- a/src/bilinearlens.jl +++ b/src/bilinearlens.jl @@ -81,7 +81,7 @@ function BilinearLens(ϕ::FlatS0) function compute_sparse_repr(is_gpu_backed::Val{false}) K = Vector{Int32}(getK(Nside)) M = similar(K) - V = similar(K,Float32) + V = similar(K,T) for I in 1:length(ĩs) compute_row!(I, ĩs[I], j̃s[I], M, V) end @@ -92,7 +92,7 @@ function BilinearLens(ϕ::FlatS0) function compute_sparse_repr(is_gpu_backed::Val{true}) K = CuVector{Cint}(getK(Nside)) M = similar(K) - V = similar(K,Float32) + V = similar(K,T) cuda(ĩs, j̃s, M, V; threads=256) do ĩs, j̃s, M, V index = threadIdx().x stride = blockDim().x @@ -104,7 +104,7 @@ function BilinearLens(ϕ::FlatS0) if !Base.isdefined(CUSPARSE,:CuSparseMatrixCOO) error("To use BilinearLens on GPU, run `using Pkg; pkg\"add https://github.com/marius311/CUDA.jl#coo\"` and restart Julia.") end - switch2csr(CUSPARSE.CuSparseMatrixCOO{Float32}(K,M,V,(Nside^2,Nside^2))) + switch2csr(CUSPARSE.CuSparseMatrixCOO{T}(K,M,V,(Nside^2,Nside^2))) end @@ -129,7 +129,7 @@ getϕ(Lϕ::BilinearLens) = Lϕ.ϕ # applying various forms of the operator function *(Lϕ::BilinearLens, f::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}} - Lϕ.sparse_repr==I && return f + Lϕ.sparse_repr===I && return f Łf = Ł(f) f̃ = similar(Łf) ds = (D == 1 ? ((),) : tuple.(1:D)) @@ -140,7 +140,7 @@ function *(Lϕ::BilinearLens, f::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}} end function *(Lϕ::Adjoint{<:Any,<:BilinearLens}, f::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}} - parent(Lϕ).sparse_repr==I && return f + parent(Lϕ).sparse_repr===I && return f Łf = Ł(f) f̃ = similar(Łf) ds = (D == 1 ? ((),) : tuple.(1:D)) diff --git a/src/biquadlens.jl b/src/biquadlens.jl new file mode 100644 index 00000000..494945a2 --- /dev/null +++ b/src/biquadlens.jl @@ -0,0 +1,215 @@ + +export BiquadLens + +@doc doc""" + + BiquadLens(ϕ) + +BiquadLens is a lensing operator that computes lensing with bilinear +interpolation. The action of the operator, as well as its adjoint, inverse, +inverse-adjoint, and gradient w.r.t. ϕ can all be computed. The log-determinant +of the operation is non-zero and can't be computed. + +Internally, BiquadLens forms a sparse matrix with the interpolation weights, +which can be applied and adjoint-ed extremely fast (e.g. at least an order of +magnitude faster than LenseFlow). Inverse and inverse-adjoint lensing is +somewhat slower as it is implemented with several steps of the [preconditioned +generalized minimal residual](https://en.wikipedia.org/wiki/Generalized_minimal_residual_method) +algorithm, taking anti-lensing as the preconditioner. + +!!! warning + + Due to [this bug](https://github.com/JuliaLang/PackageCompiler.jl/issues/379) + in PackageCompiler, currently you have to run `using SparseArrays` by hand + in your Julia session before `BiquadLens` is available. + +""" +mutable struct BiquadLens{Φ,S} <: ImplicitOp{Basis,Spin,Pix} + ϕ :: Φ + sparse_repr :: S + anti_lensing_sparse_repr :: Union{S, Nothing} +end + +function BiquadLens(ϕ::FlatS0) + + # if ϕ == 0 then just return identity operator + if norm(ϕ) == 0 + return BiquadLens(ϕ,I,I) + end + + @unpack Nside,Δx,T = fieldinfo(ϕ) + + # the (i,j)-th pixel is deflected to (ĩs[i],j̃s[j]) + j̃s,ĩs = getindex.((∇*ϕ)./Δx, :Ix) + ĩs .= ĩs .+ (1:Nside) + j̃s .= (j̃s' .+ (1:Nside))' + + # sub2ind converts a 2D index to 1D index, including wrapping at edges + indexwrap(i) = mod(i - 1, Nside) + 1 + sub2ind(i,j) = Base._sub2ind((Nside,Nside),indexwrap(i),indexwrap(j)) + + # compute the 9 non-zero entries in L[I,:] (ie the Ith row of the sparse + # lensing representation, L) and add these to the sparse constructor + # matrices, M, and V, accordingly. this function is split off so it can be + # called directly or used as a CUDA kernel + function compute_row!(I, ĩ, j̃, M, V) + + # (i,j) indices of the 9 nearest neighbors + x₋,x₀,x₊ = floor(Int,ĩ) .+ (-1, 0, 1) + y₋,y₀,y₊ = floor(Int,j̃) .+ (-1, 0, 1) + + # 1-D indices of the 9 nearest neighbors + M[9I-8:9I] .= @SVector[ + sub2ind(x₋, y₋), + sub2ind(x₀, y₋), + sub2ind(x₊, y₋), + sub2ind(x₋, y₀), + sub2ind(x₀, y₀), + sub2ind(x₊, y₀), + sub2ind(x₋, y₊), + sub2ind(x₀, y₊), + sub2ind(x₊, y₊), + ] + + # weights of these neighbors in the bilinear interpolation + Δx₋, Δx₀, Δx₊ = ((x₋,x₀,x₊) .- ĩ) + Δy₋, Δy₀, Δy₊ = ((y₋,y₀,y₊) .- j̃) + A = @SMatrix[ + 1 Δx₋ Δx₋^2 Δy₋ Δy₋^2 Δx₋*Δy₋ Δx₋*Δy₋^2 Δx₋^2*Δy₋; + 1 Δx₀ Δx₀^2 Δy₋ Δy₋^2 Δx₀*Δy₋ Δx₀*Δy₋^2 Δx₀^2*Δy₋; + 1 Δx₊ Δx₊^2 Δy₋ Δy₋^2 Δx₊*Δy₋ Δx₊*Δy₋^2 Δx₊^2*Δy₋; + 1 Δx₋ Δx₋^2 Δy₀ Δy₀^2 Δx₋*Δy₀ Δx₋*Δy₀^2 Δx₋^2*Δy₀; + 1 Δx₀ Δx₀^2 Δy₀ Δy₀^2 Δx₀*Δy₀ Δx₀*Δy₀^2 Δx₀^2*Δy₀; + 1 Δx₊ Δx₊^2 Δy₀ Δy₀^2 Δx₊*Δy₀ Δx₊*Δy₀^2 Δx₊^2*Δy₀; + 1 Δx₋ Δx₋^2 Δy₊ Δy₊^2 Δx₋*Δy₊ Δx₋*Δy₊^2 Δx₋^2*Δy₊; + 1 Δx₀ Δx₀^2 Δy₊ Δy₊^2 Δx₀*Δy₊ Δx₀*Δy₊^2 Δx₀^2*Δy₊; + 1 Δx₊ Δx₊^2 Δy₊ Δy₊^2 Δx₊*Δy₊ Δx₊*Δy₊^2 Δx₊^2*Δy₊; + ] + V[9I-8:9I] .= (inv(A'*A)*A')[1,:] + + end + + # a surprisingly large fraction of the computation for large Nside, so memoize it: + @memoize getK(Nside) = Int32.((9:9Nside^2+8) .÷ 9) + + # CPU + function compute_sparse_repr(is_gpu_backed::Val{false}) + K = Vector{Int32}(getK(Nside)) + M = similar(K) + V = similar(K,T) + for I in 1:length(ĩs) + compute_row!(I, ĩs[I], j̃s[I], M, V) + end + sparse(K,M,V,Nside^2,Nside^2) + end + + # GPU + function compute_sparse_repr(is_gpu_backed::Val{true}) + K = CuVector{Cint}(getK(Nside)) + M = similar(K) + V = similar(K,T) + cuda(ĩs, j̃s, M, V; threads=256) do ĩs, j̃s, M, V + index = threadIdx().x + stride = blockDim().x + for I in index:stride:length(ĩs) + compute_row!(I, ĩs[I], j̃s[I], M, V) + end + end + # remove once CuSparseMatrixCOO makes it into official CUDA.jl: + if !Base.isdefined(CUSPARSE,:CuSparseMatrixCOO) + error("To use BiquadLens on GPU, run `using Pkg; pkg\"add https://github.com/marius311/CUDA.jl#coo\"` and restart Julia.") + end + switch2csr(CUSPARSE.CuSparseMatrixCOO{T}(K,M,V,(Nside^2,Nside^2))) + end + + + BiquadLens(ϕ, compute_sparse_repr(Val(is_gpu_backed(ϕ))), nothing) + +end + + +# lazily computing the sparse representation for anti-lensing + +function get_anti_lensing_sparse_repr!(Lϕ::BiquadLens) + if Lϕ.anti_lensing_sparse_repr == nothing + Lϕ.anti_lensing_sparse_repr = BiquadLens(-Lϕ.ϕ).sparse_repr + end + Lϕ.anti_lensing_sparse_repr +end + + +getϕ(Lϕ::BiquadLens) = Lϕ.ϕ +(Lϕ::BiquadLens)(ϕ) = BiquadLens(ϕ) + +# applying various forms of the operator + +function *(Lϕ::BiquadLens, f::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}} + Lϕ.sparse_repr===I && return f + Łf = Ł(f) + f̃ = similar(Łf) + ds = (D == 1 ? ((),) : tuple.(1:D)) + for d in ds + mul!(@views(f̃.Ix[:,:,d...][:]), Lϕ.sparse_repr, @views(Łf.Ix[:,:,d...][:])) + end + f̃ +end + +function *(Lϕ::Adjoint{<:Any,<:BiquadLens}, f::FlatS0{P}) where {N,D,P<:Flat{N,<:Any,<:Any,D}} + parent(Lϕ).sparse_repr===I && return f + Łf = Ł(f) + f̃ = similar(Łf) + ds = (D == 1 ? ((),) : tuple.(1:D)) + for d in ds + mul!(@views(f̃.Ix[:,:,d...][:]), parent(Lϕ).sparse_repr', @views(Łf.Ix[:,:,d...][:])) + end + f̃ +end + +function \(Lϕ::BiquadLens, f::FlatS0{P}) where {N,P<:Flat{N}} + FlatMap{P}(reshape(gmres( + Lϕ.sparse_repr, view(f[:Ix],:), + Pl = get_anti_lensing_sparse_repr!(Lϕ), maxiter = 5 + ), N, N)) +end + +function \(Lϕ::Adjoint{<:Any,<:BiquadLens}, f::FlatS0{P}) where {N,P<:Flat{N}} + FlatMap{P}(reshape(gmres( + parent(Lϕ).sparse_repr', view(f[:Ix],:), + Pl = get_anti_lensing_sparse_repr!(parent(Lϕ))', maxiter = 5 + ), N, N)) +end + +# special cases for BiquadLens(0ϕ), which don't work with bicgstabl, +# see https://github.com/JuliaMath/IterativeSolvers.jl/issues/271 +function \(Lϕ::BiquadLens{<:Any,<:UniformScaling}, f::FlatS0{P}) where {N,P<:Flat{N}} + Lϕ.sparse_repr \ f +end +function \(Lϕ::Adjoint{<:Any,<:BiquadLens{<:Any,<:UniformScaling}}, f::FlatS0{P}) where {N,P<:Flat{N}} + parent(Lϕ).sparse_repr \ f +end + +for op in (:*, :\) + @eval function ($op)(Lϕ::Union{BiquadLens, Adjoint{<:Any,<:BiquadLens}}, f::FieldTuple) + Łf = Ł(f) + F = typeof(Łf) + F(map(f->($op)(Lϕ,f), Łf.fs)) + end +end + + +# gradients + +@adjoint BiquadLens(ϕ) = BiquadLens(ϕ), Δ -> (Δ,) + +@adjoint function *(Lϕ::BiquadLens, f::Field{B}) where {B} + f̃ = Lϕ * f + function back(Δ) + (∇' * (Ref(tuple_adjoint(Ł(Δ))) .* Ł(∇*f̃))), B(Lϕ*Δ) + end + f̃, back +end + + +# gpu + +adapt_structure(storage, Lϕ::BiquadLens) = BiquadLens(adapt(storage, fieldvalues(Lϕ))...) diff --git a/src/dataset.jl b/src/dataset.jl index 2feddc1a..0a4da0a0 100644 --- a/src/dataset.jl +++ b/src/dataset.jl @@ -29,15 +29,19 @@ Zygote.grad_mut(ds::DataSet) = Ref{Any}((;(propertynames(ds) .=> nothing)...)) Cf # unlensed field covariance Cf̃ = nothing # lensed field covariance (not always needed) Cn # noise covariance - Cn̂ = Cn # approximate noise covariance, diagonal in same basis as Cf M = 1 # user mask - M̂ = M # approximate user mask, diagonal in same basis as Cf B = 1 # beam and instrumental transfer functions - B̂ = B # approximate beam and instrumental transfer functions, diagonal in same basis as Cf D = 1 # mixing matrix for mixed parametrization G = 1 # reparametrization for ϕ - P = 1 # pixelization operator (if estimating field on higher res than data) L = LenseFlow # lensing operator, possibly cached for memory reuse + P = 1 # pixelization operator (if estimating field on higher res than data) + + # preconditioners (denoted by \hat). these should all approximate the + # non-\hat version, but be fast to calculate: + Cn̂ = Cn + M̂ = M + B̂ = B + L̂ = 1 end function subblock(ds::DS, block) where {DS<:DataSet} diff --git a/src/maximization.jl b/src/maximization.jl index 8e031b90..07efe4c4 100644 --- a/src/maximization.jl +++ b/src/maximization.jl @@ -218,7 +218,7 @@ function MAP_marg( ϕstart = nothing, Nϕ = :qe, nsteps = 10, - nsteps_with_meanfield_update = 4, + nsteps_with_meanfield_update = nsteps, conjgrad_kwargs = (nsteps=500, tol=1e-1), α = 0.2, weights = :unlensed,