Skip to content

Commit f9a46cf

Browse files
Merge pull request #1411 from AayushSabharwal/as/arr-linear_expansion
feat: support array variables in `linear_expansion`
2 parents 372783a + d294b0e commit f9a46cf

File tree

4 files changed

+67
-0
lines changed

4 files changed

+67
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
2525
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2626
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
2727
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
28+
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2829
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2930
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
3031
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
@@ -79,6 +80,7 @@ Lux = "1"
7980
MacroTools = "0.5"
8081
NaNMath = "1"
8182
Nemo = "0.46, 0.47, 0.48"
83+
OffsetArrays = "1.15.0"
8284
PreallocationTools = "0.4"
8385
PrecompileTools = "1"
8486
Primes = "0.5"

src/Symbolics.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ import SymbolicLimits
4747

4848
using ADTypes: ADTypes
4949

50+
import OffsetArrays
51+
5052
@reexport using SymbolicUtils
5153
RuntimeGeneratedFunctions.init(@__MODULE__)
5254

src/linear_algebra.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,13 @@ function _linear_expansion(t, x)
276276
op, args = operation(t), arguments(t)
277277
expansion_check(op)
278278

279+
if iscall(x) && operation(x) == getindex
280+
arrx, idxsx... = arguments(x)
281+
else
282+
arrx = nothing
283+
idxsx = nothing
284+
end
285+
279286
if op === (+)
280287
a₁ = b₁ = 0
281288
islinear = true
@@ -318,15 +325,48 @@ function _linear_expansion(t, x)
318325
a₁, b₁, islinear = linear_expansion(args[1], x)
319326
# (a₁ x + b₁)/b₂
320327
return islinear ? (a₁ / b₂, b₁ / b₂, islinear) : (0, 0, false)
328+
elseif op === getindex
329+
arrt, idxst... = arguments(t)
330+
isequal(arrt, arrx) && return (0, t, true)
331+
shape(arrt) == Unknown() && return (0, t, true)
332+
333+
indexed_t = OffsetArrays.Origin(map(first, axes(arrt)))(Symbolics.scalarize(arrt))[idxst...]
334+
# when indexing a registered function/callable symbolic
335+
# scalarizing and indexing leads to the same symbolic variable
336+
# which causes a StackOverflowError without this
337+
isequal(t, indexed_t) && return (0, t, true)
338+
return linear_expansion(Symbolics.scalarize(arrt)[idxst...], x)
321339
else
322340
for (i, arg) in enumerate(args)
341+
isequal(arg, arrx) && return (0, 0, false)
342+
if symbolic_type(arg) == NotSymbolic()
343+
arg isa AbstractArray || continue
344+
_occursin_array(x, arrx, arg) && return (0, 0, false)
345+
continue
346+
end
323347
a, b, islinear = linear_expansion(arg, x)
324348
(_iszero(a) && islinear) || return (0, 0, false)
325349
end
326350
return (0, t, true)
327351
end
328352
end
329353

354+
"""
355+
_occursin_array(sym, arrsym, arr)
356+
357+
Check if `sym` (or, if `sym` is an element of an array symbolic, the array symbolic
358+
`arrsym`) occursin in the non-symbolic array `arr`.
359+
"""
360+
function _occursin_array(sym, arrsym, arr)
361+
for el in arr
362+
if symbolic_type(el) == NotSymbolic()
363+
return el isa AbstractArray && _occursin_array(sym, arrsym, el)
364+
else
365+
return occursin(sym, el) || occursin(arrsym, el)
366+
end
367+
end
368+
end
369+
330370
###
331371
### Utilities
332372
###

test/linear_solver.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,26 @@ a, b, islinear = Symbolics.linear_expansion(D(x) - x, x)
5959
@test islinear
6060
@test isequal(a, -1)
6161
@test isequal(b, D(x))
62+
63+
@testset "linear_expansion with array variables" begin
64+
@variables x[1:2] y[1:2] z(..)
65+
@test !Symbolics.linear_expansion(z(x) + x[1], x[1])[3]
66+
@test !Symbolics.linear_expansion(z(x[1]) + x[1], x[1])[3]
67+
a, b, islin = Symbolics.linear_expansion(z(x[2]) + x[1], x[1])
68+
@test islin && isequal(a, 1) && isequal(b, z(x[2]))
69+
a, b, islin = Symbolics.linear_expansion((x + x)[1], x[1])
70+
@test islin && isequal(a, 2) && isequal(b, 0)
71+
a, b, islin = Symbolics.linear_expansion(y[1], x[1])
72+
@test islin && isequal(a, 0) && isequal(b, y[1])
73+
@test !Symbolics.linear_expansion(z([x...]), x[1])[3]
74+
@test !Symbolics.linear_expansion(z(collect(Symbolics.unwrap(x))), x[1])[3]
75+
@test !Symbolics.linear_expansion(z([x, 2x]), x[1])[3]
76+
77+
@variables x[0:2]
78+
a, b, islin = Symbolics.linear_expansion(x[0] - z(x[1]), z(x[1]))
79+
@test islin && isequal(a, -1) && isequal(b, x[0])
80+
81+
@variables x::Vector{Real}
82+
a, b, islin = Symbolics.linear_expansion(x[0] - z(x[1]), z(x[1]))
83+
@test islin && isequal(a, -1) && isequal(b, x[0])
84+
end

0 commit comments

Comments
 (0)