Skip to content

Commit f033b0c

Browse files
committed
Better register and derivative tests
1 parent ca40b8e commit f033b0c

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

src/ModelingToolkit.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ RuntimeGeneratedFunctions.init(@__MODULE__)
2323
using RecursiveArrayTools
2424

2525
import SymbolicUtils
26-
import SymbolicUtils: Term, Add, Mul, Pow, Sym, to_symbolic, FnType, @rule, Rewriters, substitute, similarterm
26+
import SymbolicUtils: Term, Add, Mul, Pow, Sym, to_symbolic, FnType,
27+
@rule, Rewriters, substitute, similarterm,
28+
promote_symtype
2729

2830
import SymbolicUtils.Rewriters: Chain, Postwalk, Prewalk, Fixpoint
2931

src/register_function.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,18 @@ registered function.
88
```julia
99
@register foo(x, y)
1010
@register goo(x, y::Int) # `y` is not overloaded to take symbolic objects
11+
@register hoo(x, y)::Int # `hoo` returns `Int`
1112
```
1213
"""
1314
macro register(expr, Ts = [Num, Symbolic, Real])
14-
@assert expr.head == :call
15+
if expr.head === :(::)
16+
ret_type = expr.args[2]
17+
expr = expr.args[1]
18+
else
19+
ret_type = Real
20+
end
21+
22+
@assert expr.head === :call
1523

1624
f = expr.args[1]
1725
args = expr.args[2:end]
@@ -28,14 +36,17 @@ macro register(expr, Ts = [Num, Symbolic, Real])
2836
name(x::Symbol) = :($value($x))
2937
name(x::Expr) = ((@assert x.head == :(::)); :($value($(x.args[1]))))
3038

31-
Expr(:block,
32-
[quote
33-
function $f($(setinds(args, symbolic_args, ts)...))
34-
wrap = any(x->typeof(x) <: Num, tuple($(setinds(args, symbolic_args, ts)...),)) ? Num : identity
35-
wrap(Term{Real}($f, [$(map(name, args)...)]))
36-
end
37-
end
38-
for ts in types]...) |> esc
39+
ex = Expr(:block)
40+
for ts in types
41+
push!(ex.args, quote
42+
function $f($(setinds(args, symbolic_args, ts)...))
43+
wrap = any(x->typeof(x) <: Num, tuple($(setinds(args, symbolic_args, ts)...),)) ? Num : identity
44+
wrap(Term{$ret_type}($f, [$(map(name, args)...)]))
45+
end
46+
end)
47+
end
48+
push!(ex.args, :((::$typeof($promote_symtype))(::$typeof($f), args...) = $ret_type))
49+
esc(ex)
3950
end
4051

4152
# Ensure that Num that get @registered from outside the ModelingToolkit

test/direct.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,21 @@ canonequal(a, b) = isequal(simplify(a), simplify(b))
1313
)
1414

1515
@test canonequal(
16-
ModelingToolkit.derivative(sin(cos(x)), x),
17-
-sin(x) * cos(cos(x))
18-
)
16+
ModelingToolkit.derivative(sin(cos(x)), x),
17+
-sin(x) * cos(cos(x))
18+
)
19+
20+
@register no_der(x)
21+
@test canonequal(
22+
ModelingToolkit.derivative([sin(cos(x)), hypot(x, no_der(x))], x),
23+
[
24+
-sin(x) * cos(cos(x)),
25+
x/hypot(x, no_der(x)) + no_der(x)*Differential(x)(no_der(x))/hypot(x, no_der(x))
26+
]
27+
)
28+
29+
@register intfun(x)::Int
30+
@test ModelingToolkit.symtype(intfun(x)) === Int
1931

2032
eqs =*(y-x),
2133
x*-z)-y,

0 commit comments

Comments
 (0)