Skip to content

Commit d2ddb17

Browse files
committed
More scope adjustment
1 parent 1714a6b commit d2ddb17

File tree

5 files changed

+32
-35
lines changed

5 files changed

+32
-35
lines changed

ext/DAECompilerModelingToolkitExt.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ function declare_parameters(model, struct_name)
7575
backing::B
7676
end
7777
)
78-
79-
78+
79+
8080
constructor_expr =:(
8181
@generated function _check_parameter_names(::Type{$struct_name}, param_kwargs::NamedTuple)
8282
unexpected_parameters = setdiff(fieldnames(param_kwargs), $param_names_tuple_expr)
@@ -108,7 +108,7 @@ function declare_parameters(model, struct_name)
108108
if name === $param_name
109109
return if hasfield(B, $param_name)
110110
getfield(getfield(this, :backing), $param_name)
111-
else
111+
else
112112
$param_value
113113
end
114114
end
@@ -118,7 +118,7 @@ function declare_parameters(model, struct_name)
118118
return getfield(getfield(this, :backing), name)
119119
))
120120
getproperty_expr.args[end].args[end] = Expr(:block, getproperty_body...)
121-
121+
122122
return Expr(:block, struct_expr, constructor_expr, propertynames_expr, getproperty_expr)
123123
end
124124

@@ -206,7 +206,7 @@ end
206206

207207
macro DAECompiler.declare_MTKConnector(mtk_component, ports...)
208208
# We do need to do run time eval, because we can't decide what to construct with just lexical information.
209-
# we need the values of the
209+
# we need the values of the
210210
:(Base.eval(@__MODULE__, $MTKConnector_AST($(esc(mtk_component)), $(esc.(ports)...))))
211211
end
212212

@@ -219,7 +219,7 @@ function MTKConnector_AST(model::MTK.ODESystem, ports...)
219219
end
220220

221221
while !isnothing(MTK.get_parent(model))
222-
# Undo any call to structural_simplify
222+
# Undo any call to structural_simplify
223223
# (Should we give a warning here? They did waste CPU cycles simplfying it in first place)
224224
model = MTK.get_parent(model)
225225
end
@@ -239,11 +239,11 @@ function MTKConnector_AST(model::MTK.ODESystem, ports...)
239239

240240

241241
struct_name = gensym(nameof(model))
242-
242+
243243
return quote
244244
$(declare_parameters(model, struct_name))
245245

246-
function (this::$struct_name)($(port_names...); dscope=$(_c(Scope))())
246+
function (this::$struct_name)($(map(port->:($(port)::Float64), port_names)...); dscope=$(_c(Scope))())
247247
$(declare_vars(model, :dscope))
248248
$(declare_derivatives(state))
249249
$(declare_equations(state, model, :dscope, ports))
@@ -258,4 +258,4 @@ function MTKConnector_AST(model::MTK.ODESystem, ports...)
258258
end
259259

260260

261-
end # module
261+
end # module

src/analysis/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ has_any_genscope(sc::PartialStruct) = false # TODO
319319

320320
function _make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
321321
if isa(argt, Const)
322-
@assert !isa(argt.val, Scope) # Shouldn't have been forwarded
322+
#@assert !isa(argt.val, Scope) # Shouldn't have been forwarded
323323
return argt
324324
elseif isa(argt, Type) && argt <: Intrinsics.AbstractScope
325325
return PartialScope(add_scope!(which))

src/analysis/lattice.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,9 @@ function CC._getfield_tfunc(🥬::DAELattice, @nospecialize(s00), @nospecialize(
462462
return Union{}
463463
end
464464
rt = CC._getfield_tfunc(CC.widenlattice(🥬), s00.typ, name, setfield)
465+
if rt == Union{}
466+
return Union{}
467+
end
465468
if isempty(s00)
466469
return Incidence(rt)
467470
end

src/state_mapping.jl

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ using SciMLBase, SymbolicIndexingInterface
33
struct ScopeRef{T, ST}
44
sys::T
55
scope::Scope{ST}
6+
7+
# (Optional) opaque data structure to facilitate faster `getproperty`.
8+
cursor
69
end
710
Base.Broadcast.broadcastable(ref::ScopeRef) = Ref(ref) # broadcast as scalar
811

@@ -135,33 +138,24 @@ function SciMLBase.sym_to_index(sr::ScopeRef, A::SciMLBase.DEIntegrator)
135138
end
136139

137140
function Base.getproperty(sys::IRODESystem, name::Symbol)
138-
haskey(StructuralAnalysisResult(sys).names, name) || throw(Base.UndefRefError())
139-
return ScopeRef(sys, Scope(Scope(), name))
141+
names = StructuralAnalysisResult(sys).names
142+
cursor = get(names, name, nothing)
143+
cursor === nothing && throw(Base.UndefRefError())
144+
return ScopeRef(sys, Scope(Scope(), name), cursor)
140145
end
141146

142147
function Base.propertynames(sr::ScopeRef)
143-
scope = getfield(sr, :scope)
144-
stack = sym_stack(scope)
145-
strct = NameLevel(StructuralAnalysisResult(IRODESystem(sr)).names)
146-
for s in reverse(stack)
147-
strct = strct.children[s]
148-
strct.children === nothing && return keys(Dict{Symbol, Any}())
149-
end
150-
return keys(strct.children)
148+
cursor = getfield(sr, :cursor)
149+
cursor.children === nothing && return keys(Dict{Symbol, Any}())
150+
return keys(cursor.children)
151151
end
152152

153153
function Base.getproperty(sr::ScopeRef{IRODESystem}, name::Symbol)
154-
scope = getfield(sr, :scope)
155-
stack = sym_stack(scope)
156-
strct = NameLevel(StructuralAnalysisResult(IRODESystem(sr)).names)
157-
for s in reverse(stack)
158-
strct = strct.children[s]
159-
strct.children === nothing && throw(Base.UndefRefError())
160-
end
161-
if !haskey(strct.children, name)
162-
throw(Base.UndefRefError())
163-
end
164-
ScopeRef(IRODESystem(sr), Scope(getfield(sr, :scope), name))
154+
cursor = getfield(sr, :cursor)
155+
cursor.children === nothing && return throw(Base.UndefRefError())
156+
new_cursor = get(cursor.children, name, nothing)
157+
new_cursor === nothing && return throw(Base.UndefRefError())
158+
return ScopeRef(IRODESystem(sr), Scope(getfield(sr, :scope), name), new_cursor)
165159
end
166160

167161
function Base.show(io::IO, scope::Scope)

test/MSL/modeling_toolkit_helper.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,10 @@ function Base.getproperty(sys::IRODESystem, name::Symbol)
261261
namespaces = split_namespaces_var(name)
262262
if haskey(names, namespaces[1])
263263
# Normal DAECompiler way
264-
return return get_scope_ref(sys, namespaces)
264+
return return get_scope_ref(sys, namespaces, names[namespaces[1]])
265265
elseif length(namespaces) > 1 && haskey(names, namespaces[2])
266266
# Ignore first namespace it's cos we are not fully consistent with if we include the system name or not
267-
return return get_scope_ref(sys, namespaces; start_idx=2)
267+
return return get_scope_ref(sys, namespaces, names[namespaces[2]]; start_idx=2)
268268
else # It could be from the mtksys
269269
mtksys = sys_map[sys_map_key(sys)]
270270
if hasproperty(mtksys, name) # if it is actually from the MTK system (which allows unflattened names)
@@ -273,8 +273,8 @@ function Base.getproperty(sys::IRODESystem, name::Symbol)
273273
end
274274
throw(Base.KeyError(name)) # should be a UndefRef but key error useful for findout what broke it.
275275
end
276-
function get_scope_ref(sys, names; start_idx=1)
277-
ref = DAECompiler.ScopeRef(sys, DAECompiler.Scope(DAECompiler.Scope(), names[start_idx]))
276+
function get_scope_ref(sys, names, cursor; start_idx=1)
277+
ref = DAECompiler.ScopeRef(sys, DAECompiler.Scope(DAECompiler.Scope(), names[start_idx]), cursor)
278278
for name in @view names[(start_idx+1):end]
279279
ref = getproperty(ref, name)
280280
end

0 commit comments

Comments
 (0)