Skip to content

Commit d8b95d6

Browse files
authored
Backport gh-2608 (#2610)
This PR backports of #2608 from development branch to `maintenance/0.19.x`.
1 parent c9f8c3d commit d8b95d6

File tree

7 files changed

+64
-19
lines changed

7 files changed

+64
-19
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ This release is compatible with NumPy 2.3.3.
7474
* Fixed tests for the rounding functions to depend on minimum required numpy version [#2589](https://github.com/IntelPython/dpnp/pull/2589)
7575
* Fixed tests for the ufuncs to depend on minimum required numpy version [#2590](https://github.com/IntelPython/dpnp/pull/2590)
7676
* Added missing permission definition in `Autoupdate pre-commit` GitHub workflow [#2591](https://github.com/IntelPython/dpnp/pull/2591)
77+
* Resolved issue with the cyclic import in `linalg` submodule [#2608](https://github.com/IntelPython/dpnp/pull/2608)
7778

7879
### Security
7980

dpnp/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,14 @@
7070
from .dpnp_iface_utils import *
7171
from .dpnp_iface_utils import __all__ as _ifaceutils__all__
7272
from ._version import get_versions
73+
from . import linalg as linalg
7374

7475
__all__ = _iface__all__
7576
__all__ += _ifaceutils__all__
7677

78+
# add submodules
79+
__all__ += ["linalg"]
80+
7781

7882
__version__ = get_versions()["version"]
7983
del get_versions

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,9 @@ PYBIND11_MODULE(_lapack_impl, m)
8383
.value("C", oneapi::mkl::transpose::C)
8484
.export_values(); // Optional, allows access like `Transpose.N`
8585

86-
// Register a custom LinAlgError exception in the dpnp.linalg submodule
87-
py::module_ linalg_module = py::module_::import("dpnp.linalg");
88-
py::register_exception<lapack_ext::LinAlgError>(
89-
linalg_module, "LinAlgError", PyExc_ValueError);
86+
// Register a LinAlgError exception in the current submodule
87+
py::register_exception<lapack_ext::LinAlgError>(m, "LinAlgError",
88+
PyExc_ValueError);
9089

9190
init_dispatch_vectors();
9291
init_dispatch_tables();

dpnp/dpnp_iface.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
from dpnp.dpnp_algo import *
5353
from dpnp.dpnp_array import dpnp_array
5454
from dpnp.fft import *
55-
from dpnp.linalg import *
5655
from dpnp.memory import *
5756
from dpnp.random import *
5857
from dpnp.special import *

dpnp/linalg/__init__.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,44 @@
3434
"""
3535

3636

37-
from dpnp.linalg.dpnp_iface_linalg import *
38-
from dpnp.linalg.dpnp_iface_linalg import __all__ as __all__linalg
37+
from .dpnp_iface_linalg import (
38+
LinAlgError,
39+
)
40+
from .dpnp_iface_linalg import __all__ as __all__linalg
41+
from .dpnp_iface_linalg import (
42+
cholesky,
43+
cond,
44+
cross,
45+
det,
46+
diagonal,
47+
eig,
48+
eigh,
49+
eigvals,
50+
eigvalsh,
51+
inv,
52+
lstsq,
53+
lu_factor,
54+
lu_solve,
55+
matmul,
56+
matrix_norm,
57+
matrix_power,
58+
matrix_rank,
59+
matrix_transpose,
60+
multi_dot,
61+
norm,
62+
outer,
63+
pinv,
64+
qr,
65+
slogdet,
66+
solve,
67+
svd,
68+
svdvals,
69+
tensordot,
70+
tensorinv,
71+
tensorsolve,
72+
trace,
73+
vecdot,
74+
vector_norm,
75+
)
3976

4077
__all__ = __all__linalg

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@
3838

3939
# pylint: disable=invalid-name
4040
# pylint: disable=no-member
41+
# pylint: disable=no-name-in-module
4142

4243
from typing import NamedTuple
4344

4445
import numpy
4546
from dpctl.tensor._numpy_helper import normalize_axis_tuple
4647

4748
import dpnp
49+
from dpnp.backend.extensions.lapack._lapack_impl import LinAlgError
4850

4951
from .dpnp_utils_linalg import (
5052
assert_2d,
@@ -70,6 +72,7 @@
7072
)
7173

7274
__all__ = [
75+
"LinAlgError",
7376
"cholesky",
7477
"cond",
7578
"cross",
@@ -105,6 +108,9 @@
105108
"vector_norm",
106109
]
107110

111+
# Need to set the module explicitly, since exposed by LAPACK pybind11 extension
112+
LinAlgError.__module__ = "dpnp.linalg"
113+
108114

109115
# pylint:disable=missing-class-docstring
110116
class EigResult(NamedTuple):
@@ -2330,7 +2336,7 @@ def tensorsolve(a, b, axes=None):
23302336
prod = numpy.prod(old_shape)
23312337

23322338
if a.size != prod**2:
2333-
raise dpnp.linalg.LinAlgError(
2339+
raise LinAlgError(
23342340
"Input arrays must satisfy the requirement "
23352341
"prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
23362342
)

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
import dpnp
5151
import dpnp.backend.extensions.lapack._lapack_impl as li
5252
from dpnp.dpnp_utils import get_usm_allocations
53-
from dpnp.linalg import LinAlgError as LinAlgError
5453

5554
__all__ = [
5655
"assert_2d",
@@ -943,7 +942,7 @@ def _check_lapack_dev_info(dev_info, error_msg=None):
943942
if any(dev_info):
944943
error_msg = error_msg or "Singular matrix"
945944

946-
raise LinAlgError(error_msg)
945+
raise li.LinAlgError(error_msg)
947946

948947

949948
def _common_type(*arrays):
@@ -1879,7 +1878,7 @@ def assert_2d(*arrays):
18791878

18801879
for a in arrays:
18811880
if a.ndim != 2:
1882-
raise LinAlgError(
1881+
raise li.LinAlgError(
18831882
f"{a.ndim}-dimensional array given. The input "
18841883
"array must be exactly two-dimensional"
18851884
)
@@ -1906,7 +1905,7 @@ def assert_stacked_2d(*arrays):
19061905

19071906
for a in arrays:
19081907
if a.ndim < 2:
1909-
raise LinAlgError(
1908+
raise li.LinAlgError(
19101909
f"{a.ndim}-dimensional array given. The input "
19111910
"array must be at least two-dimensional"
19121911
)
@@ -1942,7 +1941,7 @@ def assert_stacked_square(*arrays):
19421941
for a in arrays:
19431942
m, n = a.shape[-2:]
19441943
if m != n:
1945-
raise LinAlgError(
1944+
raise li.LinAlgError(
19461945
"Last 2 dimensions of the input array must be square"
19471946
)
19481947

@@ -2086,7 +2085,7 @@ def dpnp_cond(x, p=None):
20862085
"""Compute the condition number of a matrix."""
20872086

20882087
if _is_empty_2d(x):
2089-
raise LinAlgError("cond is not defined on empty arrays")
2088+
raise li.LinAlgError("cond is not defined on empty arrays")
20902089
if p is None or p == 2 or p == -2:
20912090
s = dpnp.linalg.svd(x, compute_uv=False)
20922091
if p == -2:
@@ -2340,15 +2339,15 @@ def dpnp_lstsq(a, b, rcond=None):
23402339
"""
23412340

23422341
if b.ndim > 2:
2343-
raise LinAlgError(
2342+
raise li.LinAlgError(
23442343
f"{b.ndim}-dimensional array given. The input "
23452344
"array must be exactly two-dimensional"
23462345
)
23472346

23482347
m, n = a.shape[-2:]
23492348
m2 = b.shape[0]
23502349
if m != m2:
2351-
raise LinAlgError("Incompatible dimensions")
2350+
raise li.LinAlgError("Incompatible dimensions")
23522351

23532352
u, s, vh = dpnp_svd(a, full_matrices=False, related_arrays=[b])
23542353

@@ -2669,20 +2668,20 @@ def dpnp_multi_dot(n, arrays, out=None):
26692668
"""Compute dot product of two or more arrays in a single function call."""
26702669

26712670
if not arrays[0].ndim in [1, 2]:
2672-
raise LinAlgError(
2671+
raise li.LinAlgError(
26732672
f"{arrays[0].ndim}-dimensional array given. "
26742673
"First array must be 1-D or 2-D."
26752674
)
26762675

26772676
if not arrays[-1].ndim in [1, 2]:
2678-
raise LinAlgError(
2677+
raise li.LinAlgError(
26792678
f"{arrays[-1].ndim}-dimensional array given. "
26802679
"Last array must be 1-D or 2-D."
26812680
)
26822681

26832682
for arr in arrays[1:-1]:
26842683
if arr.ndim != 2:
2685-
raise LinAlgError(
2684+
raise li.LinAlgError(
26862685
f"{arr.ndim}-dimensional array given. Inner arrays must be 2-D."
26872686
)
26882687

0 commit comments

Comments
 (0)