Skip to content

Commit 9143f02

Browse files
authored
use an extension instead of Requires on 1.9+ (#170)
* move required code to `__init__` function * move require code to separate files * use an extension instead of Requires on 1.9+ * make ChainRulesCore usage into an extension on 1.9+
1 parent b3726a9 commit 9143f02

File tree

7 files changed

+96
-41
lines changed

7 files changed

+96
-41
lines changed

Project.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1111

12+
[weakdeps]
13+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
14+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
15+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
16+
17+
[extensions]
18+
TullioCUDAExt = "CUDA"
19+
TullioFillArraysExt = "FillArrays"
20+
TullioTrackerExt = "Tracker"
21+
1222
[compat]
1323
CUDA = "3.6"
1424
CUDAKernels = "0.3.3, 0.4"

ext/TullioCUDAExt.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
module TullioCUDAExt
2+
3+
if !isdefined(Base, :get_extension)
4+
using ..Tullio, ..CUDA
5+
else
6+
using Tullio, CUDA
7+
end
8+
9+
Tullio.threader(fun!::F, ::Type{T},
10+
Z::AbstractArray, As::Tuple, Is::Tuple, Js::Tuple,
11+
redfun, block=0, keep=nothing) where {F<:Function, T<:CUDA.CuArray} =
12+
fun!(T, Z, As..., Is..., Js..., keep)
13+
14+
Tullio.∇threader(fun!::F, ::Type{T},
15+
As::Tuple, Is::Tuple, Js::Tuple, block=0) where {F<:Function, T<:CUDA.CuArray} =
16+
fun!(T, As..., Is..., Js...,)
17+
18+
# Tullio.thread_scalar ... ought to work? Was never fast.
19+
20+
end

ext/TullioChainRulesCoreExt.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
module TullioChainRulesCoreExt
2+
3+
if !isdefined(Base, :get_extension)
4+
using ..Tullio, ..ChainRulesCore
5+
else
6+
using Tullio, ChainRulesCore
7+
end
8+
9+
function ChainRulesCore.rrule(ev::Tullio.Eval, args...)
10+
Z = ev.fwd(args...)
11+
Z, function tullio_back(Δ)
12+
isnothing(ev.rev) && error("no gradient definition here!")
13+
dxs = map(ev.rev(Δ, Z, args...)) do dx
14+
dx === nothing ? ChainRulesCore.ZeroTangent() : dx
15+
end
16+
tuple(ChainRulesCore.ZeroTangent(), dxs...)
17+
end
18+
end
19+
20+
end

ext/TullioFillArraysExt.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module TullioFillArraysExt
2+
3+
if !isdefined(Base, :get_extension)
4+
using ..Tullio, ..FillArrays
5+
else
6+
using Tullio, FillArrays
7+
end
8+
9+
Tullio.promote_storage(::Type{T}, ::Type{F}) where {T, F<:FillArrays.Fill} = T
10+
Tullio.promote_storage(::Type{F}, ::Type{T}) where {T, F<:FillArrays.Fill} = T
11+
12+
end

ext/TullioTrackerExt.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
module TullioTrackerExt
2+
3+
if !isdefined(Base, :get_extension)
4+
using ..Tullio, ..Tracker
5+
else
6+
using Tullio, Tracker
7+
end
8+
9+
(ev::Tullio.Eval)(A::Tracker.TrackedArray, args...) = Tracker.track(ev, A, args...)
10+
(ev::Tullio.Eval)(A, B::Tracker.TrackedArray, args...) = Tracker.track(ev, A, B, args...)
11+
(ev::Tullio.Eval)(A::Tracker.TrackedArray, B::Tracker.TrackedArray, args...) = Tracker.track(ev, A, B, args...)
12+
13+
Tracker.@grad function (ev::Tullio.Eval)(args...)
14+
Z = ev.fwd(Tracker.data.(args)...)
15+
Z, Δ -> begin
16+
isnothing(ev.rev) && error("no gradient definition here!")
17+
tuple(ev.rev(Tracker.data(Δ), Z, Tracker.data.(args)...)...)
18+
end
19+
end
20+
21+
end

src/eval.jl

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,12 @@ Base.getindex(o::OneBox, i::Integer...) = o.val
3636
#========== gradient hooks ==========#
3737
# Macros like @adjoint need to be hidden behind include(), it seems:
3838

39-
using Requires
40-
4139
# @init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("grad/zygote.jl")
4240

43-
@init @require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("grad/tracker.jl")
44-
4541
# @init @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("grad/reverse.jl")
4642

43+
# Provided via en extension on 1.9+
44+
if !isdefined(Base, :get_extension)
4745
import ChainRulesCore
4846

4947
function ChainRulesCore.rrule(ev::Eval, args...)
@@ -56,11 +54,18 @@ function ChainRulesCore.rrule(ev::Eval, args...)
5654
tuple(ChainRulesCore.ZeroTangent(), dxs...)
5755
end
5856
end
57+
end
5958

60-
@init @require FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" begin
61-
using .FillArrays: Fill # used by Zygote
62-
Tullio.promote_storage(::Type{T}, ::Type{F}) where {T, F<:Fill} = T
63-
Tullio.promote_storage(::Type{F}, ::Type{T}) where {T, F<:Fill} = T
59+
if !isdefined(Base, :get_extension)
60+
using Requires
61+
end
62+
63+
@static if !isdefined(Base, :get_extension)
64+
function __init__()
65+
@require Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/TullioTrackerExt.jl")
66+
@require FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" include("../ext/TullioFillArraysExt.jl")
67+
@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" include("../ext/TullioCUDAExt.jl")
68+
end
6469
end
6570

6671
#========== vectorised gradients ==========#
@@ -92,25 +97,6 @@ end
9297
9398
=#
9499

95-
#========== CuArrays ==========#
96-
97-
using Requires
98-
99-
@init @require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin
100-
using .CUDA: CuArray, GPUArrays
101-
102-
Tullio.threader(fun!::F, ::Type{T},
103-
Z::AbstractArray, As::Tuple, Is::Tuple, Js::Tuple,
104-
redfun, block=0, keep=nothing) where {F<:Function, T<:CuArray} =
105-
fun!(T, Z, As..., Is..., Js..., keep)
106-
107-
Tullio.∇threader(fun!::F, ::Type{T},
108-
As::Tuple, Is::Tuple, Js::Tuple, block=0) where {F<:Function, T<:CuArray} =
109-
fun!(T, As..., Is..., Js...,)
110-
111-
# Tullio.thread_scalar ... ought to work? Was never fast.
112-
end
113-
114100
#========== storage unwrapper ==========#
115101

116102
"""

src/grad/tracker.jl

Lines changed: 0 additions & 14 deletions
This file was deleted.

0 commit comments

Comments
 (0)