Skip to content

Commit 68b24de

Browse files
author
ziyiyin97
committed
switch from Requires.jl to an extension
1 parent e1f2536 commit 68b24de

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

Project.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,24 @@ version = "1.1.3"
44

55
[deps]
66
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
7+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
78
Jets = "2a57b368-ab28-5ba9-84aa-637fe7991822"
89
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9-
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1111
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1212

13+
[weakdeps]
14+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
15+
16+
[extensions]
17+
JetPackFluxExt = "Flux"
18+
1319
[compat]
1420
FFTW = "1"
21+
Flux = "0.12, 0.13"
1522
Jets = "^1.2"
16-
Requires = "1"
1723
SpecialFunctions = "1, 2"
1824
julia = "1"
25+
26+
[extras]
27+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"

src/jop_ad.jl renamed to ext/JetPackFluxExt.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
module JetPackFluxExt
2+
3+
using JetPack, Jets, Flux
14
"""
25
F = JopAD(f, dom, rng)
36
@@ -9,9 +12,10 @@ in both forward-mode and reverse-mode via the function `pushforward` and
912
function `f`, it is also the responsibility of the user to ensure that these
1013
rules are tested (e.g. linearity, dot product, and linearization tests).
1114
"""
12-
JopAD(f::Function; dom, rng) = JopNl(dom = dom, rng = rng, f! = JopAD_f!, df! = JopAD_df!, df′! = JopAD_df′!, s = (f=f,))
13-
export JopAD
15+
JetPack.JopAD(f::Function; dom, rng) = JopNl(dom = dom, rng = rng, f! = JopAD_f!, df! = JopAD_df!, df′! = JopAD_df′!, s = (f=f,))
1416

1517
JopAD_f!(d::AbstractArray, m::AbstractArray; f::Function) = d .= f(m)
1618
JopAD_df!(δd::AbstractArray, δm::AbstractArray; mₒ, f::Function) = δd .= Flux.pushforward(f, mₒ)(δm)
17-
JopAD_df′!(δm::AbstractArray, δd::AbstractArray; mₒ, f::Function) = δm .= Flux.pullback(f, mₒ)[2](δd)[1]
19+
JopAD_df′!(δm::AbstractArray, δd::AbstractArray; mₒ, f::Function) = δm .= Flux.pullback(f, mₒ)[2](δd)[1]
20+
21+
end

src/JetPack.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ using Jets
55
using LinearAlgebra
66
using SpecialFunctions
77
using Statistics
8-
using Requires
98

109
include("jop_atan.jl")
1110
include("jop_blend.jl")
@@ -43,9 +42,7 @@ include("jop_taper.jl")
4342
include("jop_tanh.jl")
4443
include("jop_translation.jl")
4544

46-
### If Flux is loaded, users can wrap an auto-differentiable function as a Jet operator
47-
function __init__()
48-
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" @eval include("jop_ad.jl")
49-
end
45+
function JopAD end;
46+
export JopAD
5047

5148
end

0 commit comments

Comments
 (0)