77from __future__ import annotations
88
99import argparse
10- from collections .abc import Generator
10+ import warnings
11+ from collections .abc import Generator , Sequence
1112from typing import Any
1213
1314import astroid
@@ -27,10 +28,28 @@ class DiaDefGenerator:
2728 def __init__ (self , linker : Linker , handler : DiadefsHandler ) -> None :
2829 """Common Diagram Handler initialization."""
2930 self .config = handler .config
31+ self .args = handler .args
3032 self .module_names : bool = False
3133 self ._set_default_options ()
3234 self .linker = linker
3335 self .classdiagram : ClassDiagram # defined by subclasses
36+ # Only pre-calculate depths if user has requested a max_depth
37+ if handler .config .max_depth is not None :
38+ # Detect which of the args are leaf nodes
39+ leaf_nodes = self .get_leaf_nodes ()
40+
41+ # Emit a warning if any of the args are not leaf nodes
42+ diff = set (self .args ).difference (set (leaf_nodes ))
43+ if len (diff ) > 0 :
44+ warnings .warn (
45+ "Detected nested names within the specified packages. "
46+ f"The following packages: { sorted (diff )} will be ignored for "
47+ f"depth calculations, using only: { sorted (leaf_nodes )} as the base for limiting "
48+ "package depth." ,
49+ stacklevel = 2 ,
50+ )
51+
52+ self .args_depths = {module : module .count ("." ) for module in leaf_nodes }
3453
3554 def get_title (self , node : nodes .ClassDef ) -> str :
3655 """Get title for objects."""
@@ -39,6 +58,22 @@ def get_title(self, node: nodes.ClassDef) -> str:
3958 title = f"{ node .root ().name } .{ title } "
4059 return title # type: ignore[no-any-return]
4160
61+ def get_leaf_nodes (self ) -> list [str ]:
62+ """
63+ Get the leaf nodes from the list of args in the generator.
64+
65+ A leaf node is one that is not a prefix (with an extra dot) of any other node.
66+ """
67+ leaf_nodes = [
68+ module
69+ for module in self .args
70+ if not any (
71+ other != module and other .startswith (module + "." )
72+ for other in self .args
73+ )
74+ ]
75+ return leaf_nodes
76+
4277 def _set_option (self , option : bool | None ) -> bool :
4378 """Activate some options if not explicitly deactivated."""
4479 # if we have a class diagram, we want more information by default;
@@ -67,6 +102,30 @@ def _get_levels(self) -> tuple[int, int]:
67102 """Help function for search levels."""
68103 return self .anc_level , self .association_level
69104
105+ def _should_include_by_depth (self , node : nodes .NodeNG ) -> bool :
106+ """Check if a node should be included based on depth.
107+
108+ A node will be included if it is at or below the max_depth relative to the
109+ specified base packages. A node is considered to be a base package if it is the
110+ deepest package in the list of specified packages. In other words the base nodes
111+ are the leaf nodes of the specified package tree.
112+ """
113+ # If max_depth is not set, include all nodes
114+ if self .config .max_depth is None :
115+ return True
116+
117+ # Calculate the absolute depth of the node
118+ name = node .root ().name
119+ absolute_depth = name .count ("." )
120+
121+ # Retrieve the base depth to compare against
122+ relative_depth = next (
123+ (v for k , v in self .args_depths .items () if name .startswith (k )), None
124+ )
125+ return relative_depth is not None and bool (
126+ (absolute_depth - relative_depth ) <= self .config .max_depth
127+ )
128+
70129 def show_node (self , node : nodes .ClassDef ) -> bool :
71130 """Determine if node should be shown based on config."""
72131 if node .root ().name == "builtins" :
@@ -75,7 +134,8 @@ def show_node(self, node: nodes.ClassDef) -> bool:
75134 if is_stdlib_module (node .root ().name ):
76135 return self .config .show_stdlib # type: ignore[no-any-return]
77136
78- return True
137+ # Filter node by depth
138+ return self ._should_include_by_depth (node )
79139
80140 def add_class (self , node : nodes .ClassDef ) -> None :
81141 """Visit one class and add it to diagram."""
@@ -163,7 +223,7 @@ def visit_module(self, node: nodes.Module) -> None:
163223
164224 add this class to the package diagram definition
165225 """
166- if self .pkgdiagram :
226+ if self .pkgdiagram and self . _should_include_by_depth ( node ) :
167227 self .linker .visit (node )
168228 self .pkgdiagram .add_object (node .name , node )
169229
@@ -177,7 +237,7 @@ def visit_classdef(self, node: nodes.ClassDef) -> None:
177237
178238 def visit_importfrom (self , node : nodes .ImportFrom ) -> None :
179239 """Visit astroid.ImportFrom and catch modules for package diagram."""
180- if self .pkgdiagram :
240+ if self .pkgdiagram and self . _should_include_by_depth ( node ) :
181241 self .pkgdiagram .add_from_depend (node , node .modname )
182242
183243
@@ -208,8 +268,9 @@ def class_diagram(self, project: Project, klass: nodes.ClassDef) -> ClassDiagram
208268class DiadefsHandler :
209269 """Get diagram definitions from user (i.e. xml files) or generate them."""
210270
211- def __init__ (self , config : argparse .Namespace ) -> None :
271+ def __init__ (self , config : argparse .Namespace , args : Sequence [ str ] ) -> None :
212272 self .config = config
273+ self .args = args
213274
214275 def get_diadefs (self , project : Project , linker : Linker ) -> list [ClassDiagram ]:
215276 """Get the diagram's configuration data.
0 commit comments