@@ -26,6 +26,46 @@ function appleaccelerate_isavailable()
26
26
return true
27
27
end
28
28
29
+ function aa_getrf! (A:: AbstractMatrix{<:ComplexF64} ;
30
+ ipiv = similar (A, Cint, min (size (A, 1 ), size (A, 2 ))),
31
+ info = Ref {Cint} (),
32
+ check = false )
33
+ require_one_based_indexing (A)
34
+ check && chkfinite (A)
35
+ chkstride1 (A)
36
+ m, n = size (A)
37
+ lda = max (1 , stride (A, 2 ))
38
+ if isempty (ipiv)
39
+ ipiv = similar (A, Cint, min (size (A, 1 ), size (A, 2 )))
40
+ end
41
+ ccall ((" zgetrf_" , libacc), Cvoid,
42
+ (Ref{Cint}, Ref{Cint}, Ptr{ComplexF64},
43
+ Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
44
+ m, n, A, lda, ipiv, info)
45
+ info[] < 0 && throw (ArgumentError (" Invalid arguments sent to LAPACK dgetrf_" ))
46
+ A, ipiv, BlasInt (info[]), info # Error code is stored in LU factorization type
47
+ end
48
+
49
+ function aa_getrf! (A:: AbstractMatrix{<:ComplexF32} ;
50
+ ipiv = similar (A, Cint, min (size (A, 1 ), size (A, 2 ))),
51
+ info = Ref {Cint} (),
52
+ check = false )
53
+ require_one_based_indexing (A)
54
+ check && chkfinite (A)
55
+ chkstride1 (A)
56
+ m, n = size (A)
57
+ lda = max (1 , stride (A, 2 ))
58
+ if isempty (ipiv)
59
+ ipiv = similar (A, Cint, min (size (A, 1 ), size (A, 2 )))
60
+ end
61
+ ccall ((" cgetrf_" , libacc), Cvoid,
62
+ (Ref{Cint}, Ref{Cint}, Ptr{ComplexF32},
63
+ Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
64
+ m, n, A, lda, ipiv, info)
65
+ info[] < 0 && throw (ArgumentError (" Invalid arguments sent to LAPACK dgetrf_" ))
66
+ A, ipiv, BlasInt (info[]), info # Error code is stored in LU factorization type
67
+ end
68
+
29
69
function aa_getrf! (A:: AbstractMatrix{<:Float64} ;
30
70
ipiv = similar (A, Cint, min (size (A, 1 ), size (A, 2 ))),
31
71
info = Ref {Cint} (),
@@ -67,6 +107,55 @@ function aa_getrf!(A::AbstractMatrix{<:Float32};
67
107
A, ipiv, BlasInt (info[]), info # Error code is stored in LU factorization type
68
108
end
69
109
110
+ function aa_getrs! (trans:: AbstractChar ,
111
+ A:: AbstractMatrix{<:ComplexF64} ,
112
+ ipiv:: AbstractVector{Cint} ,
113
+ B:: AbstractVecOrMat{<:ComplexF64} ;
114
+ info = Ref {Cint} ())
115
+ require_one_based_indexing (A, ipiv, B)
116
+ LinearAlgebra. LAPACK. chktrans (trans)
117
+ chkstride1 (A, B, ipiv)
118
+ n = LinearAlgebra. checksquare (A)
119
+ if n != size (B, 1 )
120
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
121
+ end
122
+ if n != length (ipiv)
123
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
124
+ end
125
+ nrhs = size (B, 2 )
126
+ ccall ((" zgetrs_" , libacc), Cvoid,
127
+ (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{ComplexF64}, Ref{Cint},
128
+ Ptr{Cint}, Ptr{ComplexF64}, Ref{Cint}, Ptr{Cint}, Clong),
129
+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
130
+ 1 )
131
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
132
+ end
133
+
134
+ function aa_getrs! (trans:: AbstractChar ,
135
+ A:: AbstractMatrix{<:ComplexF32} ,
136
+ ipiv:: AbstractVector{Cint} ,
137
+ B:: AbstractVecOrMat{<:ComplexF32} ;
138
+ info = Ref {Cint} ())
139
+ require_one_based_indexing (A, ipiv, B)
140
+ LinearAlgebra. LAPACK. chktrans (trans)
141
+ chkstride1 (A, B, ipiv)
142
+ n = LinearAlgebra. checksquare (A)
143
+ if n != size (B, 1 )
144
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
145
+ end
146
+ if n != length (ipiv)
147
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
148
+ end
149
+ nrhs = size (B, 2 )
150
+ ccall ((" cgetrs_" , libacc), Cvoid,
151
+ (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{ComplexF32}, Ref{Cint},
152
+ Ptr{Cint}, Ptr{ComplexF32}, Ref{Cint}, Ptr{Cint}, Clong),
153
+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
154
+ 1 )
155
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
156
+ B
157
+ end
158
+
70
159
function aa_getrs! (trans:: AbstractChar ,
71
160
A:: AbstractMatrix{<:Float64} ,
72
161
ipiv:: AbstractVector{Cint} ,
@@ -128,12 +217,20 @@ else
128
217
nothing
129
218
end
130
219
131
- function LinearSolve. init_cacheval (alg:: AppleAccelerateLUFactorization , A, b, u, Pl, Pr,
220
+ function LinearSolve. init_cacheval (alg:: AppleAccelerateLUFactorization , A:: AbstractMatrix{<:Float64} , b:: AbstractArray{<:Float64} , u, Pl, Pr,
132
221
maxiters:: Int , abstol, reltol, verbose:: Bool ,
133
222
assumptions:: OperatorAssumptions )
134
223
PREALLOCATED_APPLE_LU
135
224
end
136
225
226
+ function LinearSolve. init_cacheval (alg:: AppleAccelerateLUFactorization , A, b, u, Pl, Pr,
227
+ maxiters:: Int , abstol, reltol, verbose:: Bool ,
228
+ assumptions:: OperatorAssumptions )
229
+ A = rand (eltype (A), 0 , 0 )
230
+ luinst = ArrayInterface. lu_instance (A)
231
+ LU (luinst. factors, similar (A, Cint, 0 ), luinst. info), Ref {Cint} ()
232
+ end
233
+
137
234
function SciMLBase. solve! (cache:: LinearCache , alg:: AppleAccelerateLUFactorization ;
138
235
kwargs... )
139
236
A = cache. A
0 commit comments