Skip to content

Commit 236a0fc

Browse files
fix function wrapping
1 parent 004d95f commit 236a0fc

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

src/functionwrapper.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,22 @@ end
88
(g::SDEDiffusionTermWrapper{true})(du, u, p, t) = g.g(du, u, g.h, p, t)
99
(g::SDEDiffusionTermWrapper{false})(u, p, t) = g.g(u, g.h, p, t)
1010

11-
struct SDEFunctionWrapper{iip,F,G,H,TMM,Ta,Tt,TJ,JP,TW,TWt,TPJ,S,TCV} <: DiffEqBase.AbstractRODEFunction{iip}
11+
struct SDEFunctionWrapper{iip,F,G,H,TMM,Ta,Tt,TJ,JVP,VJP,JP,SP,TW,TWt,TPJ,GG,S,TCV} <: DiffEqBase.AbstractRODEFunction{iip}
1212
f::F
1313
g::G
1414
h::H
1515
mass_matrix::TMM
1616
analytic::Ta
1717
tgrad::Tt
1818
jac::TJ
19+
jvp::JVP
20+
vjp::VJP
1921
jac_prototype::JP
22+
sparsity::SP
2023
Wfact::TW
2124
Wfact_t::TWt
2225
paramjac::TPJ
26+
ggprime::GG
2327
syms::S
2428
colorvec::TCV
2529
end
@@ -30,7 +34,7 @@ end
3034

3135
function wrap_functions_and_history(f::SDDEFunction, g, h)
3236
gwh = SDEDiffusionTermWrapper{isinplace(g,5),typeof(g),typeof(h)}(g,h)
33-
37+
3438
if f.jac === nothing
3539
jac = nothing
3640
else
@@ -46,10 +50,12 @@ function wrap_functions_and_history(f::SDDEFunction, g, h)
4650
end
4751

4852
SDEFunctionWrapper{isinplace(f),typeof(f.f),typeof(gwh),typeof(h),typeof(f.mass_matrix),
49-
typeof(f.analytic),typeof(f.tgrad),typeof(jac),
50-
typeof(f.jac_prototype),typeof(f.Wfact),typeof(f.Wfact_t),
51-
typeof(f.paramjac),typeof(f.syms),typeof(f.colorvec)}(
53+
typeof(f.analytic),typeof(f.tgrad),typeof(jac),typeof(f.jvp),
54+
typeof(f.vjp),typeof(f.jac_prototype),typeof(f.sparsity),
55+
typeof(f.Wfact),typeof(f.Wfact_t),typeof(f.paramjac),
56+
typeof(f.ggprime),typeof(f.syms),typeof(f.colorvec)}(
5257
f.f, gwh, h, f.mass_matrix, f.analytic, f.tgrad, jac,
53-
f.jac_prototype, f.Wfact, f.Wfact_t, f.paramjac, f.syms,
54-
f.colorvec), gwh
55-
end
58+
f.jvp, f.vjp,
59+
f.jac_prototype, f.sparsity, f.Wfact, f.Wfact_t,
60+
f.paramjac, f.ggprime, f.syms, f.colorvec), gwh
61+
end

0 commit comments

Comments
 (0)