Skip to content

Commit beac28e

Browse files
Mypy appeasement
1 parent 9fba411 commit beac28e

File tree

1 file changed

+33
-30
lines changed

1 file changed

+33
-30
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -361,16 +361,16 @@ def xlamch_impl(kind: str = "E") -> Callable[[str], float]:
361361

362362
def impl(kind: str = "E") -> float:
363363
KIND = val_to_int_ptr(ord(kind))
364-
return numba_lamch(KIND)
364+
return numba_lamch(KIND) # type: ignore
365365

366366
return impl
367367

368368

369369
def _xlange(A: np.ndarray, order: str | None = None) -> float:
370370
"""
371-
Placeholder for computing the norm of a matrix; used by linalg.solve. Not used by pytensor to numbify graphs.
371+
Placeholder for computing the norm of a matrix; used by linalg.solve. Will never be called in python mode.
372372
"""
373-
pass
373+
return # type: ignore
374374

375375

376376
@overload(_xlange)
@@ -388,19 +388,20 @@ def xlange_impl(
388388
numba_lange = _LAPACK().numba_xlange(dtype)
389389

390390
def impl(A: np.ndarray, order: str | None = None):
391-
_M, _N = np.int32(A.shape[-2:])
391+
_M, _N = np.int32(A.shape[-2:]) # type: ignore
392+
392393
A_copy = _copy_to_fortran_order(A)
393394

394-
M = val_to_int_ptr(_M)
395-
N = val_to_int_ptr(_N)
396-
LDA = val_to_int_ptr(_M)
395+
M = val_to_int_ptr(_M) # type: ignore
396+
N = val_to_int_ptr(_N) # type: ignore
397+
LDA = val_to_int_ptr(_M) # type: ignore
397398

398399
NORM = (
399400
val_to_int_ptr(ord(order))
400401
if order is not None
401402
else val_to_int_ptr(ord("1"))
402403
)
403-
WORK = np.empty(_M, dtype=dtype)
404+
WORK = np.empty(_M, dtype=dtype) # type: ignore
404405

405406
result = numba_lange(
406407
NORM, M, N, A_copy.view(w_type).ctypes, LDA, WORK.view(w_type).ctypes
@@ -416,7 +417,7 @@ def _xgecon(A: np.ndarray, A_norm: float, norm: str) -> tuple[np.ndarray, int]:
416417
Placeholder for computing the condition number of a matrix; used by linalg.solve. Not used by pytensor to numbify
417418
graphs.
418419
"""
419-
pass
420+
return # type: ignore
420421

421422

422423
@overload(_xgecon)
@@ -468,7 +469,7 @@ def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
468469
469470
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
470471
"""
471-
pass
472+
return # type: ignore
472473

473474

474475
@overload(_getrf)
@@ -484,17 +485,17 @@ def getrf_impl(
484485
def impl(
485486
A: np.ndarray, overwrite_a: bool = False
486487
) -> tuple[np.ndarray, np.ndarray, int]:
487-
_M, _N = np.int32(A.shape[-2:])
488+
_M, _N = np.int32(A.shape[-2:]) # type: ignore
488489

489490
if not overwrite_a:
490491
A_copy = _copy_to_fortran_order(A)
491492
else:
492493
A_copy = A
493494

494-
M = val_to_int_ptr(_M)
495-
N = val_to_int_ptr(_N)
496-
LDA = val_to_int_ptr(_M)
497-
IPIV = np.empty(_N, dtype=np.int32)
495+
M = val_to_int_ptr(_M) # type: ignore
496+
N = val_to_int_ptr(_N) # type: ignore
497+
LDA = val_to_int_ptr(_M) # type: ignore
498+
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
498499
INFO = val_to_int_ptr(0)
499500

500501
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
@@ -512,7 +513,7 @@ def _getrs(
512513
513514
# TODO: Implement an LU_solve Op, then dispatch to this function in numba mode.
514515
"""
515-
pass
516+
return # type: ignore
516517

517518

518519
@overload(_getrs)
@@ -580,8 +581,9 @@ def _solve_gen(
580581
overwrite_b: bool,
581582
check_finite: bool,
582583
transposed: bool,
583-
) -> np.ndarray:
584-
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects."""
584+
):
585+
"""Thin wrapper around scipy.linalg.solve. Used as an overload target for numba to avoid unexpected side-effects
586+
for users who import pytensor."""
585587
return linalg.solve(
586588
A,
587589
B,
@@ -646,7 +648,7 @@ def _sysv(
646648
"""
647649
Placeholder for solving a linear system with a symmetric matrix; used by linalg.solve.
648650
"""
649-
pass
651+
return # type: ignore
650652

651653

652654
@overload(_sysv)
@@ -665,7 +667,7 @@ def sysv_impl(
665667
def impl(
666668
A: np.ndarray, B: np.ndarray, lower: bool, overwrite_a: bool, overwrite_b: bool
667669
):
668-
_LDA, _N = np.int32(A.shape[-2:])
670+
_LDA, _N = np.int32(A.shape[-2:]) # type: ignore
669671
_solve_check_input_shapes(A, B)
670672

671673
if not overwrite_a:
@@ -685,11 +687,11 @@ def impl(
685687
NRHS = 1 if B_is_1d else int(B.shape[-1])
686688

687689
UPLO = val_to_int_ptr(ord("L") if lower else ord("U"))
688-
N = val_to_int_ptr(_N)
690+
N = val_to_int_ptr(_N) # type: ignore
689691
NRHS = val_to_int_ptr(NRHS)
690-
LDA = val_to_int_ptr(_LDA)
691-
IPIV = np.empty(_N, dtype=np.int32)
692-
LDB = val_to_int_ptr(_N)
692+
LDA = val_to_int_ptr(_LDA) # type: ignore
693+
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
694+
LDB = val_to_int_ptr(_N) # type: ignore
693695
WORK = np.empty(1, dtype=dtype)
694696
LWORK = val_to_int_ptr(-1)
695697
INFO = val_to_int_ptr(0)
@@ -737,9 +739,10 @@ def impl(
737739

738740
def _sycon(A: np.ndarray, ipiv: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
739741
"""
740-
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve.
742+
Placeholder for computing the condition number of a symmetric matrix; used by linalg.solve. Never called in
743+
python mode.
741744
"""
742-
pass
745+
return # type: ignore
743746

744747

745748
@overload(_sycon)
@@ -791,7 +794,7 @@ def _solve_symmetric(
791794
overwrite_b: bool,
792795
check_finite: bool,
793796
transposed: bool,
794-
) -> np.ndarray:
797+
):
795798
"""Thin wrapper around scipy.linalg.solve for symmetric matrices. Used as an overload target for numba to avoid
796799
unexpected side-effects when users import pytensor."""
797800
return linalg.solve(
@@ -854,7 +857,7 @@ def _posv(
854857
"""
855858
Placeholder for solving a linear system with a positive-definite matrix; used by linalg.solve.
856859
"""
857-
pass
860+
return # type: ignore
858861

859862

860863
@overload(_posv)
@@ -936,7 +939,7 @@ def _pocon(A: np.ndarray, anorm: float) -> tuple[np.ndarray, int]:
936939
Placeholder for computing the condition number of a cholesky-factorized positive-definite matrix. Used by
937940
linalg.solve when assume_a = "pos".
938941
"""
939-
pass
942+
return # type: ignore
940943

941944

942945
@overload(_pocon)
@@ -987,7 +990,7 @@ def _solve_psd(
987990
overwrite_b: bool,
988991
check_finite: bool,
989992
transposed: bool,
990-
) -> np.ndarray:
993+
):
991994
"""Thin wrapper around scipy.linalg.solve for positive-definite matrices. Used as an overload target for numba to
992995
avoid unexpected side-effects when users import pytensor."""
993996
return linalg.solve(

0 commit comments

Comments
 (0)