Skip to content

Commit 8727724

Browse files
Improve StaticArray + BandedMatrices support (update of #131) (#152)
* Don't promote to `Matrix`. * Add StaticArrays tests, fix A / B::PDiagMat * Improve `StaticArray` support and add tests * Bump version * Fix merge errors * Bump version * Add tests with BandedMatrices * Fix test error with old Julia versions Co-authored-by: chriselrod <[email protected]>
1 parent 087d59f commit 8727724

File tree

5 files changed

+86
-13
lines changed

5 files changed

+86
-13
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "PDMats"
22
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
3-
version = "0.11.6"
3+
version = "0.11.7"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -11,7 +11,9 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
1111
julia = "1"
1212

1313
[extras]
14+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
15+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1416
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1517

1618
[targets]
17-
test = ["Test"]
19+
test = ["BandedMatrices", "StaticArrays", "Test"]

src/pdiagmat.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,25 +139,25 @@ end
139139

140140
### tri products
141141

142-
function X_A_Xt(a::PDiagMat, x::StridedMatrix)
142+
function X_A_Xt(a::PDiagMat, x::AbstractMatrix)
143143
@check_argdims dim(a) == size(x, 2)
144144
z = x .* sqrt.(permutedims(a.diag))
145145
z * transpose(z)
146146
end
147147

148-
function Xt_A_X(a::PDiagMat, x::StridedMatrix)
148+
function Xt_A_X(a::PDiagMat, x::AbstractMatrix)
149149
@check_argdims dim(a) == size(x, 1)
150150
z = x .* sqrt.(a.diag)
151151
transpose(z) * z
152152
end
153153

154-
function X_invA_Xt(a::PDiagMat, x::StridedMatrix)
154+
function X_invA_Xt(a::PDiagMat, x::AbstractMatrix)
155155
@check_argdims dim(a) == size(x, 2)
156156
z = x ./ sqrt.(permutedims(a.diag))
157157
z * transpose(z)
158158
end
159159

160-
function Xt_invA_X(a::PDiagMat, x::StridedMatrix)
160+
function Xt_invA_X(a::PDiagMat, x::AbstractMatrix)
161161
@check_argdims dim(a) == size(x, 1)
162162
z = x ./ sqrt.(a.diag)
163163
transpose(z) * z

src/pdmat.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ function PDMat(mat::AbstractMatrix,chol::Cholesky{T,S}) where {T,S}
1616
PDMat{T,S}(d, convert(S, mat), chol)
1717
end
1818

19-
PDMat(mat::Matrix) = PDMat(mat, cholesky(mat))
20-
PDMat(mat::Symmetric) = PDMat(Matrix(mat))
19+
PDMat(mat::AbstractMatrix) = PDMat(mat, cholesky(mat))
2120
PDMat(fac::Cholesky) = PDMat(Matrix(fac), fac)
2221

2322
### Conversion
@@ -94,25 +93,25 @@ invquad!(r::AbstractArray, a::PDMat, x::StridedMatrix) = colwise_dot!(r, x, a.ma
9493

9594
### tri products
9695

97-
function X_A_Xt(a::PDMat, x::StridedMatrix)
96+
function X_A_Xt(a::PDMat, x::AbstractMatrix)
9897
@check_argdims dim(a) == size(x, 2)
9998
z = x * chol_lower(a.chol)
10099
return z * transpose(z)
101100
end
102101

103-
function Xt_A_X(a::PDMat, x::StridedMatrix)
102+
function Xt_A_X(a::PDMat, x::AbstractMatrix)
104103
@check_argdims dim(a) == size(x, 1)
105104
z = chol_upper(a.chol) * x
106105
return transpose(z) * z
107106
end
108107

109-
function X_invA_Xt(a::PDMat, x::StridedMatrix)
108+
function X_invA_Xt(a::PDMat, x::AbstractMatrix)
110109
@check_argdims dim(a) == size(x, 2)
111110
z = x / chol_upper(a.chol)
112111
return z * transpose(z)
113112
end
114113

115-
function Xt_invA_X(a::PDMat, x::StridedMatrix)
114+
function Xt_invA_X(a::PDMat, x::AbstractMatrix)
116115
@check_argdims dim(a) == size(x, 1)
117116
z = chol_lower(a.chol) \ x
118117
return transpose(z) * z

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
include("testutils.jl")
2-
tests = ["pdmtypes", "addition", "generics", "kron", "chol"]
2+
tests = ["pdmtypes", "addition", "generics", "kron", "chol", "specialarrays"]
33
println("Running tests ...")
44

55
for t in tests

test/specialarrays.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
using BandedMatrices
2+
using StaticArrays
3+
4+
@testset "Special matrix types" begin
5+
@testset "StaticArrays" begin
6+
# Full matrix
7+
S = (x -> x * x')(@SMatrix(randn(4, 7)))
8+
PDS = PDMat(S)
9+
@test PDS isa PDMat{Float64, <:SMatrix{4, 4, Float64}}
10+
@test isbits(PDS)
11+
12+
# Diagonal matrix
13+
D = PDiagMat(@SVector(rand(4)))
14+
@test D isa PDiagMat{Float64, <:SVector{4, Float64}}
15+
16+
x = @SVector rand(4)
17+
X = @SMatrix rand(10, 4)
18+
Y = @SMatrix rand(4, 10)
19+
20+
for A in (PDS, D)
21+
@test A * x isa SVector{4, Float64}
22+
@test A * x Matrix(A) * Vector(x)
23+
24+
@test A * Y isa SMatrix{4, 10, Float64}
25+
@test A * Y Matrix(A) * Matrix(Y)
26+
27+
@test X / A isa SMatrix{10, 4, Float64}
28+
@test X / A Matrix(X) / Matrix(A)
29+
30+
@test A \ x isa SVector{4, Float64}
31+
@test A \ x Matrix(A) \ Vector(x)
32+
33+
@test A \ Y isa SMatrix{4, 10, Float64}
34+
@test A \ Y Matrix(A) \ Matrix(Y)
35+
36+
@test X_A_Xt(A, X) isa SMatrix{10, 10, Float64}
37+
@test X_A_Xt(A, X) Matrix(X) * Matrix(A) * Matrix(X)'
38+
39+
@test X_invA_Xt(A, X) isa SMatrix{10, 10, Float64}
40+
@test X_invA_Xt(A, X) Matrix(X) * (Matrix(A) \ Matrix(X)')
41+
42+
@test Xt_A_X(A, Y) isa SMatrix{10, 10, Float64}
43+
@test Xt_A_X(A, Y) Matrix(Y)' * Matrix(A) * Matrix(Y)
44+
45+
@test Xt_invA_X(A, Y) isa SMatrix{10, 10, Float64}
46+
@test Xt_invA_X(A, Y) Matrix(Y)' * (Matrix(A) \ Matrix(Y))
47+
end
48+
end
49+
50+
@testset "BandedMatrices" begin
51+
# Full matrix
52+
A = Symmetric(BandedMatrix(Eye(5), (1, 1)))
53+
P = PDMat(A)
54+
@test P isa PDMat{Float64, <:BandedMatrix{Float64}}
55+
56+
x = rand(5)
57+
X = rand(2, 5)
58+
Y = rand(5, 2)
59+
@test P * x A * x
60+
@test P * Y A * Y
61+
# Right division with Cholesky requires https://github.com/JuliaLang/julia/pull/32594
62+
if VERSION >= v"1.3.0-DEV.562"
63+
@test X / P X / A
64+
end
65+
@test P \ x A \ x
66+
@test P \ Y A \ Y
67+
@test X_A_Xt(P, X) X * A * X'
68+
@test X_invA_Xt(P, X) X * (A \ X')
69+
@test Xt_A_X(P, Y) Y' * A * Y
70+
@test Xt_invA_X(P, Y) Y' * (A \ Y)
71+
end
72+
end

0 commit comments

Comments
 (0)