Skip to content

Commit a6aa223

Browse files
committed
Fix more tests
1 parent 727ae88 commit a6aa223

File tree

3 files changed

+70
-20
lines changed

3 files changed

+70
-20
lines changed

src/fillarrays/linearalgebra.jl

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,75 @@
1+
using FillArrays: Eye, SquareEye
2+
using LinearAlgebra: LinearAlgebra, mul!, pinv
3+
4+
function check_mul_axes(a::AbstractMatrix, b::AbstractMatrix)
5+
return axes(a, 2) == axes(b, 1) || throw(DimensionMismatch("Incompatible matrix sizes."))
6+
end
7+
8+
function _mul(a::Eye, b::Eye)
9+
check_mul_axes(a, b)
10+
T = promote_type(eltype(a), eltype(b))
11+
return Eye{T}((axes(a, 1), axes(b, 2)))
12+
end
13+
function _mul(a::SquareEye, b::SquareEye)
14+
check_mul_axes(a, b)
15+
return Diagonal(diagview(a) .* diagview(b))
16+
end
17+
118
for f in MATRIX_FUNCTIONS
219
@eval begin
3-
function Base.$f(a::SquareEyeKronecker)
20+
function Base.$f(a::EyeKronecker)
21+
LinearAlgebra.checksquare(a.a)
422
return a.a $f(a.b)
523
end
6-
function Base.$f(a::KroneckerSquareEye)
24+
function Base.$f(a::KroneckerEye)
25+
LinearAlgebra.checksquare(a.b)
726
return $f(a.a) a.b
827
end
9-
function Base.$f(a::SquareEyeSquareEye)
28+
function Base.$f(a::EyeEye)
29+
LinearAlgebra.checksquare(a)
1030
return throw(ArgumentError("`$($f)` on `Eye ⊗ Eye` is not supported."))
1131
end
1232
end
1333
end
1434

15-
function LinearAlgebra.pinv(a::SquareEyeKronecker; kwargs...)
35+
function LinearAlgebra.mul!(
36+
c::EyeKronecker, a::EyeKronecker, b::EyeKronecker, α::Number, β::Number
37+
)
38+
iszero(β) ||
39+
iszero(c) ||
40+
throw(
41+
ArgumentError(
42+
"Can't multiple KroneckerArrays with nonzero β and nonzero destination."
43+
),
44+
)
45+
check_mul_axes(a.a, b.a)
46+
mul!(c.b, a.b, b.b, α, β)
47+
return c
48+
end
49+
function LinearAlgebra.mul!(
50+
c::KroneckerEye, a::KroneckerEye, b::KroneckerEye, α::Number, β::Number
51+
)
52+
iszero(β) ||
53+
iszero(c) ||
54+
throw(
55+
ArgumentError(
56+
"Can't multiple KroneckerArrays with nonzero β and nonzero destination."
57+
),
58+
)
59+
check_mul_axes(a.b, b.b)
60+
mul!(c.a, a.a, b.a, α, β)
61+
return c
62+
end
63+
function LinearAlgebra.mul!(c::EyeEye, a::EyeEye, b::EyeEye, α::Number, β::Number)
64+
return throw(ArgumentError("Can't multiple `Eye ⊗ Eye` in-place."))
65+
end
66+
67+
function LinearAlgebra.pinv(a::EyeKronecker; kwargs...)
1668
return a.a pinv(a.b; kwargs...)
1769
end
18-
function LinearAlgebra.pinv(a::KroneckerSquareEye; kwargs...)
70+
function LinearAlgebra.pinv(a::KroneckerEye; kwargs...)
1971
return pinv(a.a; kwargs...) a.b
2072
end
21-
function LinearAlgebra.pinv(a::SquareEyeSquareEye; kwargs...)
73+
function LinearAlgebra.pinv(a::EyeEye; kwargs...)
2274
return a
2375
end

src/linearalgebra.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ function LinearAlgebra.diag(a::KroneckerArray)
2424
return copy(diagview(a))
2525
end
2626

27+
# Allows customizing multiplication for specific types
28+
# such as `Eye * Eye`, which doesn't return `Eye`.
2729
function _mul(a::AbstractArray, b::AbstractArray)
2830
return a * b
2931
end
@@ -32,13 +34,6 @@ function Base.:*(a::KroneckerArray, b::KroneckerArray)
3234
return _mul(a.a, b.a) _mul(a.b, b.b)
3335
end
3436

35-
function _mul!!(c::AbstractArray, a::AbstractArray, b::AbstractArray)
36-
return _mul!!(c, a, b, true, false)
37-
end
38-
function _mul!!(c::AbstractArray, a::AbstractArray, b::AbstractArray, α::Number, β::Number)
39-
return mul!(c, a, b, true, false)
40-
end
41-
4237
function LinearAlgebra.mul!(
4338
c::KroneckerArray, a::KroneckerArray, b::KroneckerArray, α::Number, β::Number
4439
)
@@ -49,13 +44,15 @@ function LinearAlgebra.mul!(
4944
"Can't multiple KroneckerArrays with nonzero β and nonzero destination."
5045
),
5146
)
52-
_mul!!(c.a, a.a, b.a)
53-
_mul!!(c.b, a.b, b.b, α, β)
47+
mul!(c.a, a.a, b.a)
48+
mul!(c.b, a.b, b.b, α, β)
5449
return c
5550
end
51+
5652
function LinearAlgebra.tr(a::KroneckerArray)
5753
return tr(a.a) tr(a.b)
5854
end
55+
5956
function LinearAlgebra.norm(a::KroneckerArray, p::Int=2)
6057
return norm(a.a, p) norm(a.b, p)
6158
end

test/test_blocksparsearrays.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,9 @@ end
102102
@test a[Block(1, 2)] == dev(Eye(2, 3) zeros(elt, 2, 3))
103103
@test a[Block(1, 2)] isa valtype(d)
104104

105-
@test_broken b = a * a
106-
## b = a * a
107-
## @test typeof(b) === typeof(a)
108-
## @test Array(b) ≈ Array(a) * Array(a)
105+
b = @constinferred a * a
106+
@test typeof(b) === typeof(a)
107+
@test Array(b) Array(a) * Array(a)
109108

110109
# Type inference is broken for this operation.
111110
# b = @constinferred a + a
@@ -127,9 +126,11 @@ end
127126

128127
@test @constinferred(norm(a)) norm(Array(a))
129128

129+
b = @constinferred exp(a)
130+
@test Array(b) exp(Array(a))
131+
130132
# Broken operations
131133
@test_broken inv(a)
132-
@test_broken exp(a)
133134
@test_broken svd_compact(a)
134135
@test_broken a[Block.(1:2), Block(2)]
135136
end

0 commit comments

Comments
 (0)