diff --git a/Project.toml b/Project.toml index b7fb79a23b..b54e3c5e3d 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -28,7 +27,6 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Adapt = "3" @@ -39,7 +37,6 @@ DiffEqBase = "6.41" Distributions = "0.23, 0.24, 0.25" DistributionsAD = "0.6" Flux = "0.12, 0.13" -ForwardDiff = "0.10" Functors = "0.4" LoggingExtras = "0.4, 1" Lux = "0.4" @@ -50,5 +47,4 @@ SciMLBase = "1" SciMLSensitivity = "7" TerminalLoggers = "0.1" Zygote = "0.5, 0.6" -ZygoteRules = "0.2" julia = "1.5" diff --git a/src/DiffEqFlux.jl b/src/DiffEqFlux.jl index 03ae08f191..9daab235ad 100644 --- a/src/DiffEqFlux.jl +++ b/src/DiffEqFlux.jl @@ -2,9 +2,9 @@ module DiffEqFlux using Adapt, Base.Iterators, ConsoleProgressMonitor, DataInterpolations, DiffEqBase, Distributions, DistributionsAD, - ForwardDiff, LinearAlgebra, Lux, + LinearAlgebra, Lux, Logging, LoggingExtras, Printf, ProgressLogging, Random, RecursiveArrayTools, - Reexport, SciMLBase, TerminalLoggers, Zygote, ZygoteRules + Reexport, SciMLBase, TerminalLoggers, Zygote @reexport using SciMLSensitivity @reexport using Flux @@ -14,24 +14,6 @@ import ChainRulesCore gpu_or_cpu(x) = Array -# ForwardDiff integration - -ZygoteRules.@adjoint function ForwardDiff.Dual{T}(x, ẋ::Tuple) where {T} - @assert length(ẋ) == 1 - ForwardDiff.Dual{T}(x, ẋ), ḋ -> (ḋ.partials[1], (ḋ.value,)) -end - -ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:partials}) where {T} = - d.partials, ṗ -> (ForwardDiff.Dual{T}(ṗ[1], 0),) - -ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:value}) where {T} = - d.value, ẋ -> (ForwardDiff.Dual{T}(0, ẋ),) - -ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:dl}) = A.dl, y -> Tridiagonal(dl, zeros(length(d)), zeros(length(du)),) -ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:d}) = A.d, y -> Tridiagonal(zeros(length(dl)), d, zeros(length(du)),) -ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:du}) = A.dl, y -> Tridiagonal(zeros(length(dl)), zeros(length(d), du),) -ZygoteRules.@adjoint Tridiagonal(dl, d, du) = Tridiagonal(dl, d, du), p̄ -> (diag(p̄[2:end, 1:end-1]), diag(p̄), diag(p̄[1:end-1, 2:end])) - include("ffjord.jl") include("neural_de.jl") include("spline_layer.jl")