Skip to content

Commit 756b014

Browse files
Merge pull request #1346 from AayushSabharwal/as/symbolic-type
fix: fix `symbolic_type` for unwrapped array symbolics
2 parents 1305b7e + a5dfae8 commit 756b014

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

src/variable.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,14 @@ getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val)
489489

490490
SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic()
491491
SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic()
492+
function SymbolicIndexingInterface.symbolic_type(::Type{T}) where {S <: AbstractArray, T <: Symbolic{S}}
493+
ArraySymbolic()
494+
end
495+
# need this otherwise the `::Type{<:BasicSymbolic}` method in SymbolicUtils is
496+
# more specific
497+
function SymbolicIndexingInterface.symbolic_type(::Type{T}) where {S <: AbstractArray, T <: BasicSymbolic{S}}
498+
ArraySymbolic()
499+
end
492500

493501
SymbolicIndexingInterface.hasname(x::Union{Num,Arr}) = hasname(unwrap(x))
494502

test/symbolic_indexing_interface_trait.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using SymbolicIndexingInterface
1010
@test symbolic_type(x) == ScalarSymbolic()
1111
@variables y[1:3]
1212
@test symbolic_type(y) == ArraySymbolic()
13+
@test symbolic_type(Symbolics.unwrap(y)) == ArraySymbolic()
1314
@test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),))
1415
@test symbolic_type(Symbolics.unwrap(y .* y)) == ArraySymbolic()
1516
@variables z(..)

0 commit comments

Comments
 (0)