Skip to content

Commit 5b88e12

Browse files
robustness
1 parent 262d19a commit 5b88e12

File tree

4 files changed

+29
-7
lines changed

4 files changed

+29
-7
lines changed

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ SymEngine 0.1.2
33
DataStructures 0.4.6
44
DiffEqBase 0.14.0
55
SimpleTraits 0.1.1
6+
Iterators

src/ParameterizedFunctions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ module ParameterizedFunctions
1414
delete!(ENV,"symengine_jl_safe_failure")
1515
end
1616

17-
using DataStructures, DiffEqBase, SimpleTraits
17+
using DataStructures, DiffEqBase, SimpleTraits, Iterators
1818
import Base: getindex
1919

2020
const FEM_SYMBOL_DICT = Dict{Symbol,Expr}(:x=>:(x[:,1]),:y=>:(x[:,2]),:z=>:(x[:,3]))

src/dict_build.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
function build_indvar_dict(ex)
22
indvar_dict = OrderedDict{Symbol,Int}()
3+
cur_sym = 0
34
for i in 2:2:length(ex.args) #Every odd line is line number
45
arg = ex.args[i].args[1] #Get the first thing, should be dsomething
5-
nodarg = Symbol(string(arg)[2:end]) #Take off the d
6-
if !haskey(indvar_dict,nodarg)
7-
s = string(arg)
8-
indvar_dict[Symbol(string(arg)[2:end])] = i/2 # and label it the next int if not seen before
6+
firstarg = Symbol(first(string(arg))) # Check for d
7+
if firstarg == :d
8+
nodarg = Symbol(join(drop(string(arg), 1))) # join(drop(s, 1)) is 2:end
9+
if !haskey(indvar_dict,nodarg)
10+
cur_sym += 1
11+
indvar_dict[nodarg] = cur_sym
12+
else
13+
error("The derivative term for $nodarg is repeated. This is not allowed.")
14+
end
915
end
1016
end
1117
syms = indvar_dict.keys

test/runtests.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ f_t2 = @ode_def_noinvjac SymCheck2 begin # Checks for error due to symbol on 1
1818
dy = -c*y + d*x*y*t^2
1919
end a=>1.5 b=>1 c=3 d=1
2020

21+
f_t3 = @ode_def_noinvjac ExprCheck begin # Checks for error due to symbol on 1
22+
dx = a*x - b*x*y
23+
dy = -c*y + d*x*y
24+
end a=>1.5 b=>2.0 c=t*x d=2pi # Change to π after unicode fix
25+
2126
f = @ode_def_noinvhes LotkaVolterra begin
2227
dx = a*x - b*x*y
2328
dy = -c*y + d*x*y
@@ -123,16 +128,26 @@ println("Test booleans")
123128

124129
@code_llvm has_paramjac(f)
125130

131+
println("Test difficult differentiable")
132+
NJ = @ode_def_nohes DiffDiff begin
133+
dx = a*x - b*x*y
134+
dy = -c*y + erf(x*y/d)
135+
end a=>1.5 b=>1 c=3 d=4
136+
NJ(t,u,du)
137+
@test du == [-3.0;-3*3.0 + erf(2.0*3.0/4)]
138+
@test du == NJ(t,u)
139+
# NJ(Val{:jac},t,u,J) # Currently gives E not defined, will be fixed by the next SymEgine
140+
126141
test_fail(x,y,d) = erf(x*y/d)
127142
println("Test non-differentiable")
128143
NJ = @ode_def NoJacTest begin
129144
dx = a*x - b*x*y
130-
dy = -c*y + erf(x*y/d)
145+
dy = -c*y + test_fail(x,y,d)
131146
end a=>1.5 b=>1 c=3 d=4
132147
NJ(t,u,du)
133148
@test du == [-3.0;-3*3.0 + erf(2.0*3.0/4)]
134149
@test du == NJ(t,u)
135-
150+
# NJ(Val{:jac},t,u,J) # Currently gives E not defined, will be fixed by the next SymEgine
136151
### FEM Macros
137152

138153
println("Test FEM")

0 commit comments

Comments
 (0)