Skip to content

Commit fadac92

Browse files
authored
Improve inference for generic classes (PEP 695) (#2433)
1 parent e43e045 commit fadac92

File tree

6 files changed

+60
-8
lines changed

6 files changed

+60
-8
lines changed

ChangeLog

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ Release date: TBA
2222
Refs pylint-dev/#9626
2323
Refs pylint-dev/#9623
2424

25+
* Improve inference for generic classes using the PEP 695 syntax (Python 3.12).
26+
27+
Closes pylint-dev/#9406
28+
2529

2630
What's New in astroid 3.2.1?
2731
============================

astroid/brain/brain_typing.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,20 @@ def infer_typing_attr(
196196
return node.infer(context=ctx)
197197

198198

199+
def _looks_like_generic_class_pep695(node: ClassDef) -> bool:
200+
"""Check if class is using type parameter. Python 3.12+."""
201+
return len(node.type_params) > 0
202+
203+
204+
def infer_typing_generic_class_pep695(
205+
node: ClassDef, ctx: context.InferenceContext | None = None
206+
) -> Iterator[ClassDef]:
207+
"""Add __class_getitem__ for generic classes. Python 3.12+."""
208+
func_to_add = _extract_single_node(CLASS_GETITEM_TEMPLATE)
209+
node.locals["__class_getitem__"] = [func_to_add]
210+
return iter([node])
211+
212+
199213
def _looks_like_typedDict( # pylint: disable=invalid-name
200214
node: FunctionDef | ClassDef,
201215
) -> bool:
@@ -490,3 +504,8 @@ def register(manager: AstroidManager) -> None:
490504

491505
if PY312_PLUS:
492506
register_module_extender(manager, "typing", _typing_transform)
507+
manager.register_transform(
508+
ClassDef,
509+
inference_tip(infer_typing_generic_class_pep695),
510+
_looks_like_generic_class_pep695,
511+
)

astroid/nodes/scoped_nodes/scoped_nodes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2192,7 +2192,10 @@ def scope_lookup(
21922192
and name in AstroidManager().builtins_module
21932193
)
21942194
if (
2195-
any(node == base or base.parent_of(node) for base in self.bases)
2195+
any(
2196+
node == base or base.parent_of(node) and not self.type_params
2197+
for base in self.bases
2198+
)
21962199
or lookup_upper_frame
21972200
):
21982201
# Handle the case where we have either a name

astroid/protocols.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -924,8 +924,7 @@ def generic_type_assigned_stmts(
924924
context: InferenceContext | None = None,
925925
assign_path: None = None,
926926
) -> Generator[nodes.NodeNG, None, None]:
927-
"""Return empty generator (return -> raises StopIteration) so inferred value
928-
is Uninferable.
927+
"""Hack. Return any Node so inference doesn't fail
928+
when evaluating __class_getitem__. Revert if it's causing issues.
929929
"""
930-
return
931-
yield
930+
yield nodes.Const(None)

tests/brain/test_brain.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,24 @@ def test_typing_generic_subscriptable(self):
627627
assert isinstance(inferred, nodes.ClassDef)
628628
assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)
629629

630+
@test_utils.require_version(minver="3.12")
631+
def test_typing_generic_subscriptable_pep695(self):
632+
"""Test class using type parameters is subscriptable with __class_getitem__ (added in PY312)"""
633+
node = builder.extract_node(
634+
"""
635+
class Foo[T]: ...
636+
class Bar[T](Foo[T]): ...
637+
"""
638+
)
639+
inferred = next(node.infer())
640+
assert isinstance(inferred, nodes.ClassDef)
641+
assert inferred.name == "Bar"
642+
assert isinstance(inferred.getattr("__class_getitem__")[0], nodes.FunctionDef)
643+
ancestors = list(inferred.ancestors())
644+
assert len(ancestors) == 2
645+
assert ancestors[0].name == "Foo"
646+
assert ancestors[1].name == "object"
647+
630648
@test_utils.require_version(minver="3.9")
631649
def test_typing_annotated_subscriptable(self):
632650
"""Test typing.Annotated is subscriptable with __class_getitem__"""

tests/test_protocols.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,20 +425,29 @@ def test_assigned_stmts_type_var():
425425
assign_stmts = extract_node("type Point[T] = tuple[float, float]")
426426
type_var: nodes.TypeVar = assign_stmts.type_params[0]
427427
assigned = next(type_var.name.assigned_stmts())
428-
assert assigned is Uninferable
428+
# Hack so inference doesn't fail when evaluating __class_getitem__
429+
# Revert if it's causing issues.
430+
assert isinstance(assigned, nodes.Const)
431+
assert assigned.value is None
429432

430433
@staticmethod
431434
def test_assigned_stmts_type_var_tuple():
432435
"""The result is 'Uninferable' and no exception is raised."""
433436
assign_stmts = extract_node("type Alias[*Ts] = tuple[*Ts]")
434437
type_var_tuple: nodes.TypeVarTuple = assign_stmts.type_params[0]
435438
assigned = next(type_var_tuple.name.assigned_stmts())
436-
assert assigned is Uninferable
439+
# Hack so inference doesn't fail when evaluating __class_getitem__
440+
# Revert if it's causing issues.
441+
assert isinstance(assigned, nodes.Const)
442+
assert assigned.value is None
437443

438444
@staticmethod
439445
def test_assigned_stmts_param_spec():
440446
"""The result is 'Uninferable' and no exception is raised."""
441447
assign_stmts = extract_node("type Alias[**P] = Callable[P, int]")
442448
param_spec: nodes.ParamSpec = assign_stmts.type_params[0]
443449
assigned = next(param_spec.name.assigned_stmts())
444-
assert assigned is Uninferable
450+
# Hack so inference doesn't fail when evaluating __class_getitem__
451+
# Revert if it's causing issues.
452+
assert isinstance(assigned, nodes.Const)
453+
assert assigned.value is None

0 commit comments

Comments
 (0)