Skip to content

Commit 13b0cd5

Browse files
committed
Resolve an issue with cyclic import in linalg submodule (#2608)
The PR reworks the import of `linalg` submodule to avoid a wildcard import. Also it resolves the issue reported by pylint: > Cyclic import (dpnp.linalg -> dpnp.linalg.dpnp_iface_linalg -> dpnp.linalg.dpnp_utils_linalg) (cyclic-import) and moves `LinAlgError` exception to be exposed by LAPACK pybind11 extension, because it is created there. While later the exception patched at python level to be set to "dpnp.linalg" submodule explicitly. Otherwise we have the import cycle like: > linalg/__init__.py -> dpnp_iface_linalg.py -> dpnp_utils_linalg.py -> linalg/__init__.py which might cause the import failure.
1 parent c9f8c3d commit 13b0cd5

File tree

7 files changed

+64
-19
lines changed

7 files changed

+64
-19
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ This release is compatible with NumPy 2.3.3.
7474
* Fixed tests for the rounding functions to depend on minimum required numpy version [#2589](https://github.com/IntelPython/dpnp/pull/2589)
7575
* Fixed tests for the ufuncs to depend on minimum required numpy version [#2590](https://github.com/IntelPython/dpnp/pull/2590)
7676
* Added missing permission definition in `Autoupdate pre-commit` GitHub workflow [#2591](https://github.com/IntelPython/dpnp/pull/2591)
77+
* Resolved issue with the cyclic import in `linalg` submodule [#2608](https://github.com/IntelPython/dpnp/pull/2608)
7778

7879
### Security
7980

dpnp/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,14 @@
7070
from .dpnp_iface_utils import *
7171
from .dpnp_iface_utils import __all__ as _ifaceutils__all__
7272
from ._version import get_versions
73+
from . import linalg as linalg
7374

7475
__all__ = _iface__all__
7576
__all__ += _ifaceutils__all__
7677

78+
# add submodules
79+
__all__ += ["linalg"]
80+
7781

7882
__version__ = get_versions()["version"]
7983
del get_versions

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,9 @@ PYBIND11_MODULE(_lapack_impl, m)
8383
.value("C", oneapi::mkl::transpose::C)
8484
.export_values(); // Optional, allows access like `Transpose.N`
8585

86-
// Register a custom LinAlgError exception in the dpnp.linalg submodule
87-
py::module_ linalg_module = py::module_::import("dpnp.linalg");
88-
py::register_exception<lapack_ext::LinAlgError>(
89-
linalg_module, "LinAlgError", PyExc_ValueError);
86+
// Register a LinAlgError exception in the current submodule
87+
py::register_exception<lapack_ext::LinAlgError>(m, "LinAlgError",
88+
PyExc_ValueError);
9089

9190
init_dispatch_vectors();
9291
init_dispatch_tables();

dpnp/dpnp_iface.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
from dpnp.dpnp_algo import *
5353
from dpnp.dpnp_array import dpnp_array
5454
from dpnp.fft import *
55-
from dpnp.linalg import *
5655
from dpnp.memory import *
5756
from dpnp.random import *
5857
from dpnp.special import *

dpnp/linalg/__init__.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,44 @@
3434
"""
3535

3636

37-
from dpnp.linalg.dpnp_iface_linalg import *
38-
from dpnp.linalg.dpnp_iface_linalg import __all__ as __all__linalg
37+
from .dpnp_iface_linalg import (
38+
LinAlgError,
39+
)
40+
from .dpnp_iface_linalg import __all__ as __all__linalg
41+
from .dpnp_iface_linalg import (
42+
cholesky,
43+
cond,
44+
cross,
45+
det,
46+
diagonal,
47+
eig,
48+
eigh,
49+
eigvals,
50+
eigvalsh,
51+
inv,
52+
lstsq,
53+
lu_factor,
54+
lu_solve,
55+
matmul,
56+
matrix_norm,
57+
matrix_power,
58+
matrix_rank,
59+
matrix_transpose,
60+
multi_dot,
61+
norm,
62+
outer,
63+
pinv,
64+
qr,
65+
slogdet,
66+
solve,
67+
svd,
68+
svdvals,
69+
tensordot,
70+
tensorinv,
71+
tensorsolve,
72+
trace,
73+
vecdot,
74+
vector_norm,
75+
)
3976

4077
__all__ = __all__linalg

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@
3838

3939
# pylint: disable=invalid-name
4040
# pylint: disable=no-member
41+
# pylint: disable=no-name-in-module
4142

4243
from typing import NamedTuple
4344

4445
import numpy
4546
from dpctl.tensor._numpy_helper import normalize_axis_tuple
4647

4748
import dpnp
49+
from dpnp.backend.extensions.lapack._lapack_impl import LinAlgError
4850

4951
from .dpnp_utils_linalg import (
5052
assert_2d,
@@ -70,6 +72,7 @@
7072
)
7173

7274
__all__ = [
75+
"LinAlgError",
7376
"cholesky",
7477
"cond",
7578
"cross",
@@ -105,6 +108,9 @@
105108
"vector_norm",
106109
]
107110

111+
# Need to set the module explicitly, since exposed by LAPACK pybind11 extension
112+
LinAlgError.__module__ = "dpnp.linalg"
113+
108114

109115
# pylint:disable=missing-class-docstring
110116
class EigResult(NamedTuple):
@@ -2330,7 +2336,7 @@ def tensorsolve(a, b, axes=None):
23302336
prod = numpy.prod(old_shape)
23312337

23322338
if a.size != prod**2:
2333-
raise dpnp.linalg.LinAlgError(
2339+
raise LinAlgError(
23342340
"Input arrays must satisfy the requirement "
23352341
"prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
23362342
)

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
import dpnp
5151
import dpnp.backend.extensions.lapack._lapack_impl as li
5252
from dpnp.dpnp_utils import get_usm_allocations
53-
from dpnp.linalg import LinAlgError as LinAlgError
5453

5554
__all__ = [
5655
"assert_2d",
@@ -943,7 +942,7 @@ def _check_lapack_dev_info(dev_info, error_msg=None):
943942
if any(dev_info):
944943
error_msg = error_msg or "Singular matrix"
945944

946-
raise LinAlgError(error_msg)
945+
raise li.LinAlgError(error_msg)
947946

948947

949948
def _common_type(*arrays):
@@ -1879,7 +1878,7 @@ def assert_2d(*arrays):
18791878

18801879
for a in arrays:
18811880
if a.ndim != 2:
1882-
raise LinAlgError(
1881+
raise li.LinAlgError(
18831882
f"{a.ndim}-dimensional array given. The input "
18841883
"array must be exactly two-dimensional"
18851884
)
@@ -1906,7 +1905,7 @@ def assert_stacked_2d(*arrays):
19061905

19071906
for a in arrays:
19081907
if a.ndim < 2:
1909-
raise LinAlgError(
1908+
raise li.LinAlgError(
19101909
f"{a.ndim}-dimensional array given. The input "
19111910
"array must be at least two-dimensional"
19121911
)
@@ -1942,7 +1941,7 @@ def assert_stacked_square(*arrays):
19421941
for a in arrays:
19431942
m, n = a.shape[-2:]
19441943
if m != n:
1945-
raise LinAlgError(
1944+
raise li.LinAlgError(
19461945
"Last 2 dimensions of the input array must be square"
19471946
)
19481947

@@ -2086,7 +2085,7 @@ def dpnp_cond(x, p=None):
20862085
"""Compute the condition number of a matrix."""
20872086

20882087
if _is_empty_2d(x):
2089-
raise LinAlgError("cond is not defined on empty arrays")
2088+
raise li.LinAlgError("cond is not defined on empty arrays")
20902089
if p is None or p == 2 or p == -2:
20912090
s = dpnp.linalg.svd(x, compute_uv=False)
20922091
if p == -2:
@@ -2340,15 +2339,15 @@ def dpnp_lstsq(a, b, rcond=None):
23402339
"""
23412340

23422341
if b.ndim > 2:
2343-
raise LinAlgError(
2342+
raise li.LinAlgError(
23442343
f"{b.ndim}-dimensional array given. The input "
23452344
"array must be exactly two-dimensional"
23462345
)
23472346

23482347
m, n = a.shape[-2:]
23492348
m2 = b.shape[0]
23502349
if m != m2:
2351-
raise LinAlgError("Incompatible dimensions")
2350+
raise li.LinAlgError("Incompatible dimensions")
23522351

23532352
u, s, vh = dpnp_svd(a, full_matrices=False, related_arrays=[b])
23542353

@@ -2669,20 +2668,20 @@ def dpnp_multi_dot(n, arrays, out=None):
26692668
"""Compute dot product of two or more arrays in a single function call."""
26702669

26712670
if not arrays[0].ndim in [1, 2]:
2672-
raise LinAlgError(
2671+
raise li.LinAlgError(
26732672
f"{arrays[0].ndim}-dimensional array given. "
26742673
"First array must be 1-D or 2-D."
26752674
)
26762675

26772676
if not arrays[-1].ndim in [1, 2]:
2678-
raise LinAlgError(
2677+
raise li.LinAlgError(
26792678
f"{arrays[-1].ndim}-dimensional array given. "
26802679
"Last array must be 1-D or 2-D."
26812680
)
26822681

26832682
for arr in arrays[1:-1]:
26842683
if arr.ndim != 2:
2685-
raise LinAlgError(
2684+
raise li.LinAlgError(
26862685
f"{arr.ndim}-dimensional array given. Inner arrays must be 2-D."
26872686
)
26882687

0 commit comments

Comments
 (0)