@@ -36,14 +36,12 @@ Base.getindex(o::OneBox, i::Integer...) = o.val
36
36
#= ========= gradient hooks ==========#
37
37
# Macros like @adjoint need to be hidden behind include(), it seems:
38
38
39
- using Requires
40
-
41
39
# @init @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include("grad/zygote.jl")
42
40
43
- @init @require Tracker = " 9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include (" grad/tracker.jl" )
44
-
45
41
# @init @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("grad/reverse.jl")
46
42
43
+ # Provided via en extension on 1.9+
44
+ if ! isdefined (Base, :get_extension )
47
45
import ChainRulesCore
48
46
49
47
function ChainRulesCore. rrule (ev:: Eval , args... )
@@ -56,11 +54,18 @@ function ChainRulesCore.rrule(ev::Eval, args...)
56
54
tuple (ChainRulesCore. ZeroTangent (), dxs... )
57
55
end
58
56
end
57
+ end
59
58
60
- @init @require FillArrays = " 1a297f60-69ca-5386-bcde-b61e274b549b" begin
61
- using . FillArrays: Fill # used by Zygote
62
- Tullio. promote_storage (:: Type{T} , :: Type{F} ) where {T, F<: Fill } = T
63
- Tullio. promote_storage (:: Type{F} , :: Type{T} ) where {T, F<: Fill } = T
59
+ if ! isdefined (Base, :get_extension )
60
+ using Requires
61
+ end
62
+
63
+ @static if ! isdefined (Base, :get_extension )
64
+ function __init__ ()
65
+ @require Tracker = " 9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include (" ../ext/TullioTrackerExt.jl" )
66
+ @require FillArrays = " 1a297f60-69ca-5386-bcde-b61e274b549b" include (" ../ext/TullioFillArraysExt.jl" )
67
+ @require CUDA = " 052768ef-5323-5732-b1bb-66c8b64840ba" include (" ../ext/TullioCUDAExt.jl" )
68
+ end
64
69
end
65
70
66
71
#= ========= vectorised gradients ==========#
92
97
93
98
=#
94
99
95
- #= ========= CuArrays ==========#
96
-
97
- using Requires
98
-
99
- @init @require CUDA = " 052768ef-5323-5732-b1bb-66c8b64840ba" begin
100
- using . CUDA: CuArray, GPUArrays
101
-
102
- Tullio. threader (fun!:: F , :: Type{T} ,
103
- Z:: AbstractArray , As:: Tuple , Is:: Tuple , Js:: Tuple ,
104
- redfun, block= 0 , keep= nothing ) where {F<: Function , T<: CuArray } =
105
- fun! (T, Z, As... , Is... , Js... , keep)
106
-
107
- Tullio.∇threader (fun!:: F , :: Type{T} ,
108
- As:: Tuple , Is:: Tuple , Js:: Tuple , block= 0 ) where {F<: Function , T<: CuArray } =
109
- fun! (T, As... , Is... , Js... ,)
110
-
111
- # Tullio.thread_scalar ... ought to work? Was never fast.
112
- end
113
-
114
100
#= ========= storage unwrapper ==========#
115
101
116
102
"""
0 commit comments