Skip to content

Commit 0ea30bf

Browse files
author
Karl Wessel
committed
add flag for activating robust calculation of expand_derivatives
1 parent 8c518c2 commit 0ea30bf

File tree

2 files changed

+49
-19
lines changed

2 files changed

+49
-19
lines changed

src/diff.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,12 @@ julia> dfx=expand_derivatives(Dx(f))
180180
(k*((2abs(x - y)) / y - 2z)*IfElse.ifelse(signbit(x - y), -1, 1)) / y
181181
```
182182
"""
183-
function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
183+
function expand_derivatives(O::Symbolic, simplify=false; robust=false, occurrences=nothing)
184184
if iscall(O) && isa(operation(O), Differential)
185185
arg = only(arguments(O))
186-
arg = expand_derivatives(arg, false)
186+
arg = expand_derivatives(arg, false; robust)
187187

188-
if occurrences == nothing
188+
if robust || occurrences == nothing
189189
occurrences = occursin_info(operation(O).x, arg)
190190
end
191191

@@ -202,14 +202,14 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
202202
return D(arg) # base case if any argument is directly equal to the i.v.
203203
else
204204
return sum(inner_args, init=0) do a
205-
return expand_derivatives(Differential(a)(arg)) *
206-
expand_derivatives(D(a))
205+
return expand_derivatives(Differential(a)(arg); robust) *
206+
expand_derivatives(D(a); robust)
207207
end
208208
end
209209
elseif op === (IfElse.ifelse)
210210
args = arguments(arg)
211211
O = op(args[1], D(args[2]), D(args[3]))
212-
return expand_derivatives(O, simplify; occurrences)
212+
return expand_derivatives(O, simplify; robust, occurrences)
213213
elseif isa(op, Differential)
214214
# The recursive expand_derivatives was not able to remove
215215
# a nested Differential. We can attempt to differentiate the
@@ -218,20 +218,20 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
218218
if isequal(op.x, D.x)
219219
return D(arg)
220220
else
221-
inner = expand_derivatives(D(arguments(arg)[1]), false)
221+
inner = expand_derivatives(D(arguments(arg)[1]), false; robust)
222222
# if the inner expression is not expandable either, return
223223
if iscall(inner) && operation(inner) isa Differential
224224
return D(arg)
225225
else
226-
return expand_derivatives(op(inner), simplify)
226+
return expand_derivatives(op(inner), simplify; robust)
227227
end
228228
end
229229
elseif isa(op, Integral)
230230
if isa(op.domain.domain, AbstractInterval)
231231
domain = op.domain.domain
232232
a, b = DomainSets.endpoints(domain)
233233
c = 0
234-
inner_function = expand_derivatives(arguments(arg)[1])
234+
inner_function = expand_derivatives(arguments(arg)[1]; robust)
235235
if iscall(value(a))
236236
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a)))
237237
t2 = D(a)
@@ -242,7 +242,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
242242
t2 = D(b)
243243
c += t1*t2
244244
end
245-
inner = expand_derivatives(D(arguments(arg)[1]))
245+
inner = expand_derivatives(D(arguments(arg)[1]); robust)
246246
c += op(inner)
247247
return value(c)
248248
end
@@ -254,7 +254,7 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
254254
c = 0
255255

256256
for i in 1:l
257-
t2 = expand_derivatives(D(inner_args[i]),false, occurrences=arguments(occurrences)[i])
257+
t2 = expand_derivatives(D(inner_args[i]),false; robust, occurrences=arguments(occurrences)[i])
258258

259259
x = if _iszero(t2)
260260
t2
@@ -286,23 +286,23 @@ function expand_derivatives(O::Symbolic, simplify=false; occurrences=nothing)
286286
return simplify ? SymbolicUtils.simplify(x) : x
287287
end
288288
elseif iscall(O) && isa(operation(O), Integral)
289-
return operation(O)(expand_derivatives(arguments(O)[1]))
289+
return operation(O)(expand_derivatives(arguments(O)[1]; robust))
290290
elseif !hasderiv(O)
291291
return O
292292
else
293-
args = map(a->expand_derivatives(a, false), arguments(O))
293+
args = map(a->expand_derivatives(a, false; robust), arguments(O))
294294
O1 = operation(O)(args...)
295295
return simplify ? SymbolicUtils.simplify(O1) : O1
296296
end
297297
end
298-
function expand_derivatives(n::Num, simplify=false; occurrences=nothing)
299-
wrap(expand_derivatives(value(n), simplify; occurrences=occurrences))
298+
function expand_derivatives(n::Num, simplify=false; robust=false, occurrences=nothing)
299+
wrap(expand_derivatives(value(n), simplify; robust, occurrences))
300300
end
301-
function expand_derivatives(n::Complex{Num}, simplify=false; occurrences=nothing)
302-
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; occurrences=occurrences),
303-
expand_derivatives(imag(n), simplify; occurrences=occurrences)))
301+
function expand_derivatives(n::Complex{Num}, simplify=false; robust=false, occurrences=nothing)
302+
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; robust, occurrences),
303+
expand_derivatives(imag(n), simplify; robust, occurrences)))
304304
end
305-
expand_derivatives(x, simplify=false; occurrences=nothing) = x
305+
expand_derivatives(x, simplify=false; robust=false, occurrences=nothing) = x
306306

307307
_iszero(x) = false
308308
_isone(x) = false

test/diff.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,36 @@ let
349349
@test isequal(expand_derivatives(Differential(t)(t^2 + im*t)), 2t + im)
350350
end
351351

352+
# 1262
353+
#
354+
let
355+
@variables t b(t)
356+
D = Differential(t)
357+
expr = b - ((D(b))^2) * D(D(b))
358+
expr2 = D(expr)
359+
@test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
360+
@test_throws BoundsError expand_derivatives(expr2)
361+
@test isequal(expand_derivatives(expr2; robust=true), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2))
362+
end
363+
364+
# 1126
365+
#
366+
let
367+
@syms y f(y) g(y) h(y)
368+
D = Differential(y)
369+
370+
expr_gen = (fun) -> D(D(((-D(D(fun))) / g(y))))
371+
372+
expr = expr_gen(g(y))
373+
@test_broken isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
374+
expr = expr_gen(h(y))
375+
@test_broken isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
376+
377+
expected = substitute(expand_derivatives(expr; robust=true), h(y) => f(y))
378+
expr = expr_gen(f(y))
379+
@test_throws BoundsError expand_derivatives(expr)
380+
@test isequal(expand(expand_derivatives(expr; robust=true)), expected)
381+
end
352382

353383
# Check `is_derivative` function
354384
let

0 commit comments

Comments
 (0)