Skip to content

Commit d67cf8d

Browse files
fix: handle non-standard indices in linear_expansion
1 parent 23b5604 commit d67cf8d

File tree

4 files changed

+9
-1
lines changed

4 files changed

+9
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
2525
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2626
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
2727
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
28+
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2829
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
2930
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
3031
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
@@ -79,6 +80,7 @@ Lux = "1"
7980
MacroTools = "0.5"
8081
NaNMath = "1"
8182
Nemo = "0.46, 0.47, 0.48"
83+
OffsetArrays = "1.15.0"
8284
PreallocationTools = "0.4"
8385
PrecompileTools = "1"
8486
Primes = "0.5"

src/Symbolics.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ import SymbolicLimits
4747

4848
using ADTypes: ADTypes
4949

50+
import OffsetArrays
51+
5052
@reexport using SymbolicUtils
5153
RuntimeGeneratedFunctions.init(@__MODULE__)
5254

src/linear_algebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ function _linear_expansion(t, x)
329329
arrt, idxst... = arguments(t)
330330
isequal(arrt, arrx) && return (0, t, true)
331331

332-
indexed_t = Symbolics.scalarize(arrt)[idxst...]
332+
indexed_t = OffsetArrays.Origin(map(first, axes(arrt)))(Symbolics.scalarize(arrt))[idxst...]
333333
# when indexing a registered function/callable symbolic
334334
# scalarizing and indexing leads to the same symbolic variable
335335
# which causes a StackOverflowError without this

test/linear_solver.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,8 @@ a, b, islinear = Symbolics.linear_expansion(D(x) - x, x)
7373
@test !Symbolics.linear_expansion(z([x...]), x[1])[3]
7474
@test !Symbolics.linear_expansion(z(collect(Symbolics.unwrap(x))), x[1])[3]
7575
@test !Symbolics.linear_expansion(z([x, 2x]), x[1])[3]
76+
77+
@variables x[0:2]
78+
a, b, islin = Symbolics.linear_expansion(x[0] - z(x[1]), z(x[1]))
79+
@test islin && isequal(a, -1) && isequal(b, x[0])
7680
end

0 commit comments

Comments
 (0)