Skip to content

Commit 4e957ad

Browse files
committed
refactor: move ForwardDiff to an ext
1 parent 4cea75c commit 4e957ad

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1919
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2020
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
2121
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
22+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2223
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2324

2425
[extensions]
@@ -27,6 +28,7 @@ NNlibCUDACUDNNExt = ["CUDA", "cuDNN"]
2728
NNlibCUDAExt = "CUDA"
2829
NNlibEnzymeCoreExt = "EnzymeCore"
2930
NNlibFFTWExt = "FFTW"
31+
NNlibForwardDiffExt = "ForwardDiff"
3032

3133
[compat]
3234
AMDGPU = "0.9.4, 1"
@@ -36,6 +38,7 @@ CUDA = "4, 5"
3638
ChainRulesCore = "1.13"
3739
EnzymeCore = "0.5, 0.6, 0.7"
3840
FFTW = "1.8.0"
41+
ForwardDiff = "0.10.36"
3942
GPUArraysCore = "0.1"
4043
KernelAbstractions = "0.9.2"
4144
LinearAlgebra = "<0.0.1, 1"

ext/NNlibForwardDiffExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module NNlibForwardDiffExt
2+
3+
using ForwardDiff: ForwardDiff
4+
using NNlib: NNlib
5+
6+
NNlib.within_gradient(x::ForwardDiff.Dual) = true
7+
NNlib.within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true
8+
9+
end

src/NNlib.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,6 @@ export upsample_nearest, ∇upsample_nearest,
8282
include("gather.jl")
8383
include("scatter.jl")
8484
include("utils.jl")
85-
@init @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
86-
using .ForwardDiff
87-
within_gradient(x::ForwardDiff.Dual) = true
88-
within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true
89-
end
9085

9186
include("sampling.jl")
9287
include("functions.jl")

0 commit comments

Comments
 (0)