@@ -102,32 +102,32 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.
102
102
u0 = nothing ;
103
103
version = nothing , tgrad= false , sparse = false ,
104
104
jac = false , Wfact = false , kwargs... ) where {iip}
105
- f_oop,f_iip = generate_function (sys, dvs, ps; expression= Val{false }, kwargs... )
106
- g_oop,g_iip = generate_diffusion_function (sys, dvs, ps; expression= Val{false }, kwargs... )
105
+ f_oop,f_iip = ModelingToolkit . eval .( generate_function (sys, dvs, ps; expression= Val{true }, kwargs... ) )
106
+ g_oop,g_iip = ModelingToolkit . eval .( generate_diffusion_function (sys, dvs, ps; expression= Val{true }, kwargs... ) )
107
107
108
108
f (u,p,t) = f_oop (u,p,t)
109
109
f (du,u,p,t) = f_iip (du,u,p,t)
110
110
g (u,p,t) = g_oop (u,p,t)
111
111
g (du,u,p,t) = g_iip (du,u,p,t)
112
112
113
113
if tgrad
114
- tgrad_oop,tgrad_iip = generate_tgrad (sys, dvs, ps; expression= Val{false }, kwargs... )
114
+ tgrad_oop,tgrad_iip = ModelingToolkit . eval .( generate_tgrad (sys, dvs, ps; expression= Val{true }, kwargs... ) )
115
115
_tgrad (u,p,t) = tgrad_oop (u,p,t)
116
116
_tgrad (J,u,p,t) = tgrad_iip (J,u,p,t)
117
117
else
118
118
_tgrad = nothing
119
119
end
120
120
121
121
if jac
122
- jac_oop,jac_iip = generate_jacobian (sys, dvs, ps; expression= Val{false }, sparse= sparse, kwargs... )
122
+ jac_oop,jac_iip = ModelingToolkit . eval .( generate_jacobian (sys, dvs, ps; expression= Val{true }, sparse= sparse, kwargs... ) )
123
123
_jac (u,p,t) = jac_oop (u,p,t)
124
124
_jac (J,u,p,t) = jac_iip (J,u,p,t)
125
125
else
126
126
_jac = nothing
127
127
end
128
128
129
129
if Wfact
130
- tmp_Wfact,tmp_Wfact_t = generate_factorized_W (sys, dvs, ps, true ; expression= Val{false }, kwargs... )
130
+ tmp_Wfact,tmp_Wfact_t = ModelingToolkit . eval .( generate_factorized_W (sys, dvs, ps, true ; expression= Val{true }, kwargs... ) )
131
131
Wfact_oop, Wfact_iip = tmp_Wfact
132
132
Wfact_oop_t, Wfact_iip_t = tmp_Wfact_t
133
133
_Wfact (u,p,dtgamma,t) = Wfact_oop (u,p,dtgamma,t)
@@ -141,10 +141,11 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.
141
141
M = calculate_massmatrix (sys)
142
142
_M = (u0 === nothing || M == I) ? M : ArrayInterface. restructure (u0 .* u0' ,M)
143
143
144
- SDEFunction {iip} (f,g,jac= _jac,
145
- tgrad = _tgrad,
146
- Wfact = _Wfact,
147
- Wfact_t = _Wfact_t,
144
+ SDEFunction {iip} (DiffEqBase. EvalFunc (f),DiffEqBase. EvalFunc (g),
145
+ jac = _jac === nothing ? nothing : DiffEqBase. EvalFunc (_jac),
146
+ tgrad = _tgrad === nothing ? nothing : DiffEqBase. EvalFunc (_tgrad),
147
+ Wfact = _Wfact === nothing ? nothing : DiffEqBase. EvalFunc (_Wfact),
148
+ Wfact_t = _Wfact_t === nothing ? nothing : DiffEqBase. EvalFunc (_Wfact_t),
148
149
mass_matrix = _M,
149
150
syms = Symbol .(sys. states))
150
151
end
0 commit comments