Skip to content

Commit 6061f03

Browse files
Merge pull request #77 from JuliaDiffEq/hg/fix/convert
Replace Expr and parse overloads with convert
2 parents 829e2e4 + b80f051 commit 6061f03

File tree

9 files changed

+45
-54
lines changed

9 files changed

+45
-54
lines changed

src/differentials.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ end
55
Differential(x) = Differential(x,1)
66

77
Base.show(io::IO, D::Differential) = print(io,"($(D.x),$(D.order))")
8-
Base.Expr(D::Differential) = D
8+
Base.convert(::Type{Expr}, D::Differential) = D
99

1010
function Derivative end
1111
(D::Differential)(x::Operation) = Operation(D, Expression[x])
@@ -48,7 +48,7 @@ for (modu, fun, arity) ∈ DiffRules.diffrules()
4848
M, f = $(modu, fun)
4949
partials = DiffRules.diffrule(M, f, args...)
5050
dx = @static $arity == 1 ? partials : partials[$i]
51-
parse(Operation,dx)
51+
convert(Expression, dx)
5252
end
5353
end
5454
end

src/function_registration.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,3 @@
1-
# Literals treated as constants
2-
function Base.convert(::Type{Expression}, n::Number)
3-
if !(typeof(n) <: Expression)
4-
return Constant(n)
5-
else
6-
return n
7-
end
8-
end
9-
101
# Register functions and handle literals
112
macro register(sig)
123
splitsig = splitdef(:($sig = nothing))

src/operations.jl

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,10 @@ Base.:(==)(x::Nothing,y::Operation) = false
1515
Base.:(==)(x::Variable,y::Operation) = false
1616
Base.:(==)(x::Operation,y::Variable) = false
1717

18-
function Base.Expr(O::Operation)
19-
Expr(:call, Symbol(O.op), Expr.(O.args)...)
20-
end
21-
22-
Base.show(io::IO,O::Operation) = print(io,string(Expr(O)))
18+
Base.convert(::Type{Expr}, O::Operation) =
19+
build_expr(:call, Any[Symbol(O.op); convert.(Expr, O.args)])
20+
Base.show(io::IO, O::Operation) = print(io, convert(Expr, O))
2321

24-
# Bigger printing
25-
# Is there a way to just not have this as the default?
26-
function Base.parse(::Type{Operation},ex::Expr)
27-
f = ex.args[1]
28-
operands = ex.args[2:end]
29-
if ex.head == :call && any(x -> x isa Expr, ex.args)
30-
args = Expression[parse(Operation,o) for o in operands]
31-
parse(Operation, f, args)
32-
else
33-
parse(Operation, f, Expression[parse(Operation,o) for o in operands])
34-
end
35-
end
36-
Base.parse(::Type{Operation},x::Expression) = x
37-
Base.parse(::Type{Operation},sym::Symbol,args) = Operation(eval(sym), args)
38-
Base.parse(::Type{Operation},x::Union{Symbol, Number}) = x
3922

4023
"""
4124
find_replace(O::Operation,x::Variable,y::Expression)

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ function generate_ode_jacobian(sys::DiffEqSystem,simplify=true)
9595
diff_exprs = sys.eqs[diff_idxs]
9696
jac = calculate_jacobian(sys,simplify)
9797
sys.jac = jac
98-
jac_exprs = [:(J[$i,$j] = $(Expr(jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
98+
jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
9999
exprs = vcat(var_exprs,param_exprs,vec(jac_exprs))
100100
block = expr_arr_to_block(exprs)
101101
:((J,u,p,t)->$(block))
@@ -125,11 +125,11 @@ function generate_ode_iW(sys::DiffEqSystem,simplify=true)
125125
iW_t = simplify_constants.(iW_t)
126126
end
127127

128-
iW_exprs = [:(iW[$i,$j] = $(Expr(iW[i,j]))) for i in 1:size(iW,1), j in 1:size(iW,2)]
128+
iW_exprs = [:(iW[$i,$j] = $(convert(Expr, iW[i,j]))) for i in 1:size(iW,1), j in 1:size(iW,2)]
129129
exprs = vcat(var_exprs,param_exprs,vec(iW_exprs))
130130
block = expr_arr_to_block(exprs)
131131

132-
iW_t_exprs = [:(iW[$i,$j] = $(Expr(iW_t[i,j]))) for i in 1:size(iW_t,1), j in 1:size(iW_t,2)]
132+
iW_t_exprs = [:(iW[$i,$j] = $(convert(Expr, iW_t[i,j]))) for i in 1:size(iW_t,1), j in 1:size(iW_t,2)]
133133
exprs = vcat(var_exprs,param_exprs,vec(iW_t_exprs))
134134
block2 = expr_arr_to_block(exprs)
135135
:((iW,u,p,gam,t)->$(block)),:((iW,u,p,gam,t)->$(block2))

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ function generate_nlsys_jacobian(sys::NonlinearSystem,simplify=true)
5959
var_exprs = [:($(sys.vs[i].name) = u[$i]) for i in 1:length(sys.vs)]
6060
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
6161
jac = calculate_jacobian(sys,simplify)
62-
jac_exprs = [:(J[$i,$j] = $(Expr(jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
62+
jac_exprs = [:(J[$i,$j] = $(convert(Expr, jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
6363
exprs = vcat(var_exprs,param_exprs,vec(jac_exprs))
6464
block = expr_arr_to_block(exprs)
6565
:((J,u,p,t)->$(block))

src/utils.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
11
using MacroTools
2-
function expr_arr_to_block(exprs)
3-
block = :(begin end)
4-
foreach(expr -> push!(block.args, expr), exprs)
5-
block
2+
3+
4+
function Base.convert(::Type{Expression}, ex::Expr)
5+
ex.head === :call || throw(ArgumentError("internal representation does not support non-call Expr"))
6+
7+
op = eval(ex.args[1]) # HACK
8+
args = convert.(Expression, ex.args[2:end])
9+
10+
return Operation(op, args)
11+
end
12+
Base.convert(::Type{Expression}, x::Expression) = x
13+
Base.convert(::Type{Expression}, x::Number) = Constant(x)
14+
15+
function build_expr(head::Symbol, args)
16+
ex = Expr(head)
17+
append!(ex.args, args)
18+
ex
619
end
20+
expr_arr_to_block(exprs) = build_expr(:block, exprs)
721

822
# used in parsing
923
isblock(x) = length(x) == 1 && x[1] isa Expr && x[1].head == :block
@@ -17,7 +31,7 @@ function flatten_expr!(x)
1731
x
1832
end
1933

20-
toexpr(ex) = MacroTools.postwalk(x->x isa Union{Expression,Operation} ? Expr(x) : x, ex)
34+
toexpr(ex) = MacroTools.postwalk(x -> isa(x, Expression) ? convert(Expr, x) : x, ex)
2135

2236
is_constant(x::Variable) = x.subtype === :Constant
2337
is_constant(::Any) = false

src/variables.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ function Base.:(==)(x::Number,y::Variable)
8686
Constant(x) == y
8787
end
8888

89-
function Base.Expr(x::Variable)
89+
function Base.convert(::Type{Expr}, x::Variable)
9090
if x.subtype == :Constant
9191
return x.value
9292
elseif x.diff == nothing

test/internal.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@ using Test
66
@Param b
77
@DVar x(t)
88
@Var y
9-
@test parse(Operation, 2) == 2
9+
@test convert(Expression, 2) == 2
1010
expr = :(-inv(2sqrt(+($a, $b))))
1111
op = Operation(-, [Operation(inv,
1212
[Operation(*, [2, Operation(sqrt,
1313
[Operation(+, [a, b])])])])])
14-
@test parse(Operation, expr) == op
14+
@test convert(Expression, expr) == op
1515
expr1 = :($x^($y-1))
1616
op1 = Operation(^, [x, Operation(-, [y, 1])])
17-
@test parse(Operation, expr1) == op1
17+
@test convert(Expression, expr1) == op1
18+
19+
@test_throws ArgumentError convert(Expression, :([a, b]))
20+
@test_throws ArgumentError convert(Expression, :(a ? b : c))

test/variable_parsing.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Test
44
@Var a=1.0 b
55
a1 = Variable(:a,1.0)
66
@test a1 == a
7-
@test Expr(a) == :a
7+
@test convert(Expr, a) == :a
88

99
@Var begin
1010
a = 1.0
@@ -21,9 +21,9 @@ z1 = DependentVariable(:z,dependents = [t])
2121
@test x1 == x
2222
@test y1 == y
2323
@test z1 == z
24-
@test Expr(x) == :x
25-
@test Expr(y) == :y
26-
@test Expr(z) == :z
24+
@test convert(Expr, x) == :x
25+
@test convert(Expr, y) == :y
26+
@test convert(Expr, z) == :z
2727

2828
@IVar begin
2929
t
@@ -33,19 +33,19 @@ t1 = IndependentVariable(:t)
3333
s1 = IndependentVariable(:s, cos(2.5))
3434
@test t1 == t
3535
@test s1 == s
36-
@test Expr(t) == :t
37-
@test Expr(s) == :s
38-
@test Expr(cos(t + sin(s))) == :(cos(t + sin(s)))
36+
@test convert(Expr, t) == :t
37+
@test convert(Expr, s) == :s
38+
@test convert(Expr, cos(t + sin(s))) == :(cos(t + sin(s)))
3939

4040
@Deriv D''~t
4141
D1 = Differential(t, 2)
4242
@test D1 == D
43-
@test Expr(D) == D
43+
@test convert(Expr, D) == D
4444

4545
@Const c=0 v=2
4646
c1 = Constant(0)
4747
v1 = Constant(2)
4848
@test c1 == c
4949
@test v1 == v
50-
@test Expr(c) == 0
51-
@test Expr(v) == 2
50+
@test convert(Expr, c) == 0
51+
@test convert(Expr, v) == 2

0 commit comments

Comments
 (0)