diff --git a/Project.toml b/Project.toml index cfa2b44..2a82138 100644 --- a/Project.toml +++ b/Project.toml @@ -4,14 +4,17 @@ authors = ["Miles Lucas and contributors"] version = "0.3.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CoordinateTransformations = "150eb455-5306-5404-9cee-2592286d6298" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" KeywordCalls = "4d827475-d3e4-43d6-abe3-9688362ede9f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] +ChainRulesCore = "1" CoordinateTransformations = "0.6" Distances = "0.10" KeywordCalls = "0.2" diff --git a/src/PSFModels.jl b/src/PSFModels.jl index a9f5bbf..c0cda0a 100644 --- a/src/PSFModels.jl +++ b/src/PSFModels.jl @@ -111,9 +111,12 @@ plot(model, axes(other)) # use axes from other array """ module PSFModels +using ChainRulesCore +import ChainRulesCore: frule, rrule using CoordinateTransformations using Distances using KeywordCalls +using LinearAlgebra using SpecialFunctions using StaticArrays diff --git a/src/airy.jl b/src/airy.jl index cb371f2..451dc73 100644 --- a/src/airy.jl +++ b/src/airy.jl @@ -43,6 +43,9 @@ end Base.size(a::AiryDisk) = map(length, a.indices) Base.axes(a::AiryDisk) = a.indices +# short printing +Base.show(io::IO, a::AiryDisk{T}) where {T} = print(io, "AiryDisk{$T}(pos=$(a.pos), fwhm=$(a.fwhm), amp=$(a.amp))") + const rz = 3.8317059702075125 / π function (a::AiryDisk{T})(point::AbstractVector) where T diff --git a/src/gaussian.jl b/src/gaussian.jl index 2835e32..c1522f0 100644 --- a/src/gaussian.jl +++ b/src/gaussian.jl @@ -45,6 +45,9 @@ end Base.size(g::Gaussian) = map(length, g.indices) Base.axes(g::Gaussian) = g.indices +# short printing +Base.show(io::IO, g::Gaussian{T}) where {T} = print(io, "Gaussian{$T}(pos=$(g.pos), fwhm=$(g.fwhm), amp=$(g.amp))") + # Gaussian pre-factor for normalizing the exponential const GAUSS_PRE = -4 * log(2) @@ -61,3 +64,48 @@ function (g::Gaussian{T,<:Union{Tuple,AbstractVector}})(point::AbstractVector) w val = g.amp * exp(GAUSS_PRE * Δ) return convert(T, val) end + +## gradients + +# isotropic +function fgrad(g::Gaussian, point::AbstractVector) + f = g(point) + + xdiff = first(point) - first(g.pos) + ydiff = last(point) - last(g.pos) + dfdpos = -2 * GAUSS_PRE * f / g.fwhm^2 .* SA[xdiff, ydiff] + dfdfwhm = -2 * GAUSS_PRE * f * (xdiff^2 + ydiff^2) / g.fwhm^3 + dfdamp = f / g.amp + return f, dfdpos, dfdfwhm, dfdamp +end + +# diagonal +function fgrad(g::Gaussian{T,<:Union{Tuple,AbstractVector}}, point::AbstractVector) where T + f = g(point) + + xdiff = first(point) - first(g.pos) + ydiff = last(point) - last(g.pos) + dfdpos = -2 * GAUSS_PRE * f .* SA[xdiff / first(g.fwhm)^2, ydiff / last(g.fwhm)^2] + dfdfwhm = -2 * GAUSS_PRE * f .* SA[xdiff^2 / first(g.fwhm)^3, ydiff^2 / last(g.fwhm)^3] + dfda = f / g.amp + return f, dfdpos, dfdfwhm, dfda +end + +function frule((Δpsf, Δp), g::Gaussian, point::AbstractVector) + f, dfdpos, dfdfwhm, dfda = fgrad(g, point) + Δf = dot(dfdpos, Δpsf.pos) + dot(dfdfwhm, Δpsf.fwhm) + dfda * Δpsf.amp + Δf -= dot(dfdpos, Δp) + return f, Δf +end + +function rrule(g::G, point::AbstractVector) where {G<:Gaussian} + f, dfdpos, dfdfwhm, dfda = fgrad(g, point) + function Gaussian_pullback(Δf) + ∂pos = dfdpos .* Δf + ∂fwhm = dfdfwhm .* Δf + ∂g = Tangent{G}(pos=∂pos, fwhm=∂fwhm, amp=dfda * Δf, indices=NoTangent()) + ∂pos = dfdpos .* -Δf + return ∂g, ∂pos + end + return f, Gaussian_pullback +end diff --git a/src/moffat.jl b/src/moffat.jl index e4a0563..c780e4a 100644 --- a/src/moffat.jl +++ b/src/moffat.jl @@ -39,6 +39,10 @@ end Base.size(m::Moffat) = map(length, m.indices) Base.axes(m::Moffat) = m.indices +# short printing +Base.show(io::IO, m::Moffat{T}) where {T} = print(io, "Moffat{$T}(pos=$(m.pos), fwhm=$(m.fwhm), amp=$(m.amp), alpha=$(m.alpha))") + + # scalar case function (m::Moffat{T})(point::AbstractVector) where T hwhm = m.fwhm / 2 diff --git a/test/Project.toml b/test/Project.toml index bbdcf2a..6eca7ce 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,16 @@ [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] +ChainRulesCore = "1" +ChainRulesTestUtils = "1" +FiniteDifferences = "0.12" RecipesBase = "1" +StableRNGs = "1" StaticArrays = "0.12, 1" diff --git a/test/runtests.jl b/test/runtests.jl index 3d860f9..a02f5c3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,16 @@ +using ChainRulesCore +using ChainRulesTestUtils +using FiniteDifferences using PSFModels using PSFModels: Gaussian, Normal, AiryDisk, Moffat +using StableRNGs using StaticArrays using Test +ChainRulesCore.debug_mode() = true + +rng = StableRNG(122) + function test_model_interface(K) # test defaults m = @inferred K(fwhm=10) @@ -68,58 +76,95 @@ function test_model_interface(K) @test m(m.pos) ≈ BigFloat(1) end -@testset "Model Interface - $K" for K in (Gaussian, AiryDisk, Moffat) - test_model_interface(K) -end - @testset "Gaussian" begin - m = Gaussian(fwhm=10) - expected = exp(-4 * log(2) * sum(abs2, SA[1, 2]) / 100) - @test m[2, 1] ≈ m(1, 2) ≈ expected - - m = Gaussian(fwhm=(10, 9)) - wdist = (1/10)^2 + (2/9)^2 - expected = exp(-4 * log(2) * wdist) - @test m[2, 1] ≈ m(1, 2) ≈ expected + test_model_interface(Gaussian) + + @testset "isotropic" begin + m = Gaussian(fwhm=10) + expected = exp(-4 * log(2) * sum(abs2, SA[1, 2]) / 100) + @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Gaussian{Float64}(pos=[0, 0], fwhm=10, amp=1.0)" + end + + @testset "diagonal" begin + m = Gaussian(fwhm=(10, 9)) + wdist = (1/10)^2 + (2/9)^2 + expected = exp(-4 * log(2) * wdist) + @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Gaussian{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0)" + end # test Normal alias @test Normal(fwhm=10) === Gaussian(fwhm=10) + + @testset "gradients" begin + FiniteDifferences.to_vec(x::Integer) = Bool[], _ -> x + # have to make sure PSFs are all floating point so tangents don't have type issues + psf_iso = Gaussian(fwhm=10.0, pos=zeros(2)) + psf_tang = Tangent{Gaussian}(fwhm=rand(rng), pos=rand(rng, 2), amp=rand(rng), indices=NoTangent()) + point = Float64[1, 2] + test_frule(psf_iso ⊢ psf_tang, point) + test_rrule(psf_iso ⊢ psf_tang, point) + + psf_diag = Gaussian(fwhm=Float64[10, 8], pos=zeros(2)) + psf_tang = Tangent{Gaussian}(fwhm=rand(rng, 2), pos=rand(rng, 2), amp=rand(rng), indices=NoTangent()) + test_frule(psf_diag ⊢ psf_tang, point) + test_rrule(psf_diag ⊢ psf_tang, point) + end end @testset "AiryDisk" begin - m = AiryDisk(fwhm=10) - radius = m.fwhm * 1.18677 - # first radius is 0 - @test m(radius, 0) ≈ 0 atol=eps(Float64) - @test m(-radius, 0) ≈ 0 atol=eps(Float64) - @test m(0, radius) ≈ 0 atol=eps(Float64) - @test m(0, -radius) ≈ 0 atol=eps(Float64) - - m = AiryDisk(fwhm=(10, 9)) - r1 = m.fwhm[1] * 1.18677 - r2 = m.fwhm[2] * 1.18677 - # first radius is 0 - @test m(r1, 0) ≈ 0 atol=eps(Float64) - @test m(-r1, 0) ≈ 0 atol=eps(Float64) - @test m(0, r2) ≈ 0 atol=eps(Float64) - @test m(0, -r2) ≈ 0 atol=eps(Float64) + test_model_interface(AiryDisk) + + @testset "isotropic" begin + m = AiryDisk(fwhm=10) + radius = m.fwhm * 1.18677 + # first radius is 0 + @test m(radius, 0) ≈ 0 atol=eps(Float64) + @test m(-radius, 0) ≈ 0 atol=eps(Float64) + @test m(0, radius) ≈ 0 atol=eps(Float64) + @test m(0, -radius) ≈ 0 atol=eps(Float64) + @test repr(m) == "AiryDisk{Float64}(pos=[0, 0], fwhm=10, amp=1.0)" + end + + @testset "diagonal" begin + m = AiryDisk(fwhm=(10, 9)) + r1 = m.fwhm[1] * 1.18677 + r2 = m.fwhm[2] * 1.18677 + # first radius is 0 + @test m(r1, 0) ≈ 0 atol=eps(Float64) + @test m(-r1, 0) ≈ 0 atol=eps(Float64) + @test m(0, r2) ≈ 0 atol=eps(Float64) + @test m(0, -r2) ≈ 0 atol=eps(Float64) + @test repr(m) == "AiryDisk{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0)" + end end @testset "Moffat" begin - m = Moffat(fwhm=10) - expected = inv(1 + sum(abs2, SA[1, 2]) / 25) - @test m[2, 1] ≈ m(1, 2) ≈ expected - - m = Moffat(fwhm=(10, 9)) - wdist = (1/5)^2 + (2/4.5)^2 - expected = inv(1 + wdist) - @test m[2, 1] ≈ m(1, 2) ≈ expected - - # different alpha - m = Moffat(fwhm=10, alpha=2) - expected = inv(1 + sum(abs2, SA[1, 2]) / 25)^2 - @test m[2, 1] ≈ m(1, 2) ≈ expected + test_model_interface(Moffat) + + @testset "isotropic" begin + m = Moffat(fwhm=10) + expected = inv(1 + sum(abs2, SA[1, 2]) / 25) + @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=10, amp=1.0, alpha=1)" + end + + @testset "diagonal" begin + m = Moffat(fwhm=(10, 9)) + wdist = (1/5)^2 + (2/4.5)^2 + expected = inv(1 + wdist) + @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=(10, 9), amp=1.0, alpha=1)" + end + + @testset "alpha" begin + m = Moffat(fwhm=10, alpha=2) + expected = inv(1 + sum(abs2, SA[1, 2]) / 25)^2 + @test m[2, 1] ≈ m(1, 2) ≈ expected + @test repr(m) == "Moffat{Float64}(pos=[0, 0], fwhm=10, amp=1.0, alpha=2)" + end end include("plotting.jl")