@@ -27,6 +27,63 @@ function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(siz
27
27
A, ipiv, info[], info # Error code is stored in LU factorization type
28
28
end
29
29
30
+ function getrf! (A:: AbstractMatrix{<:Float32} ; ipiv = similar (A, BlasInt, min (size (A,1 ),size (A,2 ))), info = Ref {BlasInt} (), check = false )
31
+ require_one_based_indexing (A)
32
+ check && chkfinite (A)
33
+ chkstride1 (A)
34
+ m, n = size (A)
35
+ lda = max (1 ,stride (A, 2 ))
36
+ if isempty (ipiv)
37
+ ipiv = similar (A, BlasInt, min (size (A,1 ),size (A,2 )))
38
+ end
39
+ ccall ((@blasfunc (sgetrf_), MKL_jll. libmkl_rt), Cvoid,
40
+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32},
41
+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
42
+ m, n, A, lda, ipiv, info)
43
+ chkargsok (info[])
44
+ A, ipiv, info[], info # Error code is stored in LU factorization type
45
+ end
46
+
47
+ function getrs! (trans:: AbstractChar , A:: AbstractMatrix{<:Float64} , ipiv:: AbstractVector{Cint} , B:: AbstractVecOrMat{<:Float64} ; info = Ref {Cint} ())
48
+ require_one_based_indexing (A, ipiv, B)
49
+ LinearAlgebra. LAPACK. chktrans (trans)
50
+ chkstride1 (A, B, ipiv)
51
+ n = LinearAlgebra. checksquare (A)
52
+ if n != size (B, 1 )
53
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
54
+ end
55
+ if n != length (ipiv)
56
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
57
+ end
58
+ nrhs = size (B, 2 )
59
+ ccall ((" dgetrs_" , MKL_jll. libmkl_rt), Cvoid,
60
+ (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float64}, Ref{Cint},
61
+ Ptr{Cint}, Ptr{Float64}, Ref{Cint}, Ptr{Cint}, Clong),
62
+ trans, n, size (B,2 ), A, max (1 ,stride (A,2 )), ipiv, B, max (1 ,stride (B,2 )), info, 1 )
63
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
64
+ B
65
+ end
66
+
67
+ function getrs! (trans:: AbstractChar , A:: AbstractMatrix{<:Float32} , ipiv:: AbstractVector{Cint} , B:: AbstractVecOrMat{<:Float32} ; info = Ref {Cint} ())
68
+ require_one_based_indexing (A, ipiv, B)
69
+ LinearAlgebra. LAPACK. chktrans (trans)
70
+ chkstride1 (A, B, ipiv)
71
+ n = LinearAlgebra. checksquare (A)
72
+ if n != size (B, 1 )
73
+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
74
+ end
75
+ if n != length (ipiv)
76
+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
77
+ end
78
+ nrhs = size (B, 2 )
79
+ ccall ((" sgetrs_" , MKL_jll. libmkl_rt), Cvoid,
80
+ (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float32}, Ref{Cint},
81
+ Ptr{Cint}, Ptr{Float32}, Ref{Cint}, Ptr{Cint}, Clong),
82
+ trans, n, size (B,2 ), A, max (1 ,stride (A,2 )), ipiv, B, max (1 ,stride (B,2 )), info, 1 )
83
+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
84
+ B
85
+ end
86
+
30
87
default_alias_A (:: MKLLUFactorization , :: Any , :: Any ) = false
31
88
default_alias_b (:: MKLLUFactorization , :: Any , :: Any ) = false
32
89
@@ -47,8 +104,20 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization;
47
104
cache. cacheval = fact
48
105
cache. isfresh = false
49
106
end
50
- y = ldiv! (cache. u, @get_cacheval (cache, :MKLLUFactorization )[1 ], cache. b)
51
- SciMLBase. build_linear_solution (alg, y, nothing , cache)
107
+
108
+ A, info = @get_cacheval (cache, :MKLLUFactorization )
109
+ LinearAlgebra. require_one_based_indexing (cache. u, cache. b)
110
+ m, n = size (A, 1 ), size (A, 2 )
111
+ if m > n
112
+ Bc = copy (cache. b)
113
+ getrs! (' N' , A. factors, A. ipiv, Bc; info)
114
+ return copyto! (cache. u, 1 , Bc, 1 , n)
115
+ else
116
+ copyto! (cache. u, cache. b)
117
+ getrs! (' N' , A. factors, A. ipiv, cache. u; info)
118
+ end
119
+
120
+ SciMLBase. build_linear_solution (alg, cache. u, nothing , cache)
52
121
end
53
122
54
123
end
0 commit comments