Skip to content

Commit 89e47eb

Browse files
Merge pull request #531 from vpuri3/ad
overload _alg_autodiff
2 parents 5dc34a5 + e7ffcc5 commit 89e47eb

File tree

5 files changed

+11
-9
lines changed

5 files changed

+11
-9
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
2525
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2626
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2727
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
28+
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
2829
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2930
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
3031
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -44,7 +45,7 @@ JumpProcesses = "9"
4445
LevyArea = "1.0.0"
4546
MuladdMacro = "0.2.1"
4647
NLsolve = "4"
47-
OrdinaryDiffEq = "6.4"
48+
OrdinaryDiffEq = "6.52"
4849
RandomNumbers = "1.5.3"
4950
RecursiveArrayTools = "2"
5051
Reexport = "0.2, 1.0"

src/StochasticDiffEq.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ using DocStringExtensions
2424
import DiffEqBase: ODE_DEFAULT_NORM, ODE_DEFAULT_ISOUTOFDOMAIN,
2525
ODE_DEFAULT_PROG_MESSAGE, ODE_DEFAULT_UNSTABLE_CHECK
2626

27-
using DiffEqBase: DiffEqArrayOperator
27+
using SciMLOperators: MatrixOperator
2828

2929
using DiffEqBase: TimeGradientWrapper, UJacobianWrapper, TimeDerivativeWrapper, UDerivativeWrapper
3030

src/alg_utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,10 @@ alg_needs_extra_process(alg::PL1WM) = true
260260
alg_needs_extra_process(alg::NON) = true
261261
alg_needs_extra_process(alg::NON2) = true
262262

263-
OrdinaryDiffEq.alg_autodiff(alg::StochasticDiffEqNewtonAlgorithm{CS,AD,FDT,ST,CJ,Controller}) where {CS,AD,FDT,ST,CJ,Controller} = AD
264-
OrdinaryDiffEq.alg_autodiff(alg::StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,CJ,Controller}) where {CS,AD,FDT,ST,CJ,Controller} = AD
265-
OrdinaryDiffEq.alg_autodiff(alg::StochasticDiffEqJumpNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,CJ,Controller}) where {CS,AD,FDT,ST,CJ,Controller} = AD
266-
OrdinaryDiffEq.alg_autodiff(alg::StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm{CS,AD,FDT,ST,CJ,Controller}) where {CS,AD,FDT,ST,CJ,Controller} = AD
263+
OrdinaryDiffEq._alg_autodiff(alg::StochasticDiffEqNewtonAlgorithm{CS,AD,FDT,ST,CJ,Controller}) where {CS,AD,FDT,ST,CJ,Controller} = Val{AD}()
264+
OrdinaryDiffEq._alg_autodiff(alg::StochasticDiffEqNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,CJ,Controller}) where {CS,AD,FDT,ST,CJ,Controller} = Val{AD}()
265+
OrdinaryDiffEq._alg_autodiff(alg::StochasticDiffEqJumpNewtonAdaptiveAlgorithm{CS,AD,FDT,ST,CJ,Controller}) where {CS,AD,FDT,ST,CJ,Controller} = Val{AD}()
266+
OrdinaryDiffEq._alg_autodiff(alg::StochasticDiffEqJumpNewtonDiffusionAdaptiveAlgorithm{CS,AD,FDT,ST,CJ,Controller}) where {CS,AD,FDT,ST,CJ,Controller} = Val{AD}()
267267

268268
OrdinaryDiffEq.get_current_alg_autodiff(alg::StochasticDiffEqCompositeAlgorithm, cache) = OrdinaryDiffEq.alg_autodiff(alg.algs[cache.current])
269269

src/integrators/integrator_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ function oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
418418
# get the operator
419419
J = islin ? nf.f : f.jac(uprev, p, t)
420420
if !isa(J, DiffEqBase.AbstractDiffEqLinearOperator)
421-
J = DiffEqArrayOperator(J)
421+
J = MatrixOperator(J)
422422
end
423423
W = WOperator{false}(f.mass_matrix, dt, J, u)
424424
else

test/utility_tests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using StochasticDiffEq, LinearAlgebra, SparseArrays, Random, LinearSolve, Test
22
using StochasticDiffEq.OrdinaryDiffEq: WOperator, set_gamma!, calc_W!
3+
using StochasticDiffEq.SciMLOperators: MatrixOperator
34

45
@testset "Derivative Utilities" begin
56
@testset "calc_W!" begin
@@ -25,7 +26,7 @@ using StochasticDiffEq.OrdinaryDiffEq: WOperator, set_gamma!, calc_W!
2526
_f = (du,u,p,t) -> mul!(du,A,u); _g = (du,u,p,t) -> mul!(du,σ,u)
2627
fun = SDEFunction(_f, _g;
2728
mass_matrix=mm,
28-
jac_prototype=DiffEqArrayOperator(A))
29+
jac_prototype=MatrixOperator(A))
2930
prob = SDEProblem(fun, _g, u0, tspan)
3031
integrator = init(prob, ImplicitEM(theta=1); adaptive=false, dt=dt)
3132
W = integrator.cache.nlsolver.cache.W
@@ -52,7 +53,7 @@ using StochasticDiffEq.OrdinaryDiffEq: WOperator, set_gamma!, calc_W!
5253
prob1 = SDEProblem(SDEFunction(_f, _g; mass_matrix=mm), _g, u0, tspan)
5354
prob2 = SDEProblem(SDEFunction(_f, _g; mass_matrix=mm, jac=(u,p,t) -> t*A), _g, u0, tspan)
5455
prob1_ip = SDEProblem(SDEFunction(_f_ip, _g_ip; mass_matrix=mm), _g_ip, u0, tspan)
55-
jac_prototype=DiffEqArrayOperator(similar(A); update_func=(J,u,p,t) -> (J .= t .* A; J))
56+
jac_prototype=MatrixOperator(similar(A); update_func! = (J,u,p,t) -> (J .= t .* A; J))
5657
prob2_ip = SDEProblem(SDEFunction(_f_ip, _g_ip; mass_matrix=mm, jac_prototype=jac_prototype), _g_ip, u0, tspan)
5758

5859
for Alg in [ImplicitEM, ISSEM]

0 commit comments

Comments
 (0)