Skip to content

Commit 241c9c3

Browse files
committed
Fix tests
1 parent d5889b6 commit 241c9c3

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

src/KroneckerArrays.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ function KroneckerArray(a::AbstractArray, b::AbstractArray)
3838
ArgumentError("Kronecker product requires arrays of the same number of dimensions.")
3939
)
4040
end
41-
return KroneckerArray(Base.promote_eltype(a, b)...)
41+
elt = promote_type(eltype(a), eltype(b))
42+
return KroneckerArray(convert(AbstractArray{elt}, a), convert(AbstractArray{elt}, b))
4243
end
4344
const KroneckerMatrix{T,A<:AbstractMatrix{T},B<:AbstractMatrix{T}} = KroneckerArray{T,2,A,B}
4445
const KroneckerVector{T,A<:AbstractVector{T},B<:AbstractVector{T}} = KroneckerArray{T,1,A,B}
@@ -122,6 +123,9 @@ end
122123
function Base.:(==)(a::KroneckerArray, b::KroneckerArray)
123124
return a.a == b.a && a.b == b.b
124125
end
126+
function Base.isapprox(a::KroneckerArray, b::KroneckerArray; kwargs...)
127+
return isapprox(a.a, b.a; kwargs...) && isapprox(a.b, b.b; kwargs...)
128+
end
125129
function Base.iszero(a::KroneckerArray)
126130
return iszero(a.a) || iszero(a.b)
127131
end
@@ -162,10 +166,8 @@ function diagonal(a::KroneckerArray)
162166
return Diagonal(a.a) Diagonal(a.b)
163167
end
164168

165-
# TODO: Overload `similar` instead?
166-
function LinearAlgebra.matprod_dest(a::KroneckerArray, b::KroneckerArray, elt)
167-
return LinearAlgebra.matprod_dest(a.a, b.a, elt)
168-
LinearAlgebra.matprod_dest(a.b, b.b, elt)
169+
function Base.:*(a::KroneckerArray, b::KroneckerArray)
170+
return (a.a * b.a) (a.b * b.b)
169171
end
170172
function LinearAlgebra.mul!(c::KroneckerArray, a::KroneckerArray, b::KroneckerArray)
171173
mul!(c.a, a.a, b.a)

0 commit comments

Comments
 (0)