Skip to content

Commit 49576c8

Browse files
Merge master into scipy_submodule
2 parents 1b1f6c4 + 285c313 commit 49576c8

File tree

8 files changed

+65
-20
lines changed

8 files changed

+65
-20
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ This release is compatible with NumPy 2.3.3.
7474
* Fixed tests for the rounding functions to depend on minimum required numpy version [#2589](https://github.com/IntelPython/dpnp/pull/2589)
7575
* Fixed tests for the ufuncs to depend on minimum required numpy version [#2590](https://github.com/IntelPython/dpnp/pull/2590)
7676
* Added missing permission definition in `Autoupdate pre-commit` GitHub workflow [#2591](https://github.com/IntelPython/dpnp/pull/2591)
77+
* Resolved issue with the cyclic import in `linalg` submodule [#2608](https://github.com/IntelPython/dpnp/pull/2608)
7778

7879
### Security
7980

dpnp/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,16 @@
7070
from .dpnp_iface_utils import *
7171
from .dpnp_iface_utils import __all__ as _ifaceutils__all__
7272
from ._version import get_versions
73+
from . import linalg as linalg
7374
from . import scipy as scipy
7475

7576
__all__ = _iface__all__
7677
__all__ += _ifaceutils__all__
7778
__all__ += ["scipy"]
7879

80+
# add submodules
81+
__all__ += ["linalg"]
82+
7983

8084
__version__ = get_versions()["version"]
8185
del get_versions

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,9 @@ PYBIND11_MODULE(_lapack_impl, m)
8383
.value("C", oneapi::mkl::transpose::C)
8484
.export_values(); // Optional, allows access like `Transpose.N`
8585

86-
// Register a custom LinAlgError exception in the dpnp.linalg submodule
87-
py::module_ linalg_module = py::module_::import("dpnp.linalg");
88-
py::register_exception<lapack_ext::LinAlgError>(
89-
linalg_module, "LinAlgError", PyExc_ValueError);
86+
// Register a LinAlgError exception in the current submodule
87+
py::register_exception<lapack_ext::LinAlgError>(m, "LinAlgError",
88+
PyExc_ValueError);
9089

9190
init_dispatch_vectors();
9291
init_dispatch_tables();

dpnp/dpnp_iface.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
from dpnp.dpnp_algo import *
5353
from dpnp.dpnp_array import dpnp_array
5454
from dpnp.fft import *
55-
from dpnp.linalg import *
5655
from dpnp.memory import *
5756
from dpnp.random import *
5857

dpnp/linalg/__init__.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,44 @@
3434
"""
3535

3636

37-
from dpnp.linalg.dpnp_iface_linalg import *
38-
from dpnp.linalg.dpnp_iface_linalg import __all__ as __all__linalg
37+
from .dpnp_iface_linalg import (
38+
LinAlgError,
39+
)
40+
from .dpnp_iface_linalg import __all__ as __all__linalg
41+
from .dpnp_iface_linalg import (
42+
cholesky,
43+
cond,
44+
cross,
45+
det,
46+
diagonal,
47+
eig,
48+
eigh,
49+
eigvals,
50+
eigvalsh,
51+
inv,
52+
lstsq,
53+
lu_factor,
54+
lu_solve,
55+
matmul,
56+
matrix_norm,
57+
matrix_power,
58+
matrix_rank,
59+
matrix_transpose,
60+
multi_dot,
61+
norm,
62+
outer,
63+
pinv,
64+
qr,
65+
slogdet,
66+
solve,
67+
svd,
68+
svdvals,
69+
tensordot,
70+
tensorinv,
71+
tensorsolve,
72+
trace,
73+
vecdot,
74+
vector_norm,
75+
)
3976

4077
__all__ = __all__linalg

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@
3838

3939
# pylint: disable=invalid-name
4040
# pylint: disable=no-member
41+
# pylint: disable=no-name-in-module
4142

4243
from typing import NamedTuple
4344

4445
import numpy
4546
from dpctl.tensor._numpy_helper import normalize_axis_tuple
4647

4748
import dpnp
49+
from dpnp.backend.extensions.lapack._lapack_impl import LinAlgError
4850

4951
from .dpnp_utils_linalg import (
5052
assert_2d,
@@ -68,6 +70,7 @@
6870
)
6971

7072
__all__ = [
73+
"LinAlgError",
7174
"cholesky",
7275
"cond",
7376
"cross",
@@ -101,6 +104,9 @@
101104
"vector_norm",
102105
]
103106

107+
# Need to set the module explicitly, since exposed by LAPACK pybind11 extension
108+
LinAlgError.__module__ = "dpnp.linalg"
109+
104110

105111
# pylint:disable=missing-class-docstring
106112
class EigResult(NamedTuple):
@@ -2183,7 +2189,7 @@ def tensorsolve(a, b, axes=None):
21832189
prod = numpy.prod(old_shape)
21842190

21852191
if a.size != prod**2:
2186-
raise dpnp.linalg.LinAlgError(
2192+
raise LinAlgError(
21872193
"Input arrays must satisfy the requirement "
21882194
"prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
21892195
)

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
import dpnp
5151
import dpnp.backend.extensions.lapack._lapack_impl as li
5252
from dpnp.dpnp_utils import get_usm_allocations
53-
from dpnp.linalg import LinAlgError as LinAlgError
5453

5554
__all__ = [
5655
"assert_2d",
@@ -818,7 +817,7 @@ def _check_lapack_dev_info(dev_info, error_msg=None):
818817
if any(dev_info):
819818
error_msg = error_msg or "Singular matrix"
820819

821-
raise LinAlgError(error_msg)
820+
raise li.LinAlgError(error_msg)
822821

823822

824823
def _common_type(*arrays):
@@ -1737,7 +1736,7 @@ def assert_2d(*arrays):
17371736

17381737
for a in arrays:
17391738
if a.ndim != 2:
1740-
raise LinAlgError(
1739+
raise li.LinAlgError(
17411740
f"{a.ndim}-dimensional array given. The input "
17421741
"array must be exactly two-dimensional"
17431742
)
@@ -1764,7 +1763,7 @@ def assert_stacked_2d(*arrays):
17641763

17651764
for a in arrays:
17661765
if a.ndim < 2:
1767-
raise LinAlgError(
1766+
raise li.LinAlgError(
17681767
f"{a.ndim}-dimensional array given. The input "
17691768
"array must be at least two-dimensional"
17701769
)
@@ -1800,7 +1799,7 @@ def assert_stacked_square(*arrays):
18001799
for a in arrays:
18011800
m, n = a.shape[-2:]
18021801
if m != n:
1803-
raise LinAlgError(
1802+
raise li.LinAlgError(
18041803
"Last 2 dimensions of the input array must be square"
18051804
)
18061805

@@ -1944,7 +1943,7 @@ def dpnp_cond(x, p=None):
19441943
"""Compute the condition number of a matrix."""
19451944

19461945
if _is_empty_2d(x):
1947-
raise LinAlgError("cond is not defined on empty arrays")
1946+
raise li.LinAlgError("cond is not defined on empty arrays")
19481947
if p is None or p == 2 or p == -2:
19491948
s = dpnp.linalg.svd(x, compute_uv=False)
19501949
if p == -2:
@@ -2198,15 +2197,15 @@ def dpnp_lstsq(a, b, rcond=None):
21982197
"""
21992198

22002199
if b.ndim > 2:
2201-
raise LinAlgError(
2200+
raise li.LinAlgError(
22022201
f"{b.ndim}-dimensional array given. The input "
22032202
"array must be exactly two-dimensional"
22042203
)
22052204

22062205
m, n = a.shape[-2:]
22072206
m2 = b.shape[0]
22082207
if m != m2:
2209-
raise LinAlgError("Incompatible dimensions")
2208+
raise li.LinAlgError("Incompatible dimensions")
22102209

22112210
u, s, vh = dpnp_svd(a, full_matrices=False, related_arrays=[b])
22122211

@@ -2318,20 +2317,20 @@ def dpnp_multi_dot(n, arrays, out=None):
23182317
"""Compute dot product of two or more arrays in a single function call."""
23192318

23202319
if not arrays[0].ndim in [1, 2]:
2321-
raise LinAlgError(
2320+
raise li.LinAlgError(
23222321
f"{arrays[0].ndim}-dimensional array given. "
23232322
"First array must be 1-D or 2-D."
23242323
)
23252324

23262325
if not arrays[-1].ndim in [1, 2]:
2327-
raise LinAlgError(
2326+
raise li.LinAlgError(
23282327
f"{arrays[-1].ndim}-dimensional array given. "
23292328
"Last array must be 1-D or 2-D."
23302329
)
23312330

23322331
for arr in arrays[1:-1]:
23332332
if arr.ndim != 2:
2334-
raise LinAlgError(
2333+
raise li.LinAlgError(
23352334
f"{arr.ndim}-dimensional array given. Inner arrays must be 2-D."
23362335
)
23372336

environments/build_conda_pkg.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ channels:
33
- conda-forge
44
dependencies:
55
- python=3.13
6-
- conda-build=25.7.0
6+
- conda-build=25.9.0

0 commit comments

Comments
 (0)