Skip to content

Commit ec4cd21

Browse files
fix issues with no tgrads
1 parent 23edc27 commit ec4cd21

File tree

4 files changed

+33
-10
lines changed

4 files changed

+33
-10
lines changed

src/ParameterizedFunctions.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ module ParameterizedFunctions
1616
delete!(ENV,"symengine_jl_safe_failure")
1717
end
1818

19-
using DataStructures, DiffEqBase, SimpleTraits, LinearAlgebra
19+
using DataStructures, DiffEqBase, SimpleTraits
20+
21+
import LinearAlgebra
2022

2123
import Base: getindex
2224

src/macros.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ macro ode_def(name,ex,params...)
33
:build_tgrad => true,
44
:build_jac => true,
55
:build_expjac => false,
6-
:build_invjac => true,
6+
:build_invjac => false,
77
:build_invW => true,
88
:build_hes => false,
99
:build_invhes => false,
@@ -31,7 +31,7 @@ macro ode_def_nohes(name,ex,params...)
3131
:build_tgrad => true,
3232
:build_jac => true,
3333
:build_expjac => false,
34-
:build_invjac => true,
34+
:build_invjac => false,
3535
:build_invW => true,
3636
:build_hes => false,
3737
:build_invhes => false,
@@ -45,7 +45,7 @@ macro ode_def_noinvhes(name,ex,params...)
4545
:build_tgrad => true,
4646
:build_jac => true,
4747
:build_expjac => false,
48-
:build_invjac => true,
48+
:build_invjac => false,
4949
:build_invW => true,
5050
:build_hes => false,
5151
:build_invhes => false,

src/ode_def_opts.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;_M=I,depvar=:t)
1+
function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;depvar=:t)
22
# depvar is the dependent variable. Defaults to t
33
# M is the mass matrix in RosW, must be a constant!
4+
45
origex = copy(ex) # Save the original expression
56

67
if !(eltype(params) <: Symbol)
@@ -42,7 +43,7 @@ function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;_M
4243

4344
numsyms = length(indvar_dict)
4445
numparams = length(params)
45-
M = Matrix(1*_M,numsyms,numsyms)
46+
M = Matrix(1*LinearAlgebra.I,numsyms,numsyms)
4647
# Parameter Functions
4748
paramfuncs = Vector{Vector{Expr}}(undef, numparams)
4849
for i in 1:numparams
@@ -52,8 +53,10 @@ function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;_M
5253
end
5354
paramfuncs[i] = tmp_pfunc
5455
end
56+
5557
pfuncs = build_p_funcs(paramfuncs,indvar_dict,params)
5658

59+
5760
# Symbolic Setup
5861
symfuncs = Vector{SymEngine.Basic}(undef, 0)
5962
symtgrad = Vector{SymEngine.Basic}(undef, 0)
@@ -88,6 +91,7 @@ function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;_M
8891
param_symjac = Matrix{SymEngine.Basic}(undef,numsyms,numparams)
8992
pderiv_exists = false
9093

94+
9195
if opts[:build_tgrad] || opts[:build_jac] || opts[:build_dpfuncs]
9296
try #do symbolic calculations
9397

@@ -105,6 +109,7 @@ function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;_M
105109
tgradex = build_tgrad_func(symtgrad,indvar_dict,params)
106110
catch err
107111
@warn("Time Derivative Gradient could not be built")
112+
symtgrad = Vector{SymEngine.Basic}(undef, 0)
108113
end
109114
end
110115

@@ -118,11 +123,15 @@ function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;_M
118123
end
119124
end
120125

126+
121127
# Build the Julia function
122128
Jex = build_jac_func(symjac,indvar_dict,params)
123129
bad_derivative(Jex)
124130
jac_exists = true
125131

132+
133+
134+
126135
if opts[:build_expjac]
127136
try
128137
expjac = exp*symjac) # This does not work, which is why disabled
@@ -145,11 +154,12 @@ function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;_M
145154
@warn("Jacobian could not invert")
146155
end
147156
end
157+
148158
if opts[:build_invW]
149159
try # Rosenbrock-W Inverse
150-
L = Base.convert(SymEngine.CDenseMatrix,M - γ*symjac)
160+
L = Base.convert(SymEngine.CDenseMatrix,1LinearAlgebra.I - γ*symjac)
151161
syminvW = inv(L)
152-
L = Base.convert(SymEngine.CDenseMatrix,M/γ - symjac)
162+
L = Base.convert(SymEngine.CDenseMatrix,1LinearAlgebra.I/γ - symjac)
153163
syminvW_t = inv(L)
154164
invWex = build_jac_func(syminvW,indvar_dict,params)
155165
bad_derivative(invWex)
@@ -158,6 +168,7 @@ function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;_M
158168
bad_derivative(invWex_t)
159169
invW_t_exists = true
160170
catch err
171+
throw(err)
161172
@warn("Rosenbrock-W could not invert")
162173
end
163174
end
@@ -167,6 +178,7 @@ function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;_M
167178
for i in eachindex(funcs), j in eachindex(syms)
168179
symhes[i,j] = diff(symjac[i,j],syms[j])
169180
end
181+
170182
# Build the Julia function
171183
Hex = build_jac_func(symhes,indvar_dict,params)
172184
bad_derivative(Hex)
@@ -187,7 +199,9 @@ function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;_M
187199
end
188200
end
189201
catch err
202+
throw(err)
190203
@warn("Failed to build the Jacobian. This means the Hessian is not built as well.")
204+
symjac = Matrix{SymEngine.Basic}(undef, 0,0)
191205
end
192206
end # End Jacobian tree
193207

@@ -275,7 +289,6 @@ function ode_def_opts(name::Symbol,opts::Dict{Symbol,Bool},ex::Expr,params...;_M
275289
else
276290
param_jac_expr = :(nothing)
277291
end
278-
279292
# Build the type
280293
exprs = Vector{Expr}(undef, 0)
281294

test/runtests.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ f_t3 = @ode_def_noinvjac ExprCheck begin # Checks for error due to symbol on 1
2323
dy = -c*y + d*x*y
2424
end a b c d # Change to π after unicode fix
2525

26-
f = @ode_def_noinvhes LotkaVolterra begin
26+
f = @ode_def LotkaVolterra begin
2727
dx = a*x - b*x*y
2828
dy = -c*y + d*x*y
2929
end a b c d
@@ -120,6 +120,14 @@ f = @ode_def begin
120120
end a b c d
121121
@test_nowarn f([0.1,0.2], [1,2], [1,2,3,4], 1)
122122

123+
# Test failures of derivatives should not have #undef
124+
sir_ode = @ode_def SIRModel begin
125+
dS = -b*S*I
126+
dI = b*S*I - g*I
127+
dR = g*I
128+
end b g
129+
130+
123131
println("Make the problems in the problem library build")
124132

125133
using DiffEqProblemLibrary

0 commit comments

Comments
 (0)