Skip to content

Commit 89ac04e

Browse files
author
Karl Wessel
committed
make sure to expand differentials in subtrees only once
1 parent 0ea30bf commit 89ac04e

File tree

2 files changed

+109
-106
lines changed

2 files changed

+109
-106
lines changed

src/diff.jl

Lines changed: 105 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,109 @@ function recursive_hasoperator(op, O)
150150
end
151151
end
152152

153+
function executediff(D, arg, simplify=false; robust=false, occurrences=nothing)
154+
if robust || occurrences == nothing
155+
occurrences = occursin_info(D.x, arg)
156+
end
157+
158+
_isfalse(occurrences) && return 0
159+
occurrences isa Bool && return 1 # means it's a `true`
160+
161+
if !iscall(arg)
162+
return D(arg) # Cannot expand
163+
elseif (op = operation(arg); issym(op))
164+
inner_args = arguments(arg)
165+
if any(isequal(D.x), inner_args)
166+
return D(arg) # base case if any argument is directly equal to the i.v.
167+
else
168+
return sum(inner_args, init=0) do a
169+
return executediff(Differential(a), arg; robust) *
170+
executediff(D, a; robust)
171+
end
172+
end
173+
elseif op === (IfElse.ifelse)
174+
args = arguments(arg)
175+
O = op(args[1],
176+
executediff(D, args[2], simplify; robust, occurrences=arguments(occurrences)[2]),
177+
executediff(D, args[3], simplify; robust, occurrences=arguments(occurrences)[3]))
178+
return O
179+
elseif isa(op, Differential)
180+
# The recursive expand_derivatives was not able to remove
181+
# a nested Differential. We can attempt to differentiate the
182+
# inner expression wrt to the outer iv. And leave the
183+
# unexpandable Differential outside.
184+
if isequal(op.x, D.x)
185+
return D(arg)
186+
else
187+
inner = executediff(D, arguments(arg)[1], false; robust)
188+
# if the inner expression is not expandable either, return
189+
if iscall(inner) && operation(inner) isa Differential
190+
return D(arg)
191+
else
192+
return expand_derivatives(op(inner), simplify; robust) # TODO
193+
end
194+
end
195+
elseif isa(op, Integral)
196+
if isa(op.domain.domain, AbstractInterval)
197+
domain = op.domain.domain
198+
a, b = DomainSets.endpoints(domain)
199+
c = 0
200+
inner_function = expand_derivatives(arguments(arg)[1]; robust) # TODO
201+
if iscall(value(a))
202+
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a)))
203+
t2 = D(a)
204+
c -= t1*t2
205+
end
206+
if iscall(value(b))
207+
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(b)))
208+
t2 = D(b)
209+
c += t1*t2
210+
end
211+
inner = executediff(D, arguments(arg)[1]; robust)
212+
c += op(inner)
213+
return value(c)
214+
end
215+
end
216+
217+
inner_args = arguments(arg)
218+
l = length(inner_args)
219+
exprs = []
220+
c = 0
221+
222+
for i in 1:l
223+
t2 = executediff(D, inner_args[i],false; robust, occurrences=arguments(occurrences)[i])
224+
225+
x = if _iszero(t2)
226+
t2
227+
elseif _isone(t2)
228+
d = derivative_idx(arg, i)
229+
d isa NoDeriv ? D(arg) : d
230+
else
231+
t1 = derivative_idx(arg, i)
232+
t1 = t1 isa NoDeriv ? D(arg) : t1
233+
t1 * t2
234+
end
235+
236+
if _iszero(x)
237+
continue
238+
elseif x isa Symbolic
239+
push!(exprs, x)
240+
else
241+
c += x
242+
end
243+
end
244+
245+
if isempty(exprs)
246+
return c
247+
elseif length(exprs) == 1
248+
term = (simplify ? SymbolicUtils.simplify(exprs[1]) : exprs[1])
249+
return _iszero(c) ? term : c + term
250+
else
251+
x = +((!_iszero(c) ? vcat(c, exprs) : exprs)...)
252+
return simplify ? SymbolicUtils.simplify(x) : x
253+
end
254+
end
255+
153256
"""
154257
$(SIGNATURES)
155258
@@ -184,107 +287,6 @@ function expand_derivatives(O::Symbolic, simplify=false; robust=false, occurrenc
184287
if iscall(O) && isa(operation(O), Differential)
185288
arg = only(arguments(O))
186289
arg = expand_derivatives(arg, false; robust)
187-
188-
if robust || occurrences == nothing
189-
occurrences = occursin_info(operation(O).x, arg)
190-
end
191-
192-
_isfalse(occurrences) && return 0
193-
occurrences isa Bool && return 1 # means it's a `true`
194-
195-
D = operation(O)
196-
197-
if !iscall(arg)
198-
return D(arg) # Cannot expand
199-
elseif (op = operation(arg); issym(op))
200-
inner_args = arguments(arg)
201-
if any(isequal(D.x), inner_args)
202-
return D(arg) # base case if any argument is directly equal to the i.v.
203-
else
204-
return sum(inner_args, init=0) do a
205-
return expand_derivatives(Differential(a)(arg); robust) *
206-
expand_derivatives(D(a); robust)
207-
end
208-
end
209-
elseif op === (IfElse.ifelse)
210-
args = arguments(arg)
211-
O = op(args[1], D(args[2]), D(args[3]))
212-
return expand_derivatives(O, simplify; robust, occurrences)
213-
elseif isa(op, Differential)
214-
# The recursive expand_derivatives was not able to remove
215-
# a nested Differential. We can attempt to differentiate the
216-
# inner expression wrt to the outer iv. And leave the
217-
# unexpandable Differential outside.
218-
if isequal(op.x, D.x)
219-
return D(arg)
220-
else
221-
inner = expand_derivatives(D(arguments(arg)[1]), false; robust)
222-
# if the inner expression is not expandable either, return
223-
if iscall(inner) && operation(inner) isa Differential
224-
return D(arg)
225-
else
226-
return expand_derivatives(op(inner), simplify; robust)
227-
end
228-
end
229-
elseif isa(op, Integral)
230-
if isa(op.domain.domain, AbstractInterval)
231-
domain = op.domain.domain
232-
a, b = DomainSets.endpoints(domain)
233-
c = 0
234-
inner_function = expand_derivatives(arguments(arg)[1]; robust)
235-
if iscall(value(a))
236-
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a)))
237-
t2 = D(a)
238-
c -= t1*t2
239-
end
240-
if iscall(value(b))
241-
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(b)))
242-
t2 = D(b)
243-
c += t1*t2
244-
end
245-
inner = expand_derivatives(D(arguments(arg)[1]); robust)
246-
c += op(inner)
247-
return value(c)
248-
end
249-
end
250-
251-
inner_args = arguments(arg)
252-
l = length(inner_args)
253-
exprs = []
254-
c = 0
255-
256-
for i in 1:l
257-
t2 = expand_derivatives(D(inner_args[i]),false; robust, occurrences=arguments(occurrences)[i])
258-
259-
x = if _iszero(t2)
260-
t2
261-
elseif _isone(t2)
262-
d = derivative_idx(arg, i)
263-
d isa NoDeriv ? D(arg) : d
264-
else
265-
t1 = derivative_idx(arg, i)
266-
t1 = t1 isa NoDeriv ? D(arg) : t1
267-
t1 * t2
268-
end
269-
270-
if _iszero(x)
271-
continue
272-
elseif x isa Symbolic
273-
push!(exprs, x)
274-
else
275-
c += x
276-
end
277-
end
278-
279-
if isempty(exprs)
280-
return c
281-
elseif length(exprs) == 1
282-
term = (simplify ? SymbolicUtils.simplify(exprs[1]) : exprs[1])
283-
return _iszero(c) ? term : c + term
284-
else
285-
x = +((!_iszero(c) ? vcat(c, exprs) : exprs)...)
286-
return simplify ? SymbolicUtils.simplify(x) : x
287-
end
288290
elseif iscall(O) && isa(operation(O), Integral)
289291
return operation(O)(expand_derivatives(arguments(O)[1]; robust))
290292
elseif !hasderiv(O)
@@ -294,6 +296,8 @@ function expand_derivatives(O::Symbolic, simplify=false; robust=false, occurrenc
294296
O1 = operation(O)(args...)
295297
return simplify ? SymbolicUtils.simplify(O1) : O1
296298
end
299+
300+
executediff(operation(O), arg, simplify; robust, occurrences)
297301
end
298302
function expand_derivatives(n::Num, simplify=false; robust=false, occurrences=nothing)
299303
wrap(expand_derivatives(value(n), simplify; robust, occurrences))

test/diff.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,8 @@ let
357357
expr = b - ((D(b))^2) * D(D(b))
358358
expr2 = D(expr)
359359
@test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
360-
@test_throws BoundsError expand_derivatives(expr2)
361360
@test isequal(expand_derivatives(expr2; robust=true), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2))
361+
@test isequal(expand_derivatives(expr2; robust=true), expand_derivatives(expr2))
362362
end
363363

364364
# 1126
@@ -370,14 +370,13 @@ let
370370
expr_gen = (fun) -> D(D(((-D(D(fun))) / g(y))))
371371

372372
expr = expr_gen(g(y))
373-
@test_broken isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
373+
@test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
374374
expr = expr_gen(h(y))
375-
@test_broken isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
375+
@test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
376376

377377
expected = substitute(expand_derivatives(expr; robust=true), h(y) => f(y))
378378
expr = expr_gen(f(y))
379-
@test_throws BoundsError expand_derivatives(expr)
380-
@test isequal(expand(expand_derivatives(expr; robust=true)), expected)
379+
@test isequal(expand(expand_derivatives(expr)), expand(expand_derivatives(expr; robust=true)))
381380
end
382381

383382
# Check `is_derivative` function

0 commit comments

Comments
 (0)