Skip to content

Commit 3bf685a

Browse files
author
Karl Wessel
committed
remove robust flag
1 parent 7fba695 commit 3bf685a

File tree

2 files changed

+46
-33
lines changed

2 files changed

+46
-33
lines changed

src/diff.jl

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,25 @@ function recursive_hasoperator(op, O)
150150
end
151151
end
152152

153-
function executediff(D, arg, simplify=false; robust=false, occurrences=nothing)
154-
if robust || occurrences == nothing
153+
"""
154+
executediff(D, arg, simplify=false; occurrences=nothing)
155+
156+
Apply the passed Differential D on the passed argument.
157+
158+
This function differs to `expand_derivatives` in that in only expands the
159+
passed differential and not any other Differentials it encounters.
160+
161+
# Arguments
162+
- `D::Differential`: The differential to apply
163+
- `arg::Symbolic`: The symbolic expression to apply the differential on.
164+
- `simplify::Bool=false`: Whether to simplify the resulting expression using
165+
[`SymbolicUtils.simplify`](@ref).
166+
- `occurrences=nothing`: Information about the occurrences of the independent
167+
variable in the argument of the derivative. This is used internally for
168+
optimization purposes.
169+
"""
170+
function executediff(D, arg, simplify=false; occurrences=nothing)
171+
if occurrences == nothing
155172
occurrences = occursin_info(D.x, arg)
156173
end
157174

@@ -166,15 +183,15 @@ function executediff(D, arg, simplify=false; robust=false, occurrences=nothing)
166183
return D(arg) # base case if any argument is directly equal to the i.v.
167184
else
168185
return sum(inner_args, init=0) do a
169-
return executediff(Differential(a), arg; robust) *
170-
executediff(D, a; robust)
186+
return executediff(Differential(a), arg) *
187+
executediff(D, a)
171188
end
172189
end
173190
elseif op === (IfElse.ifelse)
174191
args = arguments(arg)
175192
O = op(args[1],
176-
executediff(D, args[2], simplify; robust, occurrences=arguments(occurrences)[2]),
177-
executediff(D, args[3], simplify; robust, occurrences=arguments(occurrences)[3]))
193+
executediff(D, args[2], simplify; occurrences=arguments(occurrences)[2]),
194+
executediff(D, args[3], simplify; occurrences=arguments(occurrences)[3]))
178195
return O
179196
elseif isa(op, Differential)
180197
# The recursive expand_derivatives was not able to remove
@@ -184,20 +201,21 @@ function executediff(D, arg, simplify=false; robust=false, occurrences=nothing)
184201
if isequal(op.x, D.x)
185202
return D(arg)
186203
else
187-
inner = executediff(D, arguments(arg)[1], false; robust)
204+
inner = executediff(D, arguments(arg)[1], false)
188205
# if the inner expression is not expandable either, return
189206
if iscall(inner) && operation(inner) isa Differential
190207
return D(arg)
191208
else
192-
return expand_derivatives(op(inner), simplify; robust) # TODO
209+
# otherwise give the nested Differential another try
210+
return executediff(op, inner, simplify)
193211
end
194212
end
195213
elseif isa(op, Integral)
196214
if isa(op.domain.domain, AbstractInterval)
197215
domain = op.domain.domain
198216
a, b = DomainSets.endpoints(domain)
199217
c = 0
200-
inner_function = expand_derivatives(arguments(arg)[1]; robust) # TODO
218+
inner_function = arguments(arg)[1]
201219
if iscall(value(a))
202220
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a)))
203221
t2 = D(a)
@@ -208,7 +226,7 @@ function executediff(D, arg, simplify=false; robust=false, occurrences=nothing)
208226
t2 = D(b)
209227
c += t1*t2
210228
end
211-
inner = executediff(D, arguments(arg)[1]; robust)
229+
inner = executediff(D, arguments(arg)[1])
212230
c += op(inner)
213231
return value(c)
214232
end
@@ -220,7 +238,7 @@ function executediff(D, arg, simplify=false; robust=false, occurrences=nothing)
220238
c = 0
221239

222240
for i in 1:l
223-
t2 = executediff(D, inner_args[i],false; robust, occurrences=arguments(occurrences)[i])
241+
t2 = executediff(D, inner_args[i],false; occurrences=arguments(occurrences)[i])
224242

225243
x = if _iszero(t2)
226244
t2
@@ -265,9 +283,6 @@ and other derivative rules to expand any derivatives it encounters.
265283
- `O::Symbolic`: The symbolic expression to expand.
266284
- `simplify::Bool=false`: Whether to simplify the resulting expression using
267285
[`SymbolicUtils.simplify`](@ref).
268-
- `occurrences=nothing`: Information about the occurrences of the independent
269-
variable in the argument of the derivative. This is used internally for
270-
optimization purposes.
271286
272287
# Examples
273288
```jldoctest
@@ -283,30 +298,29 @@ julia> dfx=expand_derivatives(Dx(f))
283298
(k*((2abs(x - y)) / y - 2z)*IfElse.ifelse(signbit(x - y), -1, 1)) / y
284299
```
285300
"""
286-
function expand_derivatives(O::Symbolic, simplify=false; robust=false, occurrences=nothing)
301+
function expand_derivatives(O::Symbolic, simplify=false)
287302
if iscall(O) && isa(operation(O), Differential)
288303
arg = only(arguments(O))
289-
arg = expand_derivatives(arg, false; robust)
304+
arg = expand_derivatives(arg, false)
305+
return executediff(operation(O), arg, simplify)
290306
elseif iscall(O) && isa(operation(O), Integral)
291-
return operation(O)(expand_derivatives(arguments(O)[1]; robust))
307+
return operation(O)(expand_derivatives(arguments(O)[1]))
292308
elseif !hasderiv(O)
293309
return O
294310
else
295-
args = map(a->expand_derivatives(a, false; robust), arguments(O))
311+
args = map(a->expand_derivatives(a, false), arguments(O))
296312
O1 = operation(O)(args...)
297313
return simplify ? SymbolicUtils.simplify(O1) : O1
298314
end
299-
300-
executediff(operation(O), arg, simplify; robust, occurrences)
301315
end
302-
function expand_derivatives(n::Num, simplify=false; robust=false, occurrences=nothing)
303-
wrap(expand_derivatives(value(n), simplify; robust, occurrences))
316+
function expand_derivatives(n::Num, simplify=false)
317+
wrap(expand_derivatives(value(n), simplify))
304318
end
305-
function expand_derivatives(n::Complex{Num}, simplify=false; robust=false, occurrences=nothing)
306-
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; robust, occurrences),
307-
expand_derivatives(imag(n), simplify; robust, occurrences)))
319+
function expand_derivatives(n::Complex{Num}, simplify=false)
320+
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify),
321+
expand_derivatives(imag(n), simplify)))
308322
end
309-
expand_derivatives(x, simplify=false; robust=false, occurrences=nothing) = x
323+
expand_derivatives(x, simplify=false) = x
310324

311325
_iszero(x) = false
312326
_isone(x) = false

test/diff.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,8 @@ let
356356
D = Differential(t)
357357
expr = b - ((D(b))^2) * D(D(b))
358358
expr2 = D(expr)
359-
@test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
360-
@test isequal(expand_derivatives(expr2; robust=true), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2))
361-
@test isequal(expand_derivatives(expr2; robust=true), expand_derivatives(expr2))
359+
@test isequal(expand_derivatives(expr), expr)
360+
@test isequal(expand_derivatives(expr2), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2))
362361
end
363362

364363
# 1126
@@ -370,13 +369,13 @@ let
370369
expr_gen = (fun) -> D(D(((-D(D(fun))) / g(y))))
371370

372371
expr = expr_gen(g(y))
373-
@test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
372+
# just make sure that no errors are thrown in the following, the results are to complicated to compare
373+
expand_derivatives(expr)
374374
expr = expr_gen(h(y))
375-
@test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
375+
expand_derivatives(expr)
376376

377-
expected = substitute(expand_derivatives(expr; robust=true), h(y) => f(y))
378377
expr = expr_gen(f(y))
379-
@test isequal(expand(expand_derivatives(expr)), expand(expand_derivatives(expr; robust=true)))
378+
expand_derivatives(expr)
380379
end
381380

382381
# Check `is_derivative` function

0 commit comments

Comments
 (0)