Skip to content

Commit 01fc3a7

Browse files
committed
Make LinAlgError exception to be exposed by LAPACK pybind11 extension
1 parent c98fa8b commit 01fc3a7

File tree

4 files changed

+23
-16
lines changed

4 files changed

+23
-16
lines changed

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,9 @@ PYBIND11_MODULE(_lapack_impl, m)
8383
.value("C", oneapi::mkl::transpose::C)
8484
.export_values(); // Optional, allows access like `Transpose.N`
8585

86-
// Register a custom LinAlgError exception in the dpnp.linalg submodule
87-
py::module_ linalg_module = py::module_::import("dpnp.linalg");
88-
py::register_exception<lapack_ext::LinAlgError>(
89-
linalg_module, "LinAlgError", PyExc_ValueError);
86+
// Register a LinAlgError exception in the current submodule
87+
py::register_exception<lapack_ext::LinAlgError>(m, "LinAlgError",
88+
PyExc_ValueError);
9089

9190
init_dispatch_vectors();
9291
init_dispatch_tables();

dpnp/linalg/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
"""
3535

3636

37+
from .dpnp_iface_linalg import (
38+
LinAlgError,
39+
)
3740
from .dpnp_iface_linalg import __all__ as __all__linalg
3841
from .dpnp_iface_linalg import (
3942
cholesky,

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@
3838

3939
# pylint: disable=invalid-name
4040
# pylint: disable=no-member
41+
# pylint: disable=no-name-in-module
4142

4243
from typing import NamedTuple
4344

4445
import numpy
4546
from dpctl.tensor._numpy_helper import normalize_axis_tuple
4647

4748
import dpnp
49+
from dpnp.backend.extensions.lapack._lapack_impl import LinAlgError
4850

4951
from .dpnp_utils_linalg import (
5052
assert_2d,
@@ -70,6 +72,7 @@
7072
)
7173

7274
__all__ = [
75+
"LinAlgError",
7376
"cholesky",
7477
"cond",
7578
"cross",
@@ -105,6 +108,9 @@
105108
"vector_norm",
106109
]
107110

111+
# Need to set the module explicitly, since exposed by LAPACK pybind11 extension
112+
LinAlgError.__module__ = "dpnp.linalg"
113+
108114

109115
# pylint:disable=missing-class-docstring
110116
class EigResult(NamedTuple):
@@ -2330,7 +2336,7 @@ def tensorsolve(a, b, axes=None):
23302336
prod = numpy.prod(old_shape)
23312337

23322338
if a.size != prod**2:
2333-
raise dpnp.linalg.LinAlgError(
2339+
raise LinAlgError(
23342340
"Input arrays must satisfy the requirement "
23352341
"prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
23362342
)

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
import dpnp
5151
import dpnp.backend.extensions.lapack._lapack_impl as li
5252
from dpnp.dpnp_utils import get_usm_allocations
53-
from dpnp.linalg import LinAlgError as LinAlgError
5453

5554
__all__ = [
5655
"assert_2d",
@@ -943,7 +942,7 @@ def _check_lapack_dev_info(dev_info, error_msg=None):
943942
if any(dev_info):
944943
error_msg = error_msg or "Singular matrix"
945944

946-
raise LinAlgError(error_msg)
945+
raise li.LinAlgError(error_msg)
947946

948947

949948
def _common_type(*arrays):
@@ -1879,7 +1878,7 @@ def assert_2d(*arrays):
18791878

18801879
for a in arrays:
18811880
if a.ndim != 2:
1882-
raise LinAlgError(
1881+
raise li.LinAlgError(
18831882
f"{a.ndim}-dimensional array given. The input "
18841883
"array must be exactly two-dimensional"
18851884
)
@@ -1906,7 +1905,7 @@ def assert_stacked_2d(*arrays):
19061905

19071906
for a in arrays:
19081907
if a.ndim < 2:
1909-
raise LinAlgError(
1908+
raise li.LinAlgError(
19101909
f"{a.ndim}-dimensional array given. The input "
19111910
"array must be at least two-dimensional"
19121911
)
@@ -1942,7 +1941,7 @@ def assert_stacked_square(*arrays):
19421941
for a in arrays:
19431942
m, n = a.shape[-2:]
19441943
if m != n:
1945-
raise LinAlgError(
1944+
raise li.LinAlgError(
19461945
"Last 2 dimensions of the input array must be square"
19471946
)
19481947

@@ -2086,7 +2085,7 @@ def dpnp_cond(x, p=None):
20862085
"""Compute the condition number of a matrix."""
20872086

20882087
if _is_empty_2d(x):
2089-
raise LinAlgError("cond is not defined on empty arrays")
2088+
raise li.LinAlgError("cond is not defined on empty arrays")
20902089
if p is None or p == 2 or p == -2:
20912090
s = dpnp.linalg.svd(x, compute_uv=False)
20922091
if p == -2:
@@ -2340,15 +2339,15 @@ def dpnp_lstsq(a, b, rcond=None):
23402339
"""
23412340

23422341
if b.ndim > 2:
2343-
raise LinAlgError(
2342+
raise li.LinAlgError(
23442343
f"{b.ndim}-dimensional array given. The input "
23452344
"array must be exactly two-dimensional"
23462345
)
23472346

23482347
m, n = a.shape[-2:]
23492348
m2 = b.shape[0]
23502349
if m != m2:
2351-
raise LinAlgError("Incompatible dimensions")
2350+
raise li.LinAlgError("Incompatible dimensions")
23522351

23532352
u, s, vh = dpnp_svd(a, full_matrices=False, related_arrays=[b])
23542353

@@ -2669,20 +2668,20 @@ def dpnp_multi_dot(n, arrays, out=None):
26692668
"""Compute dot product of two or more arrays in a single function call."""
26702669

26712670
if not arrays[0].ndim in [1, 2]:
2672-
raise LinAlgError(
2671+
raise li.LinAlgError(
26732672
f"{arrays[0].ndim}-dimensional array given. "
26742673
"First array must be 1-D or 2-D."
26752674
)
26762675

26772676
if not arrays[-1].ndim in [1, 2]:
2678-
raise LinAlgError(
2677+
raise li.LinAlgError(
26792678
f"{arrays[-1].ndim}-dimensional array given. "
26802679
"Last array must be 1-D or 2-D."
26812680
)
26822681

26832682
for arr in arrays[1:-1]:
26842683
if arr.ndim != 2:
2685-
raise LinAlgError(
2684+
raise li.LinAlgError(
26862685
f"{arr.ndim}-dimensional array given. Inner arrays must be 2-D."
26872686
)
26882687

0 commit comments

Comments
 (0)