Skip to content

Commit 0af2fe9

Browse files
shashiYingboMa
andcommitted
redo register macro
Co-authored-by: "Yingbo Ma" <[email protected]>
1 parent 3bb919e commit 0af2fe9

File tree

8 files changed

+63
-11
lines changed

8 files changed

+63
-11
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ include("linearity.jl")
175175
include("solve.jl")
176176
include("direct.jl")
177177
include("domains.jl")
178+
include("register_function.jl")
178179

179180
include("systems/abstractsystem.jl")
180181

src/differentials.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,13 @@ for (modu, fun, arity) ∈ DiffRules.diffrules()
191191
else
192192
DiffRules.diffrule(modu, fun, ntuple(k->:(args[$k]), arity)...)[i]
193193
end
194-
@eval derivative(::typeof($modu.$fun), args::NTuple{$arity,Any}, ::Val{$i}) = (x = $expr; !(x isa Expression || x isa Constant) ? Constant(x) : x)
194+
@eval derivative(::typeof($modu.$fun), args::NTuple{$arity,Any}, ::Val{$i}) = $expr
195195
end
196196
end
197197

198-
derivative(::typeof(+), args::NTuple{N,Any}, ::Val) where {N} = Constant(1)
198+
derivative(::typeof(+), args::NTuple{N,Any}, ::Val) where {N} = 1
199199
derivative(::typeof(*), args::NTuple{N,Any}, ::Val{i}) where {N,i} = make_operation(*, deleteat!(collect(args), i))
200-
derivative(::typeof(one), args::Tuple{<:Any}, ::Val) = Constant(0)
200+
derivative(::typeof(one), args::Tuple{<:Any}, ::Val) = 0
201201

202202
function count_order(x)
203203
@assert !(x isa Symbol) "The variable $x must have an order of differentiation that is greater or equal to 1!"

src/direct.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ function toexpr(O::Term)
215215
return Expr(:call, :^, Expr(:call, :inv, toexpr(O.args[1])), -(O.args[2].value))
216216
end
217217
end
218-
return Expr(:call, Symbol(O.op), toexpr.(O.args)...)
218+
return Expr(:call, O.op, toexpr.(O.args)...)
219219
end
220220
toexpr(s::Sym) = nameof(s)
221221
toexpr(s) = s

src/extra_functions.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
@register Base.conj(x)
1+
@register Base.getindex(x,i::Integer)
22
@register Base.getindex(x,i)
33
@register Base.binomial(n,k)
4-
@register Base.copysign(x,y)
54

65
@register Base.signbit(x)
76
ModelingToolkit.derivative(::typeof(signbit), args::NTuple{1,Any}, ::Val{1}) = 0
87

98
ModelingToolkit.derivative(::typeof(abs), args::NTuple{1,Any}, ::Val{1}) = IfElse.ifelse(signbit(args[1]),-one(args[1]),one(args[1]))
109

11-
@register IfElse.ifelse(x,y,z)
10+
@register IfElse.ifelse(x,y,z::Any)
1211
ModelingToolkit.derivative(::typeof(IfElse.ifelse), args::NTuple{3,Any}, ::Val{1}) = 0
1312
ModelingToolkit.derivative(::typeof(IfElse.ifelse), args::NTuple{3,Any}, ::Val{2}) = IfElse.ifelse(args[1],1,0)
1413
ModelingToolkit.derivative(::typeof(IfElse.ifelse), args::NTuple{3,Any}, ::Val{3}) = IfElse.ifelse(args[1],0,1)

src/register_function.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
macro register(expr, Ts = [Num, Symbolic, Number])
2+
@assert expr.head == :call
3+
4+
f = expr.args[1]
5+
args = expr.args[2:end]
6+
7+
symbolic_args = findall(x->x isa Symbol, args)
8+
9+
types = vec(collect(Iterators.product(ntuple(_->Ts, length(symbolic_args))...)))
10+
11+
# remove Number Number Number methods
12+
filter!(Ts->!all(T->T == Number, Ts), types)
13+
14+
annotype(name,T) = :($name :: $T)
15+
setinds(xs, idx, vs) = (xs=copy(xs); xs[idx] .= map(annotype, xs[idx], vs); xs)
16+
name(x::Symbol) = :($value($x))
17+
name(x::Expr) = ((@assert x.head == :(::)); :($value($(x.args[1]))))
18+
19+
Expr(:block,
20+
[:($f($(setinds(args, symbolic_args, ts)...)) = Term{Number}($f, [$(map(name, args)...)]))
21+
for ts in types]...) |> esc
22+
end
23+
24+
# Ensure that Operations that get @registered from outside the ModelingToolkit
25+
# module can work without having to bring in the associated function into the
26+
# ModelingToolkit namespace. We basically store information about functions
27+
# registered at runtime in a ModelingToolkit variable,
28+
# `registered_external_functions`. It's not pretty, but we are limited by the
29+
# way GeneralizedGenerated builds a function (adding "ModelingToolkit" to every
30+
# function call).
31+
# ---
32+
const registered_external_functions = Dict{Symbol,Module}()
33+
function inject_registered_module_functions(expr)
34+
MacroTools.postwalk(expr) do x
35+
# Find all function calls in the expression and extract the function
36+
# name and calling module.
37+
MacroTools.@capture(x, f_module_.f_name_(xs__))
38+
if isnothing(f_module)
39+
MacroTools.@capture(x, f_name_(xs__))
40+
end
41+
42+
if !isnothing(f_name)
43+
# Set the calling module to the module that registered it.
44+
mod = get(registered_external_functions, f_name, f_module)
45+
if !isnothing(mod) && mod != Base
46+
x.args[1] = :($mod.$f_name)
47+
end
48+
end
49+
50+
return x
51+
end
52+
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function (f::ODEToExpr)(O::Term)
4444
any(isequal(O), f.states) && return O.op.name # dependent variables
4545
return build_expr(:call, Any[O.op.name; f.(O.args)])
4646
end
47-
return build_expr(:call, Any[Symbol(O.op); f.(O.args)])
47+
return build_expr(:call, Any[O.op; f.(O.args)])
4848
end
4949
(f::ODEToExpr)(x) = toexpr(x)
5050

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ function states_to_sym(states)
134134
any(isequal(O), states) && return O.op.name # dependent variables
135135
return build_expr(:call, Any[O.op.name; _states_to_sym.(O.args)])
136136
else
137-
return build_expr(:call, Any[Symbol(O.op); _states_to_sym.(O.args)])
137+
return build_expr(:call, Any[O.op; _states_to_sym.(O.args)])
138138
end
139139
elseif O isa Num
140140
return _states_to_sym(value(O))

test/function_registration.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ foo(x, y) = sin(x) * cos(y)
7979
using ModelingToolkit: value
8080
expr = value(foo(x, y))
8181
@test expr.op === foo
82-
@test expr.args[1] === x
83-
@test expr.args[2] === y
82+
@test expr.args[1] === value(x)
83+
@test expr.args[2] === value(y)
8484
ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{1}) = cos(x) * cos(y) # derivative w.r.t. the first argument
8585
ModelingToolkit.derivative(::typeof(foo), (x, y), ::Val{2}) = -sin(x) * sin(y) # derivative w.r.t. the second argument
8686
@test isequal(expand_derivatives(D(foo(x, y))), expand_derivatives(D(sin(x) * cos(y))))

0 commit comments

Comments
 (0)