|
17 | 17 |
|
18 | 18 | import astroid
|
19 | 19 | from astroid import nodes
|
| 20 | +from astroid.typing import InferenceResult |
20 | 21 |
|
21 | 22 | from pylint import constants
|
22 | 23 | from pylint.checkers.utils import safe_infer
|
@@ -426,15 +427,64 @@ def handle(self, node: nodes.AssignAttr, parent: nodes.ClassDef) -> None:
|
426 | 427 | # Type annotation only (x: P) -> Association
|
427 | 428 | # BUT only if there's no actual assignment (to avoid duplicates)
|
428 | 429 | if isinstance(node.parent, nodes.AnnAssign) and node.parent.value is None:
|
| 430 | + inferred_types = utils.infer_node(node) |
| 431 | + element_types = extract_element_types(inferred_types) |
| 432 | + |
| 433 | + # Resolve nodes to actual class definitions |
| 434 | + resolved_types = resolve_to_class_def(element_types) |
| 435 | + |
429 | 436 | current = set(parent.associations_type[node.attrname])
|
430 |
| - parent.associations_type[node.attrname] = list( |
431 |
| - current | utils.infer_node(node) |
432 |
| - ) |
| 437 | + parent.associations_type[node.attrname] = list(current | resolved_types) |
433 | 438 | return
|
434 | 439 |
|
435 | 440 | # Everything else is also association (fallback)
|
436 | 441 | current = set(parent.associations_type[node.attrname])
|
437 |
| - parent.associations_type[node.attrname] = list(current | utils.infer_node(node)) |
| 442 | + inferred_types = utils.infer_node(node) |
| 443 | + element_types = extract_element_types(inferred_types) |
| 444 | + |
| 445 | + # Resolve Name nodes to actual class definitions |
| 446 | + resolved_types = resolve_to_class_def(element_types) |
| 447 | + parent.associations_type[node.attrname] = list(current | resolved_types) |
| 448 | + |
| 449 | + |
| 450 | +def resolve_to_class_def(types: set[nodes.NodeNG]) -> set[nodes.ClassDef]: |
| 451 | + """Resolve a set of nodes to ClassDef nodes.""" |
| 452 | + class_defs = set() |
| 453 | + for node in types: |
| 454 | + if isinstance(node, nodes.ClassDef): |
| 455 | + class_defs.add(node) |
| 456 | + elif isinstance(node, nodes.Name): |
| 457 | + inferred = safe_infer(node) |
| 458 | + if isinstance(inferred, nodes.ClassDef): |
| 459 | + class_defs.add(inferred) |
| 460 | + return class_defs |
| 461 | + |
| 462 | + |
| 463 | +def extract_element_types(inferred_types: set[InferenceResult]) -> set[nodes.NodeNG]: |
| 464 | + """Extract element types in case the inferred type is a container. |
| 465 | +
|
| 466 | + This function checks if the inferred type is a container type (like list, dict, etc.) |
| 467 | + and extracts the element type(s) from it. If the inferred type is a direct type (like a class), |
| 468 | + it adds that type directly to the set of element types it returns. |
| 469 | + """ |
| 470 | + element_types = set() |
| 471 | + |
| 472 | + for inferred_type in inferred_types: |
| 473 | + if isinstance(inferred_type, nodes.Subscript): |
| 474 | + slice_node = inferred_type.slice |
| 475 | + |
| 476 | + # Handle both Tuple (dict[K,V]) and single element (list[T]) |
| 477 | + elements = ( |
| 478 | + slice_node.elts if isinstance(slice_node, nodes.Tuple) else [slice_node] |
| 479 | + ) |
| 480 | + |
| 481 | + for elt in elements: |
| 482 | + if isinstance(elt, (nodes.Name, nodes.ClassDef)): |
| 483 | + element_types.add(elt) |
| 484 | + else: |
| 485 | + element_types.add(inferred_type) |
| 486 | + |
| 487 | + return element_types |
438 | 488 |
|
439 | 489 |
|
440 | 490 | def project_from_files(
|
|
0 commit comments