Skip to content

Commit 3d8ce96

Browse files
Merge pull request #1398 from AayushSabharwal/as/fast-substitute-array
fix: fix `fast_substitute` folding array of symbolics
2 parents eb3b5f6 + 9271d62 commit 3d8ce96

File tree

5 files changed

+22
-7
lines changed

5 files changed

+22
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ StaticArraysCore = "1.4"
9292
SymPy = "2.2"
9393
SymbolicIndexingInterface = "0.3.14"
9494
SymbolicLimits = "0.2.2"
95-
SymbolicUtils = "3.7"
95+
SymbolicUtils = "3.10"
9696
TermInterface = "2"
9797
julia = "1.10"
9898

src/solver/preprocess.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ function _filter_poly(expr, var)
135135
subs[i_var] = im
136136
expr = unwrap(expr1 + i_var * expr2)
137137

138-
args = arguments(expr)
138+
args = map(unwrap, arguments(expr))
139139
oper = operation(expr)
140140
return subs, term(oper, args...)
141141
end
@@ -208,7 +208,7 @@ function _filter_poly(expr, var)
208208
end
209209
end
210210

211-
args = arguments(expr)
211+
args = map(unwrap, arguments(expr))
212212
oper = operation(expr)
213213
expr = term(oper, args...)
214214
return subs, expr

src/variable.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ function fast_substitute(expr, subs; operator = Nothing)
606606
args = let canfold = canfold
607607
map(args) do x
608608
x′ = fast_substitute(x, subs; operator)
609-
canfold[] = canfold[] && !(x′ isa Symbolic)
609+
canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′))
610610
x′
611611
end
612612
end
@@ -633,7 +633,7 @@ function fast_substitute(expr, pair::Pair; operator = Nothing)
633633
args = let canfold = canfold
634634
map(args) do x
635635
x′ = fast_substitute(x, pair; operator)
636-
canfold[] = canfold[] && !(x′ isa Symbolic)
636+
canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′))
637637
x′
638638
end
639639
end
@@ -645,6 +645,13 @@ function fast_substitute(expr, pair::Pair; operator = Nothing)
645645
metadata(expr))
646646
end
647647

648+
function is_array_of_symbolics(x)
649+
symbolic_type(x) == ArraySymbolic() && return true
650+
symbolic_type(x) == ScalarSymbolic() && return false
651+
x isa AbstractArray &&
652+
any(y -> symbolic_type(y) != NotSymbolic() || is_array_of_symbolics(y), x)
653+
end
654+
648655
function getparent(x, val=_fail)
649656
maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing)
650657
if maybe_parent !== nothing

test/arrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ end
388388
lapu = wrap(lapu)
389389
lapv = wrap(lapv)
390390

391-
f, g = build_function(dtu, u, v, t, expression=Val{false})
391+
f, g = build_function(dtu, u, v, t, expression=Val{false}, nanmath = false)
392392
du = zeros(Num, 8, 8)
393393
f(du, u,v,t)
394394
@test isequal(collect(du), collect(dtu))

test/utils.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,4 +153,12 @@ end
153153
test_nested_derivative = Dx(Dt(Dt(u)))
154154
result = diff2term(Symbolics.value(test_nested_derivative))
155155
@test typeof(result) === Symbolics.BasicSymbolic{Real}
156-
end
156+
end
157+
158+
@testset "`fast_substitute` inside array symbolics" begin
159+
@variables x y z
160+
@register_symbolic foo(a::AbstractArray, b)
161+
ex = foo([x, y], z)
162+
ex2 = Symbolics.fixpoint_sub(ex, Dict(y => 1.0, z => 2.0))
163+
@test isequal(ex2, foo([x, 1.0], 2.0))
164+
end

0 commit comments

Comments
 (0)