Skip to content

Commit 9b7c580

Browse files
committed
fix diff tests
1 parent f8279d8 commit 9b7c580

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

src/differentials.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x)
3737
_isfalse(occ::Bool) = occ === false
3838
_isfalse(occ::Term) = _isfalse(operation(occ))
3939

40-
function occursin_info(x, expr::Term)
40+
function occursin_info(x, expr)
41+
!istree(expr) && return false
4142
if isequal(x, expr)
4243
true
4344
else
@@ -52,16 +53,15 @@ function occursin_info(x, expr::Sym)
5253
isequal(x, expr)
5354
end
5455

55-
hasderiv(O::Term) = operation(O) isa Differential || any(hasderiv, arguments(O))
56-
hasderiv(O) = false
57-
58-
occursin_info(x, y) = false
56+
function hasderiv(O)
57+
istree(O) ? operation(O) isa Differential || any(hasderiv, arguments(O)) : false
58+
end
5959
"""
6060
$(SIGNATURES)
6161
6262
TODO
6363
"""
64-
function expand_derivatives(O::Symbolic, simplify=true; occurances=nothing)
64+
function expand_derivatives(O::Symbolic, simplify=false; occurances=nothing)
6565
if istree(O) && isa(operation(O), Differential)
6666
@assert length(arguments(O)) == 1
6767
arg = expand_derivatives(arguments(O)[1], false)
@@ -73,45 +73,45 @@ function expand_derivatives(O::Symbolic, simplify=true; occurances=nothing)
7373
_isfalse(occurances) && return 0
7474
occurances isa Bool && return 1 # means it's a `true`
7575

76-
(D, o) = (operation(O), arg)
76+
D = operation(O)
7777

78-
if !istree(o)
79-
return O # Cannot expand
80-
elseif isa(operation(o), Sym)
81-
return O # Cannot expand
82-
elseif isa(operation(o), Differential)
78+
if !istree(arg)
79+
return D(arg) # Cannot expand
80+
elseif isa(operation(arg), Sym)
81+
return D(arg) # Cannot expand
82+
elseif isa(operation(arg), Differential)
8383
# The recursive expand_derivatives was not able to remove
8484
# a nested Differential. We can attempt to differentiate the
8585
# inner expression wrt to the outer iv. And leave the
8686
# unexpandable Differential outside.
87-
if isequal(operation(o).x, D.x)
88-
return O
87+
if isequal(operation(arg).x, D.x)
88+
return D(arg)
8989
else
90-
inner = expand_derivatives(D(arguments(o)[1]), false)
90+
inner = expand_derivatives(D(arguments(arg)[1]), false)
9191
# if the inner expression is not expandable either, return
9292
if istree(inner) && operation(inner) isa Differential
93-
return O
93+
return D(arg)
9494
else
95-
return expand_derivatives(operation(o)(inner), simplify)
95+
return expand_derivatives(operation(arg)(inner), simplify)
9696
end
9797
end
9898
end
9999

100-
l = length(arguments(o))
100+
l = length(arguments(arg))
101101
exprs = []
102102
c = 0
103103

104104
for i in 1:l
105-
t2 = expand_derivatives(D(arguments(o)[i]),false, occurances=arguments(occurances)[i])
105+
t2 = expand_derivatives(D(arguments(arg)[i]),false, occurances=arguments(occurances)[i])
106106

107107
x = if _iszero(t2)
108108
t2
109109
elseif _isone(t2)
110-
d = derivative_idx(o, i)
111-
d isa NoDeriv ? D(o) : d
110+
d = derivative_idx(arg, i)
111+
d isa NoDeriv ? D(arg) : d
112112
else
113-
t1 = derivative_idx(o, i)
114-
t1 = t1 isa NoDeriv ? D(o) : t1
113+
t1 = derivative_idx(arg, i)
114+
t1 = t1 isa NoDeriv ? D(arg) : t1
115115
make_operation(*, [t1, t2])
116116
end
117117

0 commit comments

Comments
 (0)