@@ -41,7 +41,27 @@ function aa_getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, Cint, min(siz
41
41
Ref{Cint}, Ptr{Cint}, Ptr{Cint}),
42
42
m, n, A, lda, ipiv, info)
43
43
info[] < 0 && throw (ArgumentError (" Invalid arguments sent to LAPACK dgetrf_" ))
44
- A, Vector {BlasInt} (ipiv), BlasInt (info[]) # Error code is stored in LU factorization type
44
+ A, ipiv, BlasInt (info[]), info # Error code is stored in LU factorization type
45
+ end
46
+
47
+ function aa_getrs! (trans:: AbstractChar , A:: AbstractMatrix{<:Float64} , ipiv:: AbstractVector{Cint} , B:: AbstractVecOrMat{<:Float64} ; info = Ref {Cint} ())
48
+ require_one_based_indexing (A, ipiv, B)
49
+ LinearAlgebra. LAPACK. chktrans (trans)
50
+ chkstride1 (A, B, ipiv)
51
+ n = LinearAlgebra. checksquare (A)
52
+ if n != size (B, 1 )
53
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
54
+ end
55
+ if n != length (ipiv)
56
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
57
+ end
58
+ nrhs = size (B, 2 )
59
+ ccall ((" dgetrs_" , libacc), Cvoid,
60
+ (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float64}, Ref{Cint},
61
+ Ptr{Cint}, Ptr{Float64}, Ref{Cint}, Ptr{Cint}, Clong),
62
+ trans, n, size (B,2 ), A, max (1 ,stride (A,2 )), ipiv, B, max (1 ,stride (B,2 )), info, 1 )
63
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
64
+ B
45
65
end
46
66
47
67
default_alias_A (:: AppleAccelerateLUFactorization , :: Any , :: Any ) = false
@@ -50,7 +70,8 @@ default_alias_b(::AppleAccelerateLUFactorization, ::Any, ::Any) = false
50
70
function LinearSolve. init_cacheval (alg:: AppleAccelerateLUFactorization , A, b, u, Pl, Pr,
51
71
maxiters:: Int , abstol, reltol, verbose:: Bool ,
52
72
assumptions:: OperatorAssumptions )
53
- ArrayInterface. lu_instance (convert (AbstractMatrix, A))
73
+ luinst = ArrayInterface. lu_instance (convert (AbstractMatrix, A))
74
+ LU (luinst. factors,similar (A, Cint, 0 ), luinst. info), Ref {Cint} ()
54
75
end
55
76
56
77
function SciMLBase. solve! (cache:: LinearCache , alg:: AppleAccelerateLUFactorization ;
@@ -59,10 +80,23 @@ function SciMLBase.solve!(cache::LinearCache, alg::AppleAccelerateLUFactorizatio
59
80
A = convert (AbstractMatrix, A)
60
81
if cache. isfresh
61
82
cacheval = @get_cacheval (cache, :AppleAccelerateLUFactorization )
62
- fact = LU (aa_getrf! (A; ipiv = cacheval. ipiv)... )
83
+ res = aa_getrf! (A; ipiv = cacheval[1 ]. ipiv, info = cacheval[2 ])
84
+ fact = LU (res[1 : 3 ]. .. ), res[4 ]
63
85
cache. cacheval = fact
64
86
cache. isfresh = false
65
87
end
66
- y = ldiv! (cache. u, @get_cacheval (cache, :AppleAccelerateLUFactorization ), cache. b)
67
- SciMLBase. build_linear_solution (alg, y, nothing , cache)
88
+
89
+ A, info = @get_cacheval (cache, :AppleAccelerateLUFactorization )
90
+ LinearAlgebra. require_one_based_indexing (cache. u, cache. b)
91
+ m, n = size (A, 1 ), size (A, 2 )
92
+ if m > n
93
+ Bc = copy (cache. b)
94
+ aa_getrs! (' N' , A. factors, A. ipiv, Bc; info)
95
+ return copyto! (cache. u, 1 , Bc, 1 , n)
96
+ else
97
+ copyto! (cache. u, cache. b)
98
+ aa_getrs! (' N' , A. factors, A. ipiv, cache. u; info)
99
+ end
100
+
101
+ SciMLBase. build_linear_solution (alg, cache. u, nothing , cache)
68
102
end
0 commit comments