Skip to content

Commit e5278ec

Browse files
Implement numba solve for assume_a = "sym"
1 parent cace929 commit e5278ec

File tree

3 files changed

+375
-33
lines changed

3 files changed

+375
-33
lines changed

pytensor/link/numba/dispatch/_LAPACK.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,20 @@ def numba_xlange(cls, dtype):
198198
)
199199
return functype(lapack_ptr)
200200

201+
@classmethod
202+
def numba_xlamch(cls, dtype):
203+
"""
204+
Determine machine precision for floating point arithmetic.
205+
"""
206+
207+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "lamch")
208+
output_dtype = _get_output_ctype(dtype)
209+
functype = ctypes.CFUNCTYPE(
210+
output_dtype, # Output
211+
_ptr_int, # CMACH
212+
)
213+
return functype(lapack_ptr)
214+
201215
@classmethod
202216
def numba_xgecon(cls, dtype):
203217
"""
@@ -225,7 +239,7 @@ def numba_xgetrf(cls, dtype):
225239
"""
226240
Compute partial pivoting LU factorization of a general M-by-N matrix A using row interchanges.
227241
228-
Called by scipy.linalg.solve when assume_a == "gen"
242+
Called by scipy.linalg.lu_factor
229243
"""
230244
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrf")
231245
functype = ctypes.CFUNCTYPE(
@@ -245,7 +259,7 @@ def numba_xgetrs(cls, dtype):
245259
Solve a system of linear equations A @ X = B or A.T @ X = B with a general N-by-N matrix A using the LU
246260
factorization computed by numba_getrf.
247261
248-
Called by scipy.linalg.solve when assume_a == "gen"
262+
Called by scipy.linalg.lu_solve
249263
"""
250264
...
251265
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs")
@@ -262,3 +276,97 @@ def numba_xgetrs(cls, dtype):
262276
_ptr_int, # INFO
263277
)
264278
return functype(lapack_ptr)
279+
280+
@classmethod
281+
def numba_xsysv(cls, dtype):
282+
"""
283+
Solve a system of linear equations A @ X = B with a symmetric matrix A using the factorization computed by
284+
sytrf (LDL or UDU).
285+
286+
Called by scipy.linalg.solve when assume_a == "sym"
287+
"""
288+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sysv")
289+
functype = ctypes.CFUNCTYPE(
290+
None,
291+
_ptr_int, # UPLO
292+
_ptr_int, # N
293+
_ptr_int, # NRHS
294+
float_pointer, # A
295+
_ptr_int, # LDA
296+
_ptr_int, # IPIV
297+
float_pointer, # B
298+
_ptr_int, # LDB
299+
float_pointer, # WORK
300+
_ptr_int, # LWORK
301+
_ptr_int, # INFO
302+
)
303+
return functype(lapack_ptr)
304+
305+
@classmethod
306+
def numba_xsycon(cls, dtype):
307+
"""
308+
Estimates the reciprocal of the condition number of a symmetric matrix A using the factorization computed by
309+
sytrf (LDL or UDU).
310+
311+
Called by scipy.linalg.solve when assume_a == "sym"
312+
"""
313+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "sycon")
314+
315+
functype = ctypes.CFUNCTYPE(
316+
None,
317+
_ptr_int, # UPLO
318+
_ptr_int, # N
319+
float_pointer, # A
320+
_ptr_int, # LDA
321+
_ptr_int, # IPIV
322+
float_pointer, # ANORM
323+
float_pointer, # RCOND
324+
float_pointer, # WORK
325+
_ptr_int, # IWORK
326+
_ptr_int, # INFO
327+
)
328+
return functype(lapack_ptr)
329+
330+
@classmethod
331+
def numba_xpocon(cls, dtype):
332+
"""
333+
Estimates the reciprocal of the condition number of a positive definite matrix A using the Cholesky factorization
334+
computed by potrf.
335+
336+
Called by scipy.linalg.solve when assume_a == "pos"
337+
"""
338+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "pocon")
339+
functype = ctypes.CFUNCTYPE(
340+
None,
341+
_ptr_int, # UPLO
342+
_ptr_int, # N
343+
float_pointer, # A
344+
_ptr_int, # LDA
345+
float_pointer, # ANORM
346+
float_pointer, # RCOND
347+
float_pointer, # WORK
348+
_ptr_int, # IWORK
349+
_ptr_int, # INFO
350+
)
351+
return functype(lapack_ptr)
352+
353+
@classmethod
354+
def numba_xposv(cls, dtype):
355+
"""
356+
Solve a system of linear equations A @ X = B with a symmetric positive definite matrix A using the Cholesky
357+
factorization computed by potrf.
358+
"""
359+
360+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "posv")
361+
functype = ctypes.CFUNCTYPE(
362+
None,
363+
_ptr_int, # UPLO
364+
_ptr_int, # N
365+
_ptr_int, # NRHS
366+
float_pointer, # A
367+
_ptr_int, # LDA
368+
float_pointer, # B
369+
_ptr_int, # LDB
370+
_ptr_int, # INFO
371+
)
372+
return functype(lapack_ptr)

0 commit comments

Comments
 (0)