Skip to content

Commit 58f3856

Browse files
fix: fix promote_symtype for non-Array arrays
1 parent a97c2de commit 58f3856

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

src/methods.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,16 @@ for f in vcat(diadic, [+, -, *, ^, Base.add_sum, Base.mul_prod])
128128
return S
129129
elseif S <: T
130130
return T
131+
elseif T <: AbstractArray && !(T <: Array)
132+
return promote_symtype($f, Array{eltype(T), ndims(T)::Int}, S)
133+
elseif S <: AbstractArray && !(S <: Array)
134+
return promote_symtype($f, T, Array{eltype(S), ndims(S)::Int})
131135
elseif $(f === (*) || f === Base.mul_prod) && T <: AbstractMatrix && S <: AbstractVecOrMat
132-
return Array{promote_symtype(*, T.parameters[1]::TypeT, S.parameters[1]::TypeT), S.parameters[2]}
136+
return Array{promote_symtype(*, T.parameters[1]::TypeT, S.parameters[1]::TypeT), S.parameters[2]::Int}
133137
elseif $(f === (*) || f === Base.mul_prod) && T <: AbstractArray && S <: Number
134-
return Array{promote_symtype(*, T.parameters[1]::TypeT, S), T.parameters[2]}
138+
return Array{promote_symtype(*, T.parameters[1]::TypeT, S), T.parameters[2]::Int}
135139
elseif $(f === (*) || f === Base.mul_prod) && T <: Number && S <: AbstractArray
136-
return Array{promote_symtype(*, T, S.parameters[1]::TypeT), S.parameters[2]}
140+
return Array{promote_symtype(*, T, S.parameters[1]::TypeT), S.parameters[2]::Int}
137141
elseif $(f === (+) || f === Base.add_sum || f === (-)) && T <: AbstractArray && S <: AbstractArray
138142
nd = T.parameters[2]::Int
139143
@assert nd == S.parameters[2]::Int
@@ -191,6 +195,10 @@ for f in [/, \]
191195
return Real
192196
elseif T <: Rational && S <: Integer
193197
return Real
198+
elseif T <: AbstractArray && !(T <: Array)
199+
return promote_symtype($f, Array{eltype(T), ndims(T)::Int}, S)
200+
elseif S <: AbstractArray && !(S <: Array)
201+
return promote_symtype($f, T, Array{eltype(S), ndims(S)::Int})
194202
elseif $(f === (\)) && T <: Number && S <: AbstractArray
195203
return Array{promote_symtype(/, T, S.parameters[1]::TypeT), S.parameters[2]::Int}
196204
elseif $(f === (\)) && T <: AbstractVector && S <: AbstractVector

test/basics.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import MultivariatePolynomials as MP
55
using Setfield
66
using Test, ReferenceTests
77
import LinearAlgebra
8+
using SparseArrays
89

910
include("utils.jl")
1011

@@ -1179,3 +1180,12 @@ end
11791180
end
11801181
end
11811182
end
1183+
1184+
@testset "promote_symtype with sparse operations" begin
1185+
S = sprand(5, 5, 0.1)
1186+
@syms z[1:5, 1:5]::Real
1187+
@test SymbolicUtils.promote_symtype(*, typeof(S), symtype(z)) == Matrix{Real}
1188+
@test SymbolicUtils.promote_symtype(+, typeof(S), symtype(z)) == Matrix{Real}
1189+
@test SymbolicUtils.promote_symtype(\, typeof(S), symtype(z)) == Matrix{Real}
1190+
@test SymbolicUtils.promote_symtype(/, typeof(S), symtype(z)) == Matrix{Real}
1191+
end

0 commit comments

Comments
 (0)