You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
return qr_compact!(A, Q, R, default_backend(qr_compact!, A; kwargs...))
11
+
return qr_compact!(A, Q, R, default_backend(qr_compact!, A; kwargs...); kwargs...)
12
12
end
13
13
14
14
function default_backend(::typeof(qr_full!), A::AbstractMatrix; kwargs...)
@@ -20,4 +20,113 @@ end
20
20
21
21
function default_qr_backend(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
22
22
return LAPACKBackend()
23
-
end
23
+
end
24
+
25
+
function check_qr_full_input(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix)
26
+
m, n = size(A)
27
+
size(Q) == (m, m) ||
28
+
throw(DimensionMismatch("Full unitary matrix `Q` must be square with equal number of rows as A"))
29
+
isempty(R) || size(R) == (m, n) ||
30
+
throw(DimensionMismatch("Upper triangular matrix `R` must have size equal to A"))
31
+
returnnothing
32
+
end
33
+
function check_qr_compact_input(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix)
34
+
m, n = size(A)
35
+
if n <= m
36
+
size(Q) == (m, n) ||
37
+
throw(DimensionMismatch("Isometric `Q` must have size equal to A"))
38
+
isempty(R) || size(R) == (n, n) ||
39
+
throw(DimensionMismatch("Upper triangular matrix `R` must be square with equal number of columns as A"))
40
+
else
41
+
check_qr_full_input(A, Q, R)
42
+
end
43
+
end
44
+
45
+
function qr_full!(A::AbstractMatrix,
46
+
Q::AbstractMatrix,
47
+
R::AbstractMatrix,
48
+
backend::LAPACKBackend;
49
+
positive=false,
50
+
pivoted=false,
51
+
blocksize=((pivoted || A === Q) ?1: YALAPACK.default_qr_blocksize(A)))
52
+
check_qr_full_input(A, Q, R)
53
+
_unsafe_qr!(A, Q, R; positive=positive, pivoted=pivoted, blocksize=blocksize)
54
+
return Q, R
55
+
end
56
+
57
+
function qr_compact!(A::AbstractMatrix,
58
+
Q::AbstractMatrix,
59
+
R::AbstractMatrix,
60
+
backend::LAPACKBackend;
61
+
positive=false,
62
+
pivoted=false,
63
+
blocksize=((pivoted || A === Q) ?1: YALAPACK.default_qr_blocksize(A)))
64
+
check_qr_compact_input(A, Q, R)
65
+
_unsafe_qr!(A, Q, R; positive=positive, pivoted=pivoted, blocksize=blocksize)
66
+
return Q, R
67
+
end
68
+
69
+
function _unsafe_qr!(A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
70
+
positive=false,
71
+
pivoted=false,
72
+
blocksize=((pivoted || A === Q) ?1: YALAPACK.default_qr_blocksize(A)))
73
+
m, n = size(A)
74
+
minmn = min(m, n)
75
+
computeR = length(R) >0
76
+
inplaceQ = Q === A
77
+
78
+
if pivoted && (blocksize >1)
79
+
throw(ArgumentError("LAPACK does not provide a blocked implementation for a pivoted QR decomposition"))
80
+
end
81
+
if inplaceQ && (computeR || positive || blocksize >1|| m < n)
82
+
throw(ArgumentError("inplace Q only supported if matrix is tall (`m >= n`), R is not required, and using the unblocked algorithm (`blocksize=1`) with `positive=false`"))
83
+
end
84
+
85
+
if blocksize >1
86
+
nb = min(minmn, blocksize)
87
+
if computeR # first use R as space for T
88
+
A, T = YALAPACK.geqrt!(A, view(R, 1:nb, 1:minmn))
89
+
else
90
+
A, T = YALAPACK.geqrt!(A, similar(A, nb, minmn))
91
+
end
92
+
Q = YALAPACK.gemqrt!('L', 'N', A, T, one!(Q))
93
+
else
94
+
if pivoted
95
+
A, τ, jpvt = YALAPACK.geqp3!(A)
96
+
else
97
+
A, τ = YALAPACK.geqrf!(A)
98
+
end
99
+
if inplaceQ
100
+
Q = YALAPACK.orgqr!(A, τ)
101
+
else
102
+
Q = YALAPACK.ormqr!('L', 'N', A, τ, one!(Q))
103
+
end
104
+
end
105
+
106
+
if positive # already fix Q even if we do not need R
107
+
@inbounds for j in1:minmn
108
+
s = safesign(A[j, j])
109
+
@simd for i in1:m
110
+
Q[i, j] *= s
111
+
end
112
+
end
113
+
end
114
+
115
+
if computeR
116
+
R̃ = triu!(view(A, axes(R)...))
117
+
if positive
118
+
@inbounds for j in n:-1:1
119
+
@simd for i in1:min(minmn, j)
120
+
R̃[i, j] = R̃[i, j] * conj(safesign(R̃[i, i]))
121
+
end
122
+
end
123
+
end
124
+
if!pivoted
125
+
copyto!(R, R̃)
126
+
else
127
+
# probably very inefficient in terms of memory access
0 commit comments