diff --git a/Project.toml b/Project.toml index f12aa4703..167300dc7 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" @@ -79,6 +80,7 @@ Lux = "1" MacroTools = "0.5" NaNMath = "1" Nemo = "0.46, 0.47, 0.48" +OffsetArrays = "1.15.0" PreallocationTools = "0.4" PrecompileTools = "1" Primes = "0.5" diff --git a/src/Symbolics.jl b/src/Symbolics.jl index 55b98be20..5f28010f0 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -47,6 +47,8 @@ import SymbolicLimits using ADTypes: ADTypes +import OffsetArrays + @reexport using SymbolicUtils RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 821a77322..8f30fb095 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -276,6 +276,13 @@ function _linear_expansion(t, x) op, args = operation(t), arguments(t) expansion_check(op) + if iscall(x) && operation(x) == getindex + arrx, idxsx... = arguments(x) + else + arrx = nothing + idxsx = nothing + end + if op === (+) a₁ = b₁ = 0 islinear = true @@ -318,8 +325,25 @@ function _linear_expansion(t, x) a₁, b₁, islinear = linear_expansion(args[1], x) # (a₁ x + b₁)/b₂ return islinear ? (a₁ / b₂, b₁ / b₂, islinear) : (0, 0, false) + elseif op === getindex + arrt, idxst... = arguments(t) + isequal(arrt, arrx) && return (0, t, true) + shape(arrt) == Unknown() && return (0, t, true) + + indexed_t = OffsetArrays.Origin(map(first, axes(arrt)))(Symbolics.scalarize(arrt))[idxst...] + # when indexing a registered function/callable symbolic + # scalarizing and indexing leads to the same symbolic variable + # which causes a StackOverflowError without this + isequal(t, indexed_t) && return (0, t, true) + return linear_expansion(Symbolics.scalarize(arrt)[idxst...], x) else for (i, arg) in enumerate(args) + isequal(arg, arrx) && return (0, 0, false) + if symbolic_type(arg) == NotSymbolic() + arg isa AbstractArray || continue + _occursin_array(x, arrx, arg) && return (0, 0, false) + continue + end a, b, islinear = linear_expansion(arg, x) (_iszero(a) && islinear) || return (0, 0, false) end @@ -327,6 +351,22 @@ function _linear_expansion(t, x) end end +""" + _occursin_array(sym, arrsym, arr) + +Check if `sym` (or, if `sym` is an element of an array symbolic, the array symbolic +`arrsym`) occursin in the non-symbolic array `arr`. +""" +function _occursin_array(sym, arrsym, arr) + for el in arr + if symbolic_type(el) == NotSymbolic() + return el isa AbstractArray && _occursin_array(sym, arrsym, el) + else + return occursin(sym, el) || occursin(arrsym, el) + end + end +end + ### ### Utilities ### diff --git a/test/linear_solver.jl b/test/linear_solver.jl index 64be2db7b..e3f038af8 100644 --- a/test/linear_solver.jl +++ b/test/linear_solver.jl @@ -59,3 +59,26 @@ a, b, islinear = Symbolics.linear_expansion(D(x) - x, x) @test islinear @test isequal(a, -1) @test isequal(b, D(x)) + +@testset "linear_expansion with array variables" begin + @variables x[1:2] y[1:2] z(..) + @test !Symbolics.linear_expansion(z(x) + x[1], x[1])[3] + @test !Symbolics.linear_expansion(z(x[1]) + x[1], x[1])[3] + a, b, islin = Symbolics.linear_expansion(z(x[2]) + x[1], x[1]) + @test islin && isequal(a, 1) && isequal(b, z(x[2])) + a, b, islin = Symbolics.linear_expansion((x + x)[1], x[1]) + @test islin && isequal(a, 2) && isequal(b, 0) + a, b, islin = Symbolics.linear_expansion(y[1], x[1]) + @test islin && isequal(a, 0) && isequal(b, y[1]) + @test !Symbolics.linear_expansion(z([x...]), x[1])[3] + @test !Symbolics.linear_expansion(z(collect(Symbolics.unwrap(x))), x[1])[3] + @test !Symbolics.linear_expansion(z([x, 2x]), x[1])[3] + + @variables x[0:2] + a, b, islin = Symbolics.linear_expansion(x[0] - z(x[1]), z(x[1])) + @test islin && isequal(a, -1) && isequal(b, x[0]) + + @variables x::Vector{Real} + a, b, islin = Symbolics.linear_expansion(x[0] - z(x[1]), z(x[1])) + @test islin && isequal(a, -1) && isequal(b, x[0]) +end