Skip to content

Commit b1d0e49

Browse files
committed
Implement @incidence_str macro
1 parent eb4021c commit b1d0e49

File tree

3 files changed

+147
-17
lines changed

3 files changed

+147
-17
lines changed

src/analysis/lattice.jl

Lines changed: 123 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,12 @@ struct Incidence
139139
if is_non_incidence_type(type)
140140
throw(DomainError(type, "Invalid type for Incidence"))
141141
end
142-
row = convert(IncidenceVector, row)
142+
if !isa(row, IncidenceVector)
143+
vec, row = row, _zero_row()
144+
for (i, val) in enumerate(vec)
145+
row[i] = val
146+
end
147+
end
143148
time = row[1]
144149
if in(time, (linear_time_dependent, linear_time_and_state_dependent))
145150
throw(ArgumentError("Time incidence cannot be both linear and time-dependent, otherwise it would be nonlinear"))
@@ -199,8 +204,13 @@ function Base.show(io::IO, inc::Incidence)
199204
end
200205
end
201206
time = inc.row[1]
202-
time_linear = time !== nonlinear
203207
is_grouped(v, i) = isa(v, Linearity) && (v.state_dependent || (v.time_dependent || i == 1) && in(time, (linear_state_dependent, nonlinear)))
208+
function propto(linearity::Linearity)
209+
str = ""
210+
linearity.time_dependent && (str *= '')
211+
linearity.state_dependent && (str *= '')
212+
return str
213+
end
204214
for (i, v) in zip(rowvals(inc.row), nonzeros(inc.row))
205215
is_grouped(v, i) && continue
206216
if isa(v, Float64)
@@ -211,13 +221,11 @@ function Base.show(io::IO, inc::Incidence)
211221
else
212222
!first && print(io, " + ")
213223
first = false
214-
if !is_grouped(inc.row[1], 1)
215-
= i > 1 ? subscript(i - 1) : ''
216-
if v.time_dependent
217-
print(io, time_linear ? "∝t" : "f$ᵢ(t)", " * ")
218-
else # unknown constant coefficient
219-
print(io, "c$ᵢ * ")
220-
end
224+
if is_grouped(inc.row[1], 1) && v.time_dependent
225+
= subscript(i - 1)
226+
print(io, "$(propto(inc.row[1]::Linearity))t", " * ")
227+
else # unknown constant coefficient
228+
print(io, propto(v))
221229
end
222230
end
223231
print(io, subscript_state(i))
@@ -233,7 +241,7 @@ function Base.show(io::IO, inc::Incidence)
233241
else
234242
print(io, ", ")
235243
end
236-
!v.nonlinear && print(io, '')
244+
!v.nonlinear && print(io, propto(v))
237245
print(io, subscript_state(i))
238246
end
239247
if !first_grouped
@@ -243,6 +251,111 @@ function Base.show(io::IO, inc::Incidence)
243251
print(io, ")")
244252
end
245253

254+
macro incidence_str(str) generate_incidence(str) end
255+
256+
function generate_incidence(str::String)
257+
if startswith(str, "Incidence(") && endswith(str, ')')
258+
# Support `incidence"Incidence(...)"` so the user doesn't have to
259+
# manually remove the `Incidence` call when copy-pasting.
260+
str = str[11:(end - 1)]
261+
end
262+
str = replace(str, '' => '~')
263+
ex = Meta.parse(str)
264+
generate_incidence(ex)
265+
end
266+
267+
function generate_incidence(ex)
268+
T = nothing
269+
if isexpr(ex, :tuple, 2)
270+
T, ex = ex.args[1], ex.args[2]
271+
end
272+
generate_incidence(T, ex)
273+
end
274+
275+
function generate_incidence(T, ex)
276+
if isexpr(ex, :call) && ex.args[1] === :+
277+
terms = ex.args[2:end]
278+
else
279+
terms = Any[ex]
280+
end
281+
pairs = Dict{Int,Any}()
282+
for term in terms
283+
if term === :a
284+
T === nothing || throw(ArgumentError("The incidence type must not be provided if a constant `Float64` term is already present"))
285+
T = Float64
286+
continue
287+
elseif isa(term, Float64)
288+
T === nothing || throw(ArgumentError("The incidence type must not be provided if a literal `Float64` term is already present"))
289+
T = Const(term)
290+
continue
291+
end
292+
293+
@assert isa(term, Symbol) || isexpr(term, :call)
294+
295+
ispropto(x) = isexpr(x, :call, 2) && startswith(string(x.args[1]), '~')
296+
297+
if isa(term, Symbol)
298+
i = parse_variable(string(term))
299+
pairs[i] = 1.0
300+
elseif isexpr(term, :call, 3) && term.args[1] === :*
301+
factor = parse(Float64, string(term.args[2]))
302+
i = parse_variable(string(term.args[3]))
303+
pairs[i] = factor
304+
elseif ispropto(term)
305+
coefficient, variable = separate_coefficient_and_variable(term)
306+
coefficient = parse_coefficient(coefficient)
307+
i = parse_variable(variable)
308+
pairs[i] = coefficient
309+
elseif isexpr(term, :call) && term.args[1] === :f
310+
for argument in @view term.args[2:end]
311+
if ispropto(argument)
312+
coefficient, variable = separate_coefficient_and_variable(argument)
313+
coefficient = parse_coefficient(coefficient)
314+
i = parse_variable(variable)
315+
else
316+
i = parse_variable(string(argument))
317+
coefficient = nonlinear
318+
end
319+
pairs[i] = coefficient
320+
end
321+
else
322+
throw(ArgumentError("Unrecognized call to function '$(term.args[1])'"))
323+
end
324+
end
325+
values = IncidenceValue[]
326+
for i in 1:maximum(keys(pairs); init = 0)
327+
val = get(pairs, i, 0.0)
328+
isa(val, Pair) && (val = val.second)
329+
push!(values, val)
330+
end
331+
T = something(T, Const(0.0))
332+
:(Incidence($T, IncidenceValue[$(values...)]))
333+
end
334+
335+
function separate_coefficient_and_variable(term::Expr)
336+
str = string(term)
337+
i = findfirst(in(('t', 'u')), str)::Int
338+
@view(str[1:prevind(str, i)]), @view(str[i:end])
339+
end
340+
341+
function parse_coefficient(coefficient::AbstractString)
342+
matched = match(r"^~(ₜ)?(ₛ)?$", coefficient)
343+
@assert matched !== nothing
344+
time_dependent = matched.captures[1] !== nothing
345+
state_dependent = matched.captures[2] !== nothing
346+
return Linearity(; time_dependent, state_dependent, nonlinear = false)
347+
end
348+
349+
function parse_variable(term)
350+
term == "t" && return 1
351+
matched = match(r"^u([₀₁₂₃₄₅₆₇₈₉]+)$", term)
352+
@assert matched !== nothing
353+
capture = matched[1]
354+
return 1 + parse(Int, map(subscript_to_number, capture))
355+
end
356+
357+
subscript_to_number(char) = Char(48 + (UInt32(char) - 8320))
358+
246359
_zero_row() = IncidenceVector(MAX_EQS, Int[], IncidenceValue[])
247360
const _ZERO_ROW = _zero_row()
248361
const _ZERO_CONST = Const(0.0)

src/reflection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export code_structure_by_type, code_structure, @code_structure,
2-
code_ad_by_type, code_ad, @code_ad
2+
code_ad_by_type, code_ad, @code_ad, @incidence_str
33

44
function code_structure(@nospecialize(f), @nospecialize(types = Base.default_tt(f)); kwargs...)
55
tt = Base.signature_type(f, types)

test/incidence.jl

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,55 +75,72 @@ dependencies(row) = sort(rowvals(row) .=> nonzeros(row), by = first)
7575
@test incidence.typ === Const(1.0)
7676
@test dependencies(incidence.row) == []
7777
@test repr(incidence) == "Incidence(1.0)"
78+
@test incidence == incidence"1.0"
7879

7980
incidence = Incidence(Float64)
8081
@test repr(incidence) == "Incidence(a)"
82+
@test incidence == incidence"a"
8183

8284
incidence = Incidence(Float64, IncidenceValue[1.0])
8385
@test dependencies(incidence.row) == [1 => 1]
8486
@test repr(incidence) == "Incidence(a + t)"
87+
@test incidence == incidence"a + t"
8588

8689
incidence = Incidence(String, IncidenceValue[1.0])
8790
@test repr(incidence) == "Incidence(String, t)"
91+
@test incidence == incidence"String, t"
8892

8993
incidence = Incidence(1)
9094
@test incidence.typ === Const(0.0)
9195
@test dependencies(incidence.row) == [2 => 1]
9296
@test repr(incidence) == "Incidence(u₁)"
97+
@test incidence == incidence"u₁"
9398

9499
incidence = Incidence(3)
95100
@test dependencies(incidence.row) == [4 => 1]
96101
@test repr(incidence) == "Incidence(u₃)"
102+
@test incidence == incidence"u₃"
97103

98104
incidence = Incidence(Const(3.0), IncidenceValue[0.0, 0.0, 2.0, 1.0])
99105
@test repr(incidence) == "Incidence(3.0 + 2.0u₂ + u₃)"
106+
@test incidence == incidence"3.0 + 2.0u₂ + u₃"
100107

101108
incidence = Incidence(Const(0.0), IncidenceValue[4.0, 0.0, 2.0])
102109
@test repr(incidence) == "Incidence(4.0t + 2.0u₂)"
110+
@test incidence == incidence"4.0t + 2.0u₂"
103111

104112
incidence = Incidence(Const(0.0), IncidenceValue[nonlinear])
105113
@test repr(incidence) == "Incidence(f(t))"
114+
@test incidence == incidence"f(t)"
106115

107116
incidence = Incidence(Const(0.0), IncidenceValue[linear])
108-
@test repr(incidence) == "Incidence(cₜ * t)"
117+
@test repr(incidence) == "Incidence(∝t)"
118+
@test incidence == incidence"∝t"
109119

110120
incidence = Incidence(Const(0.0), IncidenceValue[1.0, nonlinear])
111121
@test repr(incidence) == "Incidence(t + f(u₁))"
122+
@test incidence == incidence"t + f(u₁)"
112123

113124
incidence = Incidence(Const(0.0), IncidenceValue[1.0, linear])
114-
@test repr(incidence) == "Incidence(t + c₁ * u₁)"
125+
@test repr(incidence) == "Incidence(t + ∝u₁)"
126+
@test incidence == incidence"t + ∝u₁"
115127

116128
incidence = Incidence(Const(0.0), IncidenceValue[linear, linear, linear])
117-
@test repr(incidence) == "Incidence(cₜ * t + c₁ * u₁ + c₂ * u₂)"
129+
@test repr(incidence) == "Incidence(∝t + ∝u₁ + ∝u₂)"
130+
@test incidence == incidence"∝t + ∝u₁ + ∝u₂"
118131

119132
incidence = Incidence(Const(0.0), IncidenceValue[linear_state_dependent, linear_time_dependent, linear])
120-
@test repr(incidence) == "Incidence(u₂ + f(∝t, ∝u₁))"
133+
@test repr(incidence) == "Incidence(∝u₂ + f(∝ₛt, ∝ₜu₁))"
134+
@test incidence == incidence"∝u₂ + f(∝ₛt, ∝ₜu₁)"
135+
@test incidence == incidence"Incidence(∝u₂ + f(∝ₛt, ∝ₜu₁))"
121136

122137
incidence = Incidence(Const(0.0), IncidenceValue[linear_state_dependent, linear_time_dependent, nonlinear])
123-
@test repr(incidence) == "Incidence(f(∝t, ∝u₁, u₂))"
138+
@test repr(incidence) == "Incidence(f(∝ₛt, ∝ₜu₁, u₂))"
139+
@test incidence == incidence"f(∝ₛt, ∝ₜu₁, u₂)"
124140

125141
incidence = Incidence(Const(0.0), IncidenceValue[nonlinear, linear_time_dependent, nonlinear])
126-
@test repr(incidence) == "Incidence(f(t, ∝u₁, u₂))"
142+
@test repr(incidence) == "Incidence(f(t, ∝ₜu₁, u₂))"
143+
@test incidence == incidence"f(t, ∝ₜu₁, u₂)"
127144

128145
@test_throws "inconsistent with an absence of time incidence" Incidence(Const(0.0), IncidenceValue[0.0, linear_time_dependent])
129146
@test_throws "inconsistent with an absence of state incidence" Incidence(Const(0.0), IncidenceValue[linear_state_dependent])

0 commit comments

Comments
 (0)