Skip to content

Commit acdc96a

Browse files
committed
Map subtypes in extract_elements
Specify mapping of original subtypes to final subtypes. Allows a default mapping for subtypes not included in `targetmap`. If `default == nothing`, subtypes not included in `targetmap` are ignored.
1 parent 06b24a5 commit acdc96a

File tree

3 files changed

+30
-15
lines changed

3 files changed

+30
-15
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,18 @@ function DiffEqSystem(eqs; iv_name = :IndependentVariable,
2020
dv_name = :DependentVariable,
2121
v_name = :Variable,
2222
p_name = :Parameter)
23-
ivs, dvs, vs, ps = extract_elements(eqs, (iv_name, dv_name, v_name, p_name))
23+
targetmap = Dict(iv_name => iv_name, dv_name => dv_name, v_name => v_name,
24+
p_name => p_name)
25+
ivs, dvs, vs, ps = extract_elements(eqs, targetmap)
2426
DiffEqSystem(eqs, ivs, dvs, vs, ps, iv_name, dv_name, p_name)
2527
end
2628

2729
function DiffEqSystem(eqs, ivs;
2830
dv_name = :DependentVariable,
2931
v_name = :Variable,
3032
p_name = :Parameter)
31-
dvs, vs, ps = extract_elements(eqs, (dv_name, v_name, p_name))
33+
targetmap = Dict(dv_name => dv_name, v_name => v_name, p_name => p_name)
34+
dvs, vs, ps = extract_elements(eqs, targetmap)
3235
DiffEqSystem(eqs, ivs, dvs, vs, ps, ivs[1].subtype, dv_name, p_name)
3336
end
3437

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ function NonlinearSystem(eqs;
1818
dv_name = :DependentVariable,
1919
p_name = :Parameter)
2020
# Allow the use of :DependentVariable to make it seamless with DE use
21-
dvs, vs, ps = extract_elements(eqs, (dv_name, v_name, p_name))
22-
vs = [dvs;vs]
21+
targetmap = Dict(v_name => v_name, dv_name => v_name, p_name => p_name)
22+
vs, ps = extract_elements(eqs, targetmap)
2323
NonlinearSystem(eqs, vs, ps, [v_name,dv_name], p_name)
2424
end
2525

src/variables.jl

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,26 +100,38 @@ end
100100

101101
extract_idv(eq) = eq.args[1].diff.x
102102

103-
function extract_elements(ops, eltypes)
103+
function extract_elements(ops, targetmap, default = nothing)
104104
elems = Dict{Symbol, Vector{Variable}}()
105105
names = Dict{Symbol, Set{Symbol}}()
106-
for el in eltypes
107-
elems[el] = Vector{Variable}()
108-
names[el] = Set{Symbol}()
106+
if default == nothing
107+
targets = unique(collect(values(targetmap)))
108+
else
109+
targets = [unique(collect(values(targetmap))), default]
110+
end
111+
for target in targets
112+
elems[target] = Vector{Variable}()
113+
names[target] = Set{Symbol}()
109114
end
110115
for op in ops
111-
extract_elements!(op, elems, names)
116+
extract_elements!(op, elems, names, targetmap, default)
112117
end
113-
Tuple(elems[el] for el in eltypes)
118+
Tuple(elems[target] for target in targets)
114119
end
115120
# Walk the tree recursively and push variables into the right set
116-
function extract_elements!(op::AbstractOperation, elems, names)
121+
function extract_elements!(op::AbstractOperation, elems, names, targetmap, default)
117122
for arg in op.args
118123
if arg isa Operation
119-
extract_elements!(arg, elems, names)
120-
elseif arg isa Variable && haskey(elems, arg.subtype) && !in(arg.name, names[arg.subtype])
121-
push!(names[arg.subtype], arg.name)
122-
push!(elems[arg.subtype], arg)
124+
extract_elements!(arg, elems, names, targetmap, default)
125+
elseif arg isa Variable
126+
if default == nothing
127+
target = haskey(targetmap, arg.subtype) ? targetmap[arg.subtype] : continue
128+
else
129+
target = haskey(targetmap, arg.subtype) ? targetmap[arg.subtype] : default
130+
end
131+
if !in(arg.name, names[target])
132+
push!(names[target], arg.name)
133+
push!(elems[target], arg)
134+
end
123135
end
124136
end
125137
end

0 commit comments

Comments
 (0)