Skip to content

Commit 3614610

Browse files
committed
Improve registered module injection so we don't have to overwrite GeneralizedGenerated.solve().
1 parent 025f225 commit 3614610

File tree

2 files changed

+40
-29
lines changed

2 files changed

+40
-29
lines changed

src/build_function.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,26 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
127127
if expression == Val{true}
128128
return oop_ex
129129
else
130-
return GeneralizedGenerated.mk_function(@__MODULE__,oop_ex)
130+
_build_and_inject_function(@__MODULE__, oop_ex)
131131
end
132132
end
133133

134+
function _build_and_inject_function(mod::Module, ex)
135+
# Generate the function, which will process the expression
136+
runtimefn = GeneralizedGenerated.mk_function(@__MODULE__, ex)
137+
138+
# Extract the processed expression of the function body
139+
params = typeof(runtimefn).parameters
140+
fn_expr = GeneralizedGenerated.NGG.from_type(params[3])
141+
142+
# Inject our externally registered module functions
143+
new_expr = ModelingToolkit.inject_registered_module_functions(fn_expr)
144+
145+
# Reconstruct the RuntimeFn's Body
146+
new_body = GeneralizedGenerated.NGG.to_type(new_expr)
147+
return GeneralizedGenerated.RuntimeFn{params[1:2]..., new_body, params[4]}()
148+
end
149+
134150
# Detect heterogeneous element types of "arrays of matrices/sparce matrices"
135151
function is_array_matrix(F)
136152
return isa(F, AbstractVector) && all(x->isa(x, AbstractArray), F)
@@ -288,7 +304,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
288304
if expression == Val{true}
289305
return oop_ex, iip_ex
290306
else
291-
return GeneralizedGenerated.mk_function(@__MODULE__,oop_ex), GeneralizedGenerated.mk_function(@__MODULE__,iip_ex)
307+
return _build_and_inject_function(@__MODULE__, oop_ex), _build_and_inject_function(@__MODULE__, iip_ex)
292308
end
293309
end
294310

src/function_registration.jl

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,32 @@ any generated functions will use your registered method. See
2020
macro register(sig)
2121
splitsig = splitdef(:($sig = nothing))
2222
name = splitsig[:name]
23+
24+
# Extract the module and function name from the signature
25+
if name isa Symbol
26+
mod = __module__ # Calling module
27+
funcname = name
28+
else
29+
mod = name.args[1]
30+
funcname = name.args[2].value
31+
end
32+
2333
args = splitsig[:args]
2434
typargs = typed_args(args)
2535
defs = :()
2636
for typarg in typargs
2737
splitsig[:args] = typarg
28-
splitsig[:body] = :(Operation($name, Expression[$(args...)]))
38+
if mod == (@__MODULE__) # If the calling module is ModelingToolkit itself...
39+
splitsig[:body] = :(Operation($name, Expression[$(args...)]))
40+
else
41+
# Register the function's associated model so we can inject it in later.
42+
splitsig[:body] = quote
43+
get!(ModelingToolkit.registered_external_functions, Symbol($("$funcname")), $mod)
44+
Operation($name, Expression[$(args...)])
45+
end
46+
end
2947
defs = :($defs; $(combinedef(splitsig)))
3048
end
31-
if (@__MODULE__) != ModelingToolkit
32-
# Register external module registrations so we can rewrite any generated
33-
# functions through JuliaVariables.solve().
34-
get!(ModelingToolkit.registered_external_functions, name, @__MODULE__)
35-
end
3649
esc(defs)
3750
end
3851
# Create all valid combinations of Expression,Number for function signature
@@ -95,28 +108,10 @@ function inject_registered_module_functions(expr)
95108
MacroTools.@capture(x, f_(xs__)) # We need to find all function calls in the expression.
96109
# If the function call has been converted to a JuliaVariables.Var and matches
97110
# one of the functions we've registered...
98-
if !isnothing(f) && x.args[1] isa JuliaVariables.Var && x.args[1].name in keys(registered_external_functions)
99-
# Rewrite it from a Var to a regular function call.
100-
x.args[1] = getproperty(registered_external_functions[x.args[1].name], x.args[1].name)
111+
if !isnothing(f) && f isa Expr && f.head == :. && f.args[2] isa QuoteNode
112+
f_name = f.args[2].value
113+
f.args[1] = get(registered_external_functions, f_name, f.args[1])
101114
end
102115
return x # Make sure we rebuild the expression as is.
103116
end
104-
end
105-
106-
# TODO: Overwriting this function works, but is quite ugly. Is there a nicer way to inject the module names?
107-
function GeneralizedGenerated.mk_function(mod::Module, ex)
108-
ex = macroexpand(mod, ex)
109-
ex = GeneralizedGenerated.simplify_ex(ex)
110-
ex = GeneralizedGenerated.solve(ex)
111-
112-
# We need to modify the expression built by the JuliaVariables.solve(ex)
113-
# method, before GeneralizedGenerated.closure_conv(mod, ex) converts it to a
114-
# RuntimeFn (as done in GeneralizedGenerated.mk_function(mod, ex)).
115-
ex = inject_registered_module_functions(ex)
116-
117-
fn = GeneralizedGenerated.closure_conv(mod, ex)
118-
if !(fn isa GeneralizedGenerated.RuntimeFn)
119-
error("Expect an unnamed function expression. ")
120-
end
121-
fn
122117
end

0 commit comments

Comments
 (0)