|
18 | 18 |
|
19 | 19 | import astroid
|
20 | 20 | from astroid import nodes
|
| 21 | +from astroid.typing import InferenceResult |
21 | 22 |
|
22 | 23 | from pylint import constants
|
23 | 24 | from pylint.checkers.utils import safe_infer
|
@@ -427,15 +428,64 @@ def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None:
|
427 | 428 | # Type annotation only (x: P) -> Association
|
428 | 429 | # BUT only if there's no actual assignment (to avoid duplicates)
|
429 | 430 | if isinstance(node.parent, nodes.AnnAssign) and node.parent.value is None:
|
| 431 | + inferred_types = utils.infer_node(node) |
| 432 | + element_types = extract_element_types(inferred_types) |
| 433 | + |
| 434 | + # Resolve nodes to actual class definitions |
| 435 | + resolved_types = resolve_to_class_def(element_types) |
| 436 | + |
430 | 437 | current = set(parent.associations_type[node.attrname])
|
431 |
| - parent.associations_type[node.attrname] = list( |
432 |
| - current | utils.infer_node(node) |
433 |
| - ) |
| 438 | + parent.associations_type[node.attrname] = list(current | resolved_types) |
434 | 439 | return
|
435 | 440 |
|
436 | 441 | # Everything else is also association (fallback)
|
437 | 442 | current = set(parent.associations_type[node.attrname])
|
438 |
| - parent.associations_type[node.attrname] = list(current | utils.infer_node(node)) |
| 443 | + inferred_types = utils.infer_node(node) |
| 444 | + element_types = extract_element_types(inferred_types) |
| 445 | + |
| 446 | + # Resolve Name nodes to actual class definitions |
| 447 | + resolved_types = resolve_to_class_def(element_types) |
| 448 | + parent.associations_type[node.attrname] = list(current | resolved_types) |
| 449 | + |
| 450 | + |
| 451 | +def resolve_to_class_def(types: set[nodes.NodeNG]) -> set[nodes.ClassDef]: |
| 452 | + """Resolve a set of nodes to ClassDef nodes.""" |
| 453 | + class_defs = set() |
| 454 | + for node in types: |
| 455 | + if isinstance(node, nodes.ClassDef): |
| 456 | + class_defs.add(node) |
| 457 | + elif isinstance(node, nodes.Name): |
| 458 | + inferred = safe_infer(node) |
| 459 | + if isinstance(inferred, nodes.ClassDef): |
| 460 | + class_defs.add(inferred) |
| 461 | + return class_defs |
| 462 | + |
| 463 | + |
| 464 | +def extract_element_types(inferred_types: set[InferenceResult]) -> set[nodes.NodeNG]: |
| 465 | + """Extract element types in case the inferred type is a container. |
| 466 | +
|
| 467 | + This function checks if the inferred type is a container type (like list, dict, etc.) |
| 468 | + and extracts the element type(s) from it. If the inferred type is a direct type (like a class), |
| 469 | + it adds that type directly to the set of element types it returns. |
| 470 | + """ |
| 471 | + element_types = set() |
| 472 | + |
| 473 | + for inferred_type in inferred_types: |
| 474 | + if isinstance(inferred_type, nodes.Subscript): |
| 475 | + slice_node = inferred_type.slice |
| 476 | + |
| 477 | + # Handle both Tuple (dict[K,V]) and single element (list[T]) |
| 478 | + elements = ( |
| 479 | + slice_node.elts if isinstance(slice_node, nodes.Tuple) else [slice_node] |
| 480 | + ) |
| 481 | + |
| 482 | + for elt in elements: |
| 483 | + if isinstance(elt, (nodes.Name, nodes.ClassDef)): |
| 484 | + element_types.add(elt) |
| 485 | + else: |
| 486 | + element_types.add(inferred_type) |
| 487 | + |
| 488 | + return element_types |
439 | 489 |
|
440 | 490 |
|
441 | 491 | def project_from_files(
|
|
0 commit comments