Skip to content

Commit 2a16144

Browse files
Add brain for numpy core module einsumfunc. (#1656)
Co-authored-by: Daniël van Noord <[email protected]>
1 parent 58750a6 commit 2a16144

File tree

3 files changed

+89
-0
lines changed

3 files changed

+89
-0
lines changed

ChangeLog

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ Release date: TBA
9595
* Fix test for Python ``3.11``. In some instances ``err.__traceback__`` will
9696
be uninferable now.
9797

98+
* Add brain for numpy core module ``einsumfunc``.
99+
100+
Closes PyCQA/pylint#5821
101+
98102
* Infer the ``DictUnpack`` value for ``Dict.getitem`` calls.
99103

100104
Closes #1195
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2+
# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
3+
# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt
4+
5+
"""
6+
Astroid hooks for numpy.core.einsumfunc module:
7+
https://github.com/numpy/numpy/blob/main/numpy/core/einsumfunc.py
8+
"""
9+
10+
from astroid import nodes
11+
from astroid.brain.helpers import register_module_extender
12+
from astroid.builder import parse
13+
from astroid.manager import AstroidManager
14+
15+
16+
def numpy_core_einsumfunc_transform() -> nodes.Module:
17+
return parse(
18+
"""
19+
def einsum(*operands, out=None, optimize=False, **kwargs):
20+
return numpy.ndarray([0, 0])
21+
"""
22+
)
23+
24+
25+
register_module_extender(
26+
AstroidManager(), "numpy.core.einsumfunc", numpy_core_einsumfunc_transform
27+
)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2+
# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
3+
# Copyright (c) https://github.com/PyCQA/astroid/blob/main/CONTRIBUTORS.txt
4+
5+
from __future__ import annotations
6+
7+
import pytest
8+
9+
from astroid import builder, nodes
10+
11+
try:
12+
import numpy # pylint: disable=unused-import
13+
14+
HAS_NUMPY = True
15+
except ImportError:
16+
HAS_NUMPY = False
17+
18+
19+
def _inferred_numpy_func_call(func_name: str, *func_args: str) -> nodes.FunctionDef:
20+
node = builder.extract_node(
21+
f"""
22+
import numpy as np
23+
func = np.{func_name:s}
24+
func({','.join(func_args):s})
25+
"""
26+
)
27+
return node.infer()
28+
29+
30+
@pytest.mark.skipif(not HAS_NUMPY, reason="This test requires the numpy library.")
31+
def test_numpy_function_calls_inferred_as_ndarray() -> None:
32+
"""
33+
Test that calls to numpy functions are inferred as numpy.ndarray
34+
"""
35+
method = "einsum"
36+
inferred_values = list(
37+
_inferred_numpy_func_call(method, "ii, np.arange(25).reshape(5, 5)")
38+
)
39+
40+
assert len(inferred_values) == 1, f"Too much inferred value for {method:s}"
41+
assert inferred_values[-1].pytype() in (
42+
".ndarray",
43+
), f"Illicit type for {method:s} ({inferred_values[-1].pytype()})"
44+
45+
46+
@pytest.mark.skipif(not HAS_NUMPY, reason="This test requires the numpy library.")
47+
def test_function_parameters() -> None:
48+
instance = builder.extract_node(
49+
f"""
50+
import numpy
51+
numpy.einsum #@
52+
"""
53+
)
54+
actual_args = instance.inferred()[0].args
55+
56+
assert actual_args.vararg == "operands"
57+
assert [arg.name for arg in actual_args.kwonlyargs] == ["out", "optimize"]
58+
assert actual_args.kwarg == "kwargs"

0 commit comments

Comments
 (0)