Skip to content

Commit 8e2db5e

Browse files
committed
write up gradients for Gaussian PSF
1 parent eb7de41 commit 8e2db5e

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

src/PSFModels.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,12 @@ plot(model, axes(other)) # use axes from other array
111111
"""
112112
module PSFModels
113113

114+
using ChainRulesCore
115+
import ChainRulesCore: frule, rrule
114116
using CoordinateTransformations
115117
using Distances
116118
using KeywordCalls
119+
using LinearAlgebra
117120
using SpecialFunctions
118121
using StaticArrays
119122

src/gaussian.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,51 @@ function (g::Gaussian{T,<:Union{Tuple,AbstractVector}})(point::AbstractVector) w
6161
val = g.amp * exp(GAUSS_PRE * Δ)
6262
return convert(T, val)
6363
end
64+
65+
## gradients
66+
67+
# isotropic
68+
function fgrad(g::Gaussian, point::AbstractVector)
69+
f = g(point)
70+
71+
xdiff = first(point) - first(g.pos)
72+
ydiff = last(point) - last(g.pos)
73+
dfdpos = -2 * GAUSS_PRE * f / g.fwhm^2 .* SA[xdiff, ydiff]
74+
dfdfwhm = -2 * GAUSS_PRE * f * (xdiff^2 + ydiff^2) / g.fwhm^3
75+
dfdamp = f / g.amp
76+
return f, dfdpos, dfdfwhm, dfdamp
77+
end
78+
79+
# short printing
80+
Base.show(io::IO, g::Gaussian{T}) where {T} = print(io, "Gaussian{$T}(pos=$(g.pos), fwhm=$(g.fwhm), amp=$(g.amp))")
81+
82+
# diagonal
83+
function fgrad(g::Gaussian{T,<:Union{Tuple,AbstractVector}}, point::AbstractVector) where T
84+
f = g(point)
85+
86+
xdiff = first(point) - first(g.pos)
87+
ydiff = last(point) - last(g.pos)
88+
dfdpos = -2 * GAUSS_PRE * f .* SA[xdiff / first(g.fwhm)^2, ydiff / last(g.fwhm)^2]
89+
dfdfwhm = -2 * GAUSS_PRE * f .* SA[xdiff^2 / first(g.fwhm)^3, ydiff^2 / last(g.fwhm)^3]
90+
dfda = f / g.amp
91+
return f, dfdpos, dfdfwhm, dfda
92+
end
93+
94+
function frule((Δpsf, Δp), g::Gaussian, point::AbstractVector)
95+
f, dfdpos, dfdfwhm, dfda = fgrad(g, point)
96+
Δf = dot(dfdpos, Δpsf.pos) + dot(dfdfwhm, Δpsf.fwhm) + dfda * Δpsf.amp
97+
Δf -= dot(dfdpos, Δp)
98+
return f, Δf
99+
end
100+
101+
function rrule(g::G, point::AbstractVector) where {G<:Gaussian}
102+
f, dfdpos, dfdfwhm, dfda = fgrad(g, point)
103+
function Gaussian_pullback(Δf)
104+
∂pos = dfdpos .* Δf
105+
∂fwhm = dfdfwhm .* Δf
106+
∂g = Tangent{G}(pos=∂pos, fwhm=∂fwhm, amp=dfda * Δf, indices=ZeroTangent())
107+
∂pos = dfdpos .* -Δf
108+
return ∂g, ∂pos
109+
end
110+
return f, Gaussian_pullback
111+
end

0 commit comments

Comments
 (0)