Skip to content

Commit 6d2aed8

Browse files
fixup! feat: support callable parameters
1 parent 0136759 commit 6d2aed8

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ using Symbolics: _parse_vars, value, @derivatives, get_variables,
6565
NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval,
6666
initial_state, transition, activeState, entry, hasnode,
6767
ticksInState, timeInState, fixpoint_sub, fast_substitute,
68-
CallWithMetadata
68+
CallWithMetadata, CallWithParent
6969
const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR)
7070
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
7171
jacobian_sparsity, isaffine, islinear, _iszero, _isone,

src/parameters.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ function isparameter(x)
2626
end
2727
end
2828

29+
function iscalledparameter(x)
30+
x = unwrap(x)
31+
return isparameter(getmetadata(x, CallWithParent, nothing))
32+
end
33+
34+
function getcalledparameter(x)
35+
x = unwrap(x)
36+
return getmetadata(x, CallWithParent)
37+
end
38+
2939
"""
3040
toparam(s)
3141

src/utils.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,11 @@ end
371371
vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op)
372372
vars(exprs::Symbolics.Arr; op = Differential) = vars(unwrap(exprs); op)
373373
function vars(exprs; op = Differential)
374-
foldl((x, y) -> vars!(x, unwrap(y); op = op), exprs; init = Set())
374+
if hasmethod(iterate, Tuple{typeof(exprs)})
375+
foldl((x, y) -> vars!(x, unwrap(y); op = op), exprs; init = Set())
376+
else
377+
vars!(Set(), unwrap(exprs); op)
378+
end
375379
end
376380
vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op)
377381
function vars!(vars, eq::Equation; op = Differential)
@@ -479,7 +483,11 @@ end
479483

480484
function collect_var!(unknowns, parameters, var, iv)
481485
isequal(var, iv) && return nothing
482-
if isparameter(var) || (iscall(var) && isparameter(operation(var)))
486+
if iscalledparameter(var)
487+
callable = getcalledparameter(var)
488+
push!(parameters, callable)
489+
collect_vars!(unknowns, parameters, arguments(var), iv)
490+
elseif isparameter(var) || (iscall(var) && isparameter(operation(var)))
483491
push!(parameters, var)
484492
elseif !isconstant(var)
485493
push!(unknowns, var)

0 commit comments

Comments
 (0)