Skip to content

Commit e152bce

Browse files
authored
Merge pull request #31 from ziyiyin97/ad
wrap a function as a jet operator
2 parents 4ee37fc + 8a53751 commit e152bce

File tree

7 files changed

+70
-1
lines changed

7 files changed

+70
-1
lines changed

Project.toml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
11
name = "JetPack"
22
uuid = "24ef3835-3876-54c3-8a7a-956cf69ca0b2"
3-
version = "1.1.2"
3+
version = "1.2.0"
44

55
[deps]
66
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
7+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
78
Jets = "2a57b368-ab28-5ba9-84aa-637fe7991822"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1011
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1112

13+
[weakdeps]
14+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
15+
16+
[extensions]
17+
JetPackFluxExt = "Flux"
18+
1219
[compat]
1320
FFTW = "1"
21+
Flux = "0.12, 0.13"
1422
Jets = "^1.2"
1523
SpecialFunctions = "1, 2"
1624
julia = "1"
25+
26+
[extras]
27+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ These operators include the following functionality:
1313
* Blending and projection operators
1414
* Operators implementing transcendental functions like logarithm and exponentiation, with linearization
1515
* Special functions as used in some full waveform inversion approaches, including *Normalized Integral Method* and linear non-analytic operators including taking real and imaginary parts
16+
* Wrapping an auto-differentiable Julia function as a Jet operator
1617

1718
[docs-dev-img]: https://img.shields.io/badge/docs-dev-blue.svg
1819
[docs-dev-url]: https://chevronetc.github.io/JetPack.jl/dev/

ext/JetPackFluxExt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module JetPackFluxExt
2+
3+
using JetPack, Jets, Flux
4+
"""
5+
F = JopAD(f, dom, rng)
6+
7+
where `F` wraps a Julia native function `f` as a Jet operator with domain
8+
and range given by `dom` and `rng`, respectively. Notice that it is the
9+
responsibility of the user to ensure that function `f` is auto-differentiable
10+
in both forward-mode and reverse-mode via the function `pushforward` and
11+
`pullback` in Flux. If the user chooses to use customized AD rules for this
12+
function `f`, it is also the responsibility of the user to ensure that these
13+
rules are tested (e.g. linearity, dot product, and linearization tests).
14+
"""
15+
JetPack.JopAD(f::Function; dom, rng) = JopNl(dom = dom, rng = rng, f! = JopAD_f!, df! = JopAD_df!, df′! = JopAD_df′!, s = (f=f,))
16+
17+
JopAD_f!(d::AbstractArray, m::AbstractArray; f::Function) = d .= f(m)
18+
JopAD_df!(δd::AbstractArray, δm::AbstractArray; mₒ, f::Function) = δd .= Flux.pushforward(f, mₒ)(δm)
19+
JopAD_df′!(δm::AbstractArray, δd::AbstractArray; mₒ, f::Function) = δm .= Flux.pullback(f, mₒ)[2](δd)[1]
20+
21+
end

src/JetPack.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,14 @@ include("jop_taper.jl")
4242
include("jop_tanh.jl")
4343
include("jop_translation.jl")
4444

45+
46+
###### JetPack with Flux extension
47+
48+
function JopAD end;
49+
export JopAD
50+
51+
if !isdefined(Base, :get_extension)
52+
include("../ext/JetPackFluxExt.jl")
53+
end
54+
4555
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
3+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
34
Jets = "2a57b368-ab28-5ba9-84aa-637fe7991822"
45
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
56
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

test/jop_ad.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using LinearAlgebra, Jets, JetPack, Test, Flux
2+
3+
spc = JetSpace(Float64, 10, 10)
4+
5+
### set up a nonlinear operator according to the function `f`
6+
f(x) = x ^ 2
7+
F = JopAD(f; dom=spc, rng=spc)
8+
9+
@testset "JopAD: linearity, dot product, and linearization tests" begin
10+
11+
x = rand(domain(F))
12+
J = jacobian(F, x)
13+
lhs, rhs = dot_product_test(J, rand(domain(J)), rand(range(J)))
14+
@test lhs rhs
15+
lhs,rhs = linearity_test(J)
16+
@test lhs rhs
17+
m0 = rand(domain(F))
18+
μ = sqrt.([1/1,1/2,1/4,1/8,1/16,1/32,1/64,1/128,1/256,1/512,1/1024,1/2048,1/4096,1/8192])
19+
δm = rand(domain(F))
20+
observed, expected = linearization_test(F, m0, μ = μ, δm = δm)
21+
δ = minimum(abs, observed - expected)
22+
@test δ < 1e-6
23+
24+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Random.seed!(101)
44

55
for filename in (
66
"jop_atan.jl",
7+
"jop_ad.jl",
78
"jop_blend.jl",
89
"jop_circshift.jl",
910
"jop_derivative.jl",

0 commit comments

Comments
 (0)