Skip to content

Commit 2f3a83e

Browse files
authored
Remove inv_value and inv_diag (#146)
* Remove `inv_value` and `inv_diag` * Fix compatibility with older Julia versions * Add test * Update README.md
1 parent f424507 commit 2f3a83e

File tree

6 files changed

+89
-48
lines changed

6 files changed

+89
-48
lines changed

README.md

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,14 @@ PDMat(chol) # with the Cholesky factorization
5050
* `PDiagMat`: diagonal matrix, defined as
5151

5252
```julia
53-
struct PDiagMat{T<:Real,V<:AbstractVector} <: AbstractPDMat{T}
53+
struct PDiagMat{T<:Real,V<:AbstractVector{T}} <: AbstractPDMat{T}
5454
dim::Int # matrix dimension
5555
diag::V # the vector of diagonal elements
56-
inv_diag::V # the element-wise inverse of diag
5756
end
5857

5958
# Constructors
6059

61-
PDiagMat(v,inv_v) # with the vector of diagonal elements and its inverse
6260
PDiagMat(v) # with the vector of diagonal elements
63-
# inv_diag will be computed upon construction
6461
```
6562

6663

@@ -70,13 +67,11 @@ PDiagMat(v) # with the vector of diagonal elements
7067
struct ScalMat{T<:Real} <: AbstractPDMat{T}
7168
dim::Int # matrix dimension
7269
value::T # diagonal value (shared by all diagonal elements)
73-
inv_value::T # inv(value)
7470
end
7571

7672

7773
# Constructors
7874

79-
ScalMat(d, v, inv_v) # with dimension d, diagonal value v and its inverse inv_v
8075
ScalMat(d, v) # with dimension d and diagonal value v
8176
```
8277

src/deprecates.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,6 @@ using Base: @deprecate
1111
@deprecate full(x::AbstractPDMat) Matrix(x)
1212

1313
@deprecate CholType Cholesky
14+
15+
@deprecate ScalMat(d::Int, x::Real, inv_x::Real) ScalMat(d, x)
16+
@deprecate PDiagMat(v::AbstractVector, inv_v::AbstractVector) PDiagMat(v)

src/pdiagmat.jl

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,12 @@
11
"""
22
Positive definite diagonal matrix.
33
"""
4-
struct PDiagMat{T<:Real,V<:AbstractVector} <: AbstractPDMat{T}
4+
struct PDiagMat{T<:Real,V<:AbstractVector{T}} <: AbstractPDMat{T}
55
dim::Int
66
diag::V
7-
inv_diag::V
8-
9-
PDiagMat{T,S}(d::Int,v::AbstractVector,inv_v::AbstractVector) where {T,S} =
10-
new{T,S}(d,v,inv_v)
11-
end
12-
13-
function PDiagMat(v::AbstractVector,inv_v::AbstractVector)
14-
@check_argdims length(v) == length(inv_v)
15-
PDiagMat{eltype(v),typeof(v)}(length(v), v, inv_v)
167
end
178

18-
PDiagMat(v::AbstractVector) = PDiagMat(v, inv.(v))
9+
PDiagMat(v::AbstractVector{<:Real}) = PDiagMat{eltype(v),typeof(v)}(length(v), v)
1910

2011
### Conversion
2112
Base.convert(::Type{PDiagMat{T}}, a::PDiagMat) where {T<:Real} = PDiagMat(convert(AbstractArray{T}, a.diag))
@@ -51,13 +42,13 @@ end
5142
*(a::PDiagMat, c::T) where {T<:Real} = PDiagMat(a.diag * c)
5243
*(a::PDiagMat, x::AbstractVector) = a.diag .* x
5344
*(a::PDiagMat, x::AbstractMatrix) = a.diag .* x
54-
\(a::PDiagMat, x::AbstractVecOrMat) = a.inv_diag .* x
55-
/(x::AbstractVecOrMat, a::PDiagMat) = a.inv_diag .* x
45+
\(a::PDiagMat, x::AbstractVecOrMat) = x ./ a.diag
46+
/(x::AbstractVecOrMat, a::PDiagMat) = x ./ a.diag
5647
Base.kron(A::PDiagMat, B::PDiagMat) = PDiagMat( vcat([A.diag[i] * B.diag for i in 1:dim(A)]...) )
5748

5849
### Algebra
5950

60-
Base.inv(a::PDiagMat) = PDiagMat(a.inv_diag, a.diag)
51+
Base.inv(a::PDiagMat) = PDiagMat(map(inv, a.diag))
6152
function LinearAlgebra.logdet(a::PDiagMat)
6253
diag = a.diag
6354
return isempty(diag) ? zero(log(zero(eltype(diag)))) : sum(log, diag)
@@ -71,9 +62,9 @@ LinearAlgebra.eigmin(a::PDiagMat) = minimum(a.diag)
7162
function whiten!(r::StridedVector, a::PDiagMat, x::StridedVector)
7263
n = dim(a)
7364
@check_argdims length(r) == length(x) == n
74-
v = a.inv_diag
65+
v = a.diag
7566
for i = 1:n
76-
r[i] = x[i] * sqrt(v[i])
67+
r[i] = x[i] / sqrt(v[i])
7768
end
7869
return r
7970
end
@@ -88,17 +79,21 @@ function unwhiten!(r::StridedVector, a::PDiagMat, x::StridedVector)
8879
return r
8980
end
9081

91-
whiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix) =
92-
broadcast!(*, r, x, sqrt.(a.inv_diag))
82+
function whiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix)
83+
r .= x ./ sqrt.(a.diag)
84+
return r
85+
end
9386

94-
unwhiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix) =
95-
broadcast!(*, r, x, sqrt.(a.diag))
87+
function unwhiten!(r::StridedMatrix, a::PDiagMat, x::StridedMatrix)
88+
r .= x .* sqrt.(a.diag)
89+
return r
90+
end
9691

9792

9893
### quadratic forms
9994

10095
quad(a::PDiagMat, x::AbstractVector) = wsumsq(a.diag, x)
101-
invquad(a::PDiagMat, x::AbstractVector) = wsumsq(a.inv_diag, x)
96+
invquad(a::PDiagMat, x::AbstractVector) = invwsumsq(a.diag, x)
10297

10398
function quad!(r::AbstractArray, a::PDiagMat, x::StridedMatrix)
10499
m, n = size(x)
@@ -116,12 +111,12 @@ end
116111

117112
function invquad!(r::AbstractArray, a::PDiagMat, x::StridedMatrix)
118113
m, n = size(x)
119-
ainvd = a.inv_diag
120-
@check_argdims m == length(ainvd) && length(r) == n
114+
ad = a.diag
115+
@check_argdims m == length(ad) && length(r) == n
121116
@inbounds for j = 1:n
122-
s = zero(promote_type(eltype(ainvd), eltype(x)))
117+
s = zero(zero(eltype(x)) / zero(eltype(ad)))
123118
for i in 1:m
124-
s += ainvd[i] * abs2(x[i,j])
119+
s += abs2(x[i,j]) / ad[i]
125120
end
126121
r[j] = s
127122
end
@@ -132,7 +127,7 @@ end
132127
### tri products
133128

134129
function X_A_Xt(a::PDiagMat, x::StridedMatrix)
135-
z = x .* reshape(sqrt.(a.diag), 1, dim(a))
130+
z = x .* sqrt.(reshape(a.diag, 1, dim(a)))
136131
z * transpose(z)
137132
end
138133

@@ -142,11 +137,11 @@ function Xt_A_X(a::PDiagMat, x::StridedMatrix)
142137
end
143138

144139
function X_invA_Xt(a::PDiagMat, x::StridedMatrix)
145-
z = x .* reshape(sqrt.(a.inv_diag), 1, dim(a))
140+
z = x ./ sqrt.(reshape(a.diag, 1, dim(a)))
146141
z * transpose(z)
147142
end
148143

149144
function Xt_invA_X(a::PDiagMat, x::StridedMatrix)
150-
z = x .* sqrt.(a.inv_diag)
145+
z = x ./ sqrt.(a.diag)
151146
transpose(z) * z
152147
end

src/scalmat.jl

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@ Scaling matrix.
44
struct ScalMat{T<:Real} <: AbstractPDMat{T}
55
dim::Int
66
value::T
7-
inv_value::T
87
end
98

10-
ScalMat(d::Int,v::Real) = ScalMat{typeof(inv(v))}(d, v, inv(v))
11-
129
### Conversion
13-
Base.convert(::Type{ScalMat{T}}, a::ScalMat) where {T<:Real} = ScalMat(a.dim, T(a.value), T(a.inv_value))
10+
Base.convert(::Type{ScalMat{T}}, a::ScalMat) where {T<:Real} = ScalMat(a.dim, T(a.value))
1411
Base.convert(::Type{AbstractArray{T}}, a::ScalMat) where {T<:Real} = convert(ScalMat{T}, a)
1512

1613
### Basics
@@ -44,13 +41,13 @@ end
4441
/(a::ScalMat{T}, c::T) where {T<:Real} = ScalMat(a.dim, a.value / c)
4542
*(a::ScalMat, x::AbstractVector) = a.value * x
4643
*(a::ScalMat, x::AbstractMatrix) = a.value * x
47-
\(a::ScalMat, x::AbstractVecOrMat) = a.inv_value * x
48-
/(x::AbstractVecOrMat, a::ScalMat) = a.inv_value * x
44+
\(a::ScalMat, x::AbstractVecOrMat) = x / a.value
45+
/(x::AbstractVecOrMat, a::ScalMat) = x / a.value
4946
Base.kron(A::ScalMat, B::ScalMat) = ScalMat( dim(A) * dim(B), A.value * B.value )
5047

5148
### Algebra
5249

53-
Base.inv(a::ScalMat) = ScalMat(a.dim, a.inv_value, a.value)
50+
Base.inv(a::ScalMat) = ScalMat(a.dim, inv(a.value))
5451
LinearAlgebra.logdet(a::ScalMat) = a.dim * log(a.value)
5552
LinearAlgebra.eigmax(a::ScalMat) = a.value
5653
LinearAlgebra.eigmin(a::ScalMat) = a.value
@@ -60,7 +57,7 @@ LinearAlgebra.eigmin(a::ScalMat) = a.value
6057

6158
function whiten!(r::StridedVecOrMat, a::ScalMat, x::StridedVecOrMat)
6259
@check_argdims dim(a) == size(x, 1)
63-
mul!(r, x, sqrt(a.inv_value))
60+
_ldiv!(r, sqrt(a.value), x)
6461
end
6562

6663
function unwhiten!(r::StridedVecOrMat, a::ScalMat, x::StridedVecOrMat)
@@ -72,10 +69,10 @@ end
7269
### quadratic forms
7370

7471
quad(a::ScalMat, x::AbstractVector) = sum(abs2, x) * a.value
75-
invquad(a::ScalMat, x::AbstractVector) = sum(abs2, x) * a.inv_value
72+
invquad(a::ScalMat, x::AbstractVector) = sum(abs2, x) / a.value
7673

7774
quad!(r::AbstractArray, a::ScalMat, x::StridedMatrix) = colwise_sumsq!(r, x, a.value)
78-
invquad!(r::AbstractArray, a::ScalMat, x::StridedMatrix) = colwise_sumsq!(r, x, a.inv_value)
75+
invquad!(r::AbstractArray, a::ScalMat, x::StridedMatrix) = colwise_sumsqinv!(r, x, a.value)
7976

8077

8178
### tri products
@@ -92,10 +89,10 @@ end
9289

9390
function X_invA_Xt(a::ScalMat, x::StridedMatrix)
9491
@check_argdims dim(a) == size(x, 2)
95-
lmul!(a.inv_value, x * transpose(x))
92+
_rdiv!(x * transpose(x), a.value)
9693
end
9794

9895
function Xt_invA_X(a::ScalMat, x::StridedMatrix)
9996
@check_argdims dim(a) == size(x, 1)
100-
lmul!(a.inv_value, transpose(x) * x)
97+
_rdiv!(transpose(x) * x, a.value)
10198
end

src/utils.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ function wsumsq(w::AbstractVector, a::AbstractVector)
5959
return s
6060
end
6161

62+
function invwsumsq(w::AbstractVector, a::AbstractVector)
63+
@check_argdims(length(a) == length(w))
64+
s = zero(zero(eltype(a)) / zero(eltype(w)))
65+
for i = 1:length(a)
66+
@inbounds s += abs2(a[i]) / w[i]
67+
end
68+
return s
69+
end
70+
6271
function colwise_dot!(r::AbstractArray, a::AbstractMatrix, b::AbstractMatrix)
6372
n = length(r)
6473
@check_argdims n == size(a, 2) == size(b, 2) && size(a, 1) == size(b, 1)
@@ -84,3 +93,37 @@ function colwise_sumsq!(r::AbstractArray, a::AbstractMatrix, c::Real)
8493
end
8594
return r
8695
end
96+
97+
function colwise_sumsqinv!(r::AbstractArray, a::AbstractMatrix, c::Real)
98+
n = length(r)
99+
@check_argdims n == size(a, 2)
100+
for j = 1:n
101+
v = zero(eltype(a))
102+
@simd for i = 1:size(a, 1)
103+
@inbounds v += abs2(a[i, j])
104+
end
105+
r[j] = v / c
106+
end
107+
return r
108+
end
109+
110+
# `rdiv!(::AbstractArray, ::Number)` was introduced in Julia 1.2
111+
# https://github.com/JuliaLang/julia/pull/31179
112+
@static if VERSION < v"1.2.0-DEV.385"
113+
function _rdiv!(X::AbstractArray, s::Number)
114+
@simd for I in eachindex(X)
115+
@inbounds X[I] /= s
116+
end
117+
X
118+
end
119+
else
120+
_rdiv!(X::AbstractArray, s::Number) = rdiv!(X, s)
121+
end
122+
123+
# `ldiv!(::AbstractArray, ::Number, ::AbstractArray)` was introduced in Julia 1.4
124+
# https://github.com/JuliaLang/julia/pull/33806
125+
@static if VERSION < v"1.4.0-DEV.635"
126+
_ldiv!(Y::AbstractArray, s::Number, X::AbstractArray) = Y .= s .\ X
127+
else
128+
_ldiv!(Y::AbstractArray, s::Number, X::AbstractArray) = ldiv!(Y, s, X)
129+
end

test/pdmtypes.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ using Test
77
m = Matrix{T}(I, 2, 2)
88
@test PDMat(m, cholesky(m)).mat == PDMat(Symmetric(m)).mat == PDMat(m).mat == PDMat(cholesky(m)).mat
99
d = ones(T,2)
10-
@test PDiagMat(d,d).inv_diag == PDiagMat(d).inv_diag
10+
@test @test_deprecated(PDiagMat(d, d)) == PDiagMat(d)
1111
x = one(T)
12-
@test ScalMat(2,x,x).inv_value == ScalMat(2,x).inv_value
12+
@test @test_deprecated(ScalMat(2, x, x)) == ScalMat(2, x)
1313
s = SparseMatrixCSC{T}(I, 2, 2)
1414
@test PDSparseMat(s, cholesky(s)).mat == PDSparseMat(s).mat == PDSparseMat(cholesky(s)).mat
1515
end
@@ -75,4 +75,12 @@ using Test
7575
@testset "convert Matrix type to the same Cholesky type (#117)" begin
7676
@test PDMat([1 0; 0 1]) == [1.0 0.0; 0.0 1.0]
7777
end
78+
79+
# https://github.com/JuliaStats/PDMats.jl/pull/141
80+
@testset "PDiagMat with range" begin
81+
v = 0.1:0.1:0.5
82+
d = PDiagMat(v)
83+
@test d isa PDiagMat{eltype(v),typeof(v)}
84+
@test d.diag === v
85+
end
7886
end

0 commit comments

Comments
 (0)