11"""
22Positive definite diagonal matrix.
33"""
4- struct PDiagMat{T<: Real ,V<: AbstractVector } <: AbstractPDMat{T}
4+ struct PDiagMat{T<: Real ,V<: AbstractVector{T} } <: AbstractPDMat{T}
55 dim:: Int
66 diag:: V
7- inv_diag:: V
8-
9- PDiagMat {T,S} (d:: Int ,v:: AbstractVector ,inv_v:: AbstractVector ) where {T,S} =
10- new {T,S} (d,v,inv_v)
11- end
12-
13- function PDiagMat (v:: AbstractVector ,inv_v:: AbstractVector )
14- @check_argdims length (v) == length (inv_v)
15- PDiagMat {eltype(v),typeof(v)} (length (v), v, inv_v)
167end
178
18- PDiagMat (v:: AbstractVector ) = PDiagMat (v, inv . (v))
9+ PDiagMat (v:: AbstractVector{<:Real} ) = PDiagMat {eltype(v),typeof (v)} ( length (v), v )
1910
2011# ## Conversion
2112Base. convert (:: Type{PDiagMat{T}} , a:: PDiagMat ) where {T<: Real } = PDiagMat (convert (AbstractArray{T}, a. diag))
5142* (a:: PDiagMat , c:: T ) where {T<: Real } = PDiagMat (a. diag * c)
5243* (a:: PDiagMat , x:: AbstractVector ) = a. diag .* x
5344* (a:: PDiagMat , x:: AbstractMatrix ) = a. diag .* x
54- \ (a:: PDiagMat , x:: AbstractVecOrMat ) = a . inv_diag .* x
55- / (x:: AbstractVecOrMat , a:: PDiagMat ) = a . inv_diag .* x
45+ \ (a:: PDiagMat , x:: AbstractVecOrMat ) = x ./ a . diag
46+ / (x:: AbstractVecOrMat , a:: PDiagMat ) = x ./ a . diag
5647Base. kron (A:: PDiagMat , B:: PDiagMat ) = PDiagMat ( vcat ([A. diag[i] * B. diag for i in 1 : dim (A)]. .. ) )
5748
5849# ## Algebra
5950
60- Base. inv (a:: PDiagMat ) = PDiagMat (a . inv_diag , a. diag)
51+ Base. inv (a:: PDiagMat ) = PDiagMat (map (inv , a. diag) )
6152function LinearAlgebra. logdet (a:: PDiagMat )
6253 diag = a. diag
6354 return isempty (diag) ? zero (log (zero (eltype (diag)))) : sum (log, diag)
@@ -71,9 +62,9 @@ LinearAlgebra.eigmin(a::PDiagMat) = minimum(a.diag)
7162function whiten! (r:: StridedVector , a:: PDiagMat , x:: StridedVector )
7263 n = dim (a)
7364 @check_argdims length (r) == length (x) == n
74- v = a. inv_diag
65+ v = a. diag
7566 for i = 1 : n
76- r[i] = x[i] * sqrt (v[i])
67+ r[i] = x[i] / sqrt (v[i])
7768 end
7869 return r
7970end
@@ -88,17 +79,21 @@ function unwhiten!(r::StridedVector, a::PDiagMat, x::StridedVector)
8879 return r
8980end
9081
91- whiten! (r:: StridedMatrix , a:: PDiagMat , x:: StridedMatrix ) =
92- broadcast! (* , r, x, sqrt .(a. inv_diag))
82+ function whiten! (r:: StridedMatrix , a:: PDiagMat , x:: StridedMatrix )
83+ r .= x ./ sqrt .(a. diag)
84+ return r
85+ end
9386
94- unwhiten! (r:: StridedMatrix , a:: PDiagMat , x:: StridedMatrix ) =
95- broadcast! (* , r, x, sqrt .(a. diag))
87+ function unwhiten! (r:: StridedMatrix , a:: PDiagMat , x:: StridedMatrix )
88+ r .= x .* sqrt .(a. diag)
89+ return r
90+ end
9691
9792
9893# ## quadratic forms
9994
10095quad (a:: PDiagMat , x:: AbstractVector ) = wsumsq (a. diag, x)
101- invquad (a:: PDiagMat , x:: AbstractVector ) = wsumsq (a. inv_diag , x)
96+ invquad (a:: PDiagMat , x:: AbstractVector ) = invwsumsq (a. diag , x)
10297
10398function quad! (r:: AbstractArray , a:: PDiagMat , x:: StridedMatrix )
10499 m, n = size (x)
@@ -116,12 +111,12 @@ end
116111
117112function invquad! (r:: AbstractArray , a:: PDiagMat , x:: StridedMatrix )
118113 m, n = size (x)
119- ainvd = a. inv_diag
120- @check_argdims m == length (ainvd ) && length (r) == n
114+ ad = a. diag
115+ @check_argdims m == length (ad ) && length (r) == n
121116 @inbounds for j = 1 : n
122- s = zero (promote_type (eltype (ainvd), eltype (x )))
117+ s = zero (zero (eltype (x)) / zero ( eltype (ad )))
123118 for i in 1 : m
124- s += ainvd[i] * abs2 (x[i,j])
119+ s += abs2 (x[i,j]) / ad[i]
125120 end
126121 r[j] = s
127122 end
132127# ## tri products
133128
134129function X_A_Xt (a:: PDiagMat , x:: StridedMatrix )
135- z = x .* reshape ( sqrt .(a. diag) , 1 , dim (a))
130+ z = x .* sqrt .(reshape ( a. diag, 1 , dim (a) ))
136131 z * transpose (z)
137132end
138133
@@ -142,11 +137,11 @@ function Xt_A_X(a::PDiagMat, x::StridedMatrix)
142137end
143138
144139function X_invA_Xt (a:: PDiagMat , x:: StridedMatrix )
145- z = x .* reshape ( sqrt .(a . inv_diag) , 1 , dim (a))
140+ z = x ./ sqrt .(reshape (a . diag , 1 , dim (a) ))
146141 z * transpose (z)
147142end
148143
149144function Xt_invA_X (a:: PDiagMat , x:: StridedMatrix )
150- z = x .* sqrt .(a. inv_diag )
145+ z = x ./ sqrt .(a. diag )
151146 transpose (z) * z
152147end
0 commit comments