Skip to content

Commit 341ba5e

Browse files
Handle conditionals and abs
It's time we finally handle #531, https://github.com/SciML/ModelingToolkit.jl/issues/532, and https://discourse.julialang.org/t/differentiation-method-for-element-wise-abs-function-applied-on-operation-types-in-modelingtoolkit-jl/45867/4 . This PR creates a `ModelingToolkit.ifelse` with derivative fixes for `abs` to make it so standard conditional code can work. While tracing cannot correctly make these operations, this would at least allow directly written symbolic code to handle conditions, so we're at least as good (or bad) as TensorFlow.
1 parent dba32a9 commit 341ba5e

File tree

4 files changed

+18
-9
lines changed

4 files changed

+18
-9
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ include("systems/reduction.jl")
123123

124124
include("latexify_recipes.jl")
125125
include("build_function.jl")
126+
include("extra_functions.jl")
126127

127128
export ODESystem, ODEFunction, ODEFunctionExpr, ODEProblemExpr
128129
export SDESystem, SDEFunction, SDEFunctionExpr, SDESystemExpr

src/differentials.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ derivative(O::Constant, ::Any) = Constant(0)
180180
# Pre-defined derivatives
181181
import DiffRules
182182
for (modu, fun, arity) DiffRules.diffrules()
183-
fun in [:*, :+] && continue # special
183+
fun in [:*, :+, :abs] && continue # special
184184
for i 1:arity
185185

186186
expr = if arity == 1

src/extra_functions.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
function ifelse end
2+
@register Base.conj(x)
3+
@register Base.getindex(x,i)
4+
@register Base.binomial(n,k)
5+
@register Base.copysign(x,y)
6+
7+
@register Base.signbit(x)
8+
ModelingToolkit.derivative(::typeof(signbit), args::NTuple{1,Any}, ::Val{1}) = 0
9+
10+
@register Base.abs(x)
11+
ModelingToolkit.derivative(::typeof(abs), args::NTuple{1,Any}, ::Val{1}) = ModelingToolkit.ifelse(signbit(args[1]),-one(args[1]),one(args[1]))
12+
13+
@register ModelingToolkit.ifelse(x,y,z)
14+
ModelingToolkit.derivative(::typeof(ModelingToolkit.ifelse), args::NTuple{3,Any}, ::Val{1}) = 0
15+
ModelingToolkit.derivative(::typeof(ModelingToolkit.ifelse), args::NTuple{3,Any}, ::Val{2}) = ModelingToolkit.ifelse(args[1],1,0)
16+
ModelingToolkit.derivative(::typeof(ModelingToolkit.ifelse), args::NTuple{3,Any}, ::Val{3}) = ModelingToolkit.ifelse(args[1],0,1)

src/function_registration.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,10 @@ for fun ∈ [:<, :>, :(==), :&, :|, :div]
9999
@eval @register $sig
100100
end
101101

102-
# ifelse
103-
#@register Base.ifelse(cond,t,f)
104-
105102
# special cases
106103
Base.:^(x::Expression,y::T) where T <: Integer = Operation(Base.:^, Expression[x, y])
107104
Base.:^(x::Expression,y::T) where T <: Rational = Operation(Base.:^, Expression[x, y])
108105

109-
@register Base.conj(x)
110-
@register Base.getindex(x,i)
111-
@register Base.binomial(n,k)
112-
@register Base.copysign(x,y)
113-
114106
Base.getindex(x::Operation,i::Int64) = Operation(getindex,[x,i])
115107
Base.one(::Operation) = 1
116108

0 commit comments

Comments
 (0)