@@ -44,7 +44,7 @@ function openblas_getrf!(A::AbstractMatrix{<:ComplexF64};
44
44
ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
45
45
info = Ref {BlasInt} (),
46
46
check = false )
47
- __openblas_isavailable () ||
47
+ __openblas_isavailable () ||
48
48
error (" Error, OpenBLAS binary is missing but solve is being called. Report this issue" )
49
49
require_one_based_indexing (A)
50
50
check && chkfinite (A)
@@ -66,7 +66,7 @@ function openblas_getrf!(A::AbstractMatrix{<:ComplexF32};
66
66
ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
67
67
info = Ref {BlasInt} (),
68
68
check = false )
69
- __openblas_isavailable () ||
69
+ __openblas_isavailable () ||
70
70
error (" Error, OpenBLAS binary is missing but solve is being called. Report this issue" )
71
71
require_one_based_indexing (A)
72
72
check && chkfinite (A)
@@ -88,7 +88,7 @@ function openblas_getrf!(A::AbstractMatrix{<:Float64};
88
88
ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
89
89
info = Ref {BlasInt} (),
90
90
check = false )
91
- __openblas_isavailable () ||
91
+ __openblas_isavailable () ||
92
92
error (" Error, OpenBLAS binary is missing but solve is being called. Report this issue" )
93
93
require_one_based_indexing (A)
94
94
check && chkfinite (A)
@@ -110,7 +110,7 @@ function openblas_getrf!(A::AbstractMatrix{<:Float32};
110
110
ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
111
111
info = Ref {BlasInt} (),
112
112
check = false )
113
- __openblas_isavailable () ||
113
+ __openblas_isavailable () ||
114
114
error (" Error, OpenBLAS binary is missing but solve is being called. Report this issue" )
115
115
require_one_based_indexing (A)
116
116
check && chkfinite (A)
@@ -133,7 +133,7 @@ function openblas_getrs!(trans::AbstractChar,
133
133
ipiv:: AbstractVector{BlasInt} ,
134
134
B:: AbstractVecOrMat{<:ComplexF64} ;
135
135
info = Ref {BlasInt} ())
136
- __openblas_isavailable () ||
136
+ __openblas_isavailable () ||
137
137
error (" Error, OpenBLAS binary is missing but solve is being called. Report this issue" )
138
138
require_one_based_indexing (A, ipiv, B)
139
139
LinearAlgebra. LAPACK. chktrans (trans)
@@ -160,7 +160,7 @@ function openblas_getrs!(trans::AbstractChar,
160
160
ipiv:: AbstractVector{BlasInt} ,
161
161
B:: AbstractVecOrMat{<:ComplexF32} ;
162
162
info = Ref {BlasInt} ())
163
- __openblas_isavailable () ||
163
+ __openblas_isavailable () ||
164
164
error (" Error, OpenBLAS binary is missing but solve is being called. Report this issue" )
165
165
require_one_based_indexing (A, ipiv, B)
166
166
LinearAlgebra. LAPACK. chktrans (trans)
@@ -187,7 +187,7 @@ function openblas_getrs!(trans::AbstractChar,
187
187
ipiv:: AbstractVector{BlasInt} ,
188
188
B:: AbstractVecOrMat{<:Float64} ;
189
189
info = Ref {BlasInt} ())
190
- __openblas_isavailable () ||
190
+ __openblas_isavailable () ||
191
191
error (" Error, OpenBLAS binary is missing but solve is being called. Report this issue" )
192
192
require_one_based_indexing (A, ipiv, B)
193
193
LinearAlgebra. LAPACK. chktrans (trans)
@@ -214,7 +214,7 @@ function openblas_getrs!(trans::AbstractChar,
214
214
ipiv:: AbstractVector{BlasInt} ,
215
215
B:: AbstractVecOrMat{<:Float32} ;
216
216
info = Ref {BlasInt} ())
217
- __openblas_isavailable () ||
217
+ __openblas_isavailable () ||
218
218
error (" Error, OpenBLAS binary is missing but solve is being called. Report this issue" )
219
219
require_one_based_indexing (A, ipiv, B)
220
220
LinearAlgebra. LAPACK. chktrans (trans)
260
260
261
261
function SciMLBase. solve! (cache:: LinearCache , alg:: OpenBLASLUFactorization ;
262
262
kwargs... )
263
- __openblas_isavailable () ||
263
+ __openblas_isavailable () ||
264
264
error (" Error, OpenBLAS binary is missing but solve is being called. Report this issue" )
265
265
A = cache. A
266
266
A = convert (AbstractMatrix, A)
@@ -292,3 +292,82 @@ function SciMLBase.solve!(cache::LinearCache, alg::OpenBLASLUFactorization;
292
292
SciMLBase. build_linear_solution (
293
293
alg, cache. u, nothing , cache; retcode = ReturnCode. Success)
294
294
end
295
+
296
+ # Mixed precision OpenBLAS implementation
297
+ default_alias_A (:: OpenBLAS32MixedLUFactorization , :: Any , :: Any ) = false
298
+ default_alias_b (:: OpenBLAS32MixedLUFactorization , :: Any , :: Any ) = false
299
+
300
+ const PREALLOCATED_OPENBLAS32_LU = begin
301
+ A = rand (Float32, 0 , 0 )
302
+ luinst = ArrayInterface. lu_instance (A), Ref {BlasInt} ()
303
+ end
304
+
305
+ function LinearSolve. init_cacheval (alg:: OpenBLAS32MixedLUFactorization , A, b, u, Pl, Pr,
306
+ maxiters:: Int , abstol, reltol, verbose:: LinearVerbosity ,
307
+ assumptions:: OperatorAssumptions )
308
+ # Pre-allocate appropriate 32-bit arrays based on input type
309
+ if eltype (A) <: Complex
310
+ A_32 = rand (ComplexF32, 0 , 0 )
311
+ else
312
+ A_32 = rand (Float32, 0 , 0 )
313
+ end
314
+ ArrayInterface. lu_instance (A_32), Ref {BlasInt} ()
315
+ end
316
+
317
+ function SciMLBase. solve! (cache:: LinearCache , alg:: OpenBLAS32MixedLUFactorization ;
318
+ kwargs... )
319
+ __openblas_isavailable () ||
320
+ error (" Error, OpenBLAS binary is missing but solve is being called. Report this issue" )
321
+ A = cache. A
322
+ A = convert (AbstractMatrix, A)
323
+
324
+ # Check if we have complex numbers
325
+ iscomplex = eltype (A) <: Complex
326
+
327
+ if cache. isfresh
328
+ cacheval = @get_cacheval (cache, :OpenBLAS32MixedLUFactorization )
329
+ # Convert to appropriate 32-bit type for factorization
330
+ if iscomplex
331
+ A_f32 = ComplexF32 .(A)
332
+ else
333
+ A_f32 = Float32 .(A)
334
+ end
335
+ res = openblas_getrf! (A_f32; ipiv = cacheval[1 ]. ipiv, info = cacheval[2 ])
336
+ fact = LU (res[1 : 3 ]. .. ), res[4 ]
337
+ cache. cacheval = fact
338
+
339
+ if ! LinearAlgebra. issuccess (fact[1 ])
340
+ return SciMLBase. build_linear_solution (
341
+ alg, cache. u, nothing , cache; retcode = ReturnCode. Failure)
342
+ end
343
+ cache. isfresh = false
344
+ end
345
+
346
+ A_lu, info = @get_cacheval (cache, :OpenBLAS32MixedLUFactorization )
347
+ require_one_based_indexing (cache. u, cache. b)
348
+ m, n = size (A_lu, 1 ), size (A_lu, 2 )
349
+
350
+ # Convert b to appropriate 32-bit type for solving
351
+ if iscomplex
352
+ b_f32 = ComplexF32 .(cache. b)
353
+ else
354
+ b_f32 = Float32 .(cache. b)
355
+ end
356
+
357
+ if m > n
358
+ Bc = copy (b_f32)
359
+ openblas_getrs! (' N' , A_lu. factors, A_lu. ipiv, Bc; info)
360
+ # Convert back to original precision
361
+ T = eltype (cache. u)
362
+ cache. u .= T .(Bc[1 : n])
363
+ else
364
+ u_f32 = copy (b_f32)
365
+ openblas_getrs! (' N' , A_lu. factors, A_lu. ipiv, u_f32; info)
366
+ # Convert back to original precision
367
+ T = eltype (cache. u)
368
+ cache. u .= T .(u_f32)
369
+ end
370
+
371
+ SciMLBase. build_linear_solution (
372
+ alg, cache. u, nothing , cache; retcode = ReturnCode. Success)
373
+ end
0 commit comments