99
1010
1111class BaseGraph (BaseModel , abc .ABC ):
12+ """Base abstract class for graph representations of dialogues.
13+
14+ This class provides the interface for graph operations and manipulations.
15+ It inherits from both BaseModel for data validation and ABC for abstract methods.
16+
17+ Attributes:
18+ graph_dict (dict): Dictionary containing the graph structure with nodes and edges.
19+ graph (Optional[nx.Graph]): NetworkX graph instance.
20+ node_mapping (Optional[dict]): Mapping between original node IDs and internal representation.
21+ """
22+
1223 graph_dict : dict
1324 graph : Optional [nx .Graph ] = None
1425 node_mapping : Optional [dict ] = None
@@ -77,8 +88,22 @@ def get_list_from_graph(self):
7788
7889
7990class Graph (BaseGraph ):
91+ """Implementation of BaseGraph for dialogue graph operations.
92+
93+ This class provides concrete implementations for graph operations including
94+ loading, visualization, path finding, and graph manipulation methods.
95+
96+ Attributes:
97+ Inherits all attributes from BaseGraph.
98+ """
99+
80100 def __init__ (self , graph_dict : dict , ** kwargs : Any ):
81- # Pass graph_dict to the parent class
101+ """Initialize the Graph instance.
102+
103+ Args:
104+ graph_dict (dict): Dictionary containing the graph structure.
105+ **kwargs: Additional keyword arguments passed to parent class.
106+ """
82107 super ().__init__ (graph_dict = graph_dict , ** kwargs )
83108 if graph_dict :
84109 self .load_graph ()
@@ -101,6 +126,11 @@ def check_edges(self, seq: list[list[int]]) -> bool:
101126 return seen == edge_set
102127
103128 def load_graph (self ):
129+ """Load graph from dictionary representation into NetworkX DiGraph.
130+
131+ Creates a directed graph from the graph_dict, handling node and edge attributes.
132+ Also creates node mapping if node IDs need renumbering.
133+ """
104134 self .graph = nx .DiGraph ()
105135 nodes = sorted ([v ["id" ] for v in self .graph_dict ["nodes" ]])
106136 logging .debug (f"Nodes: { nodes } " )
@@ -144,6 +174,11 @@ def load_graph(self):
144174 )
145175
146176 def visualise (self , * args , ** kwargs ):
177+ """Visualize the graph using matplotlib and networkx.
178+
179+ Creates a visualization of the graph with nodes and edges labeled with utterances.
180+ Uses pygraphviz layout if available, falls back to kamada_kawai_layout.
181+ """
147182 plt .figure (figsize = (17 , 11 )) # Make the plot bigger
148183 try :
149184 pos = nx .nx_agraph .pygraphviz_layout (self .graph )
@@ -173,6 +208,15 @@ def visualise(self, *args, **kwargs):
173208 plt .show ()
174209
175210 def visualise_short (self , name , * args , ** kwargs ):
211+ """Create a compact visualization of the graph.
212+
213+ Args:
214+ name (str): Title for the visualization.
215+ *args: Variable length argument list.
216+ **kwargs: Arbitrary keyword arguments.
217+
218+ Creates a simplified visualization showing only node IDs and utterance counts.
219+ """
176220 try :
177221 pos = nx .nx_agraph .pygraphviz_layout (self .graph )
178222 except ImportError as e :
@@ -211,29 +255,71 @@ def visualise_short(self, name, *args, **kwargs):
211255 plt .show ()
212256
213257 def find_nodes_by_utterance (self , utterance : str ) -> list [dict ]:
258+ """Find nodes containing a specific utterance.
259+
260+ Args:
261+ utterance (str): The utterance to search for.
262+
263+ Returns:
264+ list[dict]: List of nodes containing the utterance.
265+ """
214266 return [
215267 node for node in self .graph_dict ["nodes" ] if utterance in node ["utterances" ]
216268 ]
217269
218270 def find_edges_by_utterance (self , utterance : str ) -> list [dict ]:
271+ """Find edges containing a specific utterance.
272+
273+ Args:
274+ utterance (str): The utterance to search for.
275+
276+ Returns:
277+ list[dict]: List of edges containing the utterance.
278+ """
219279 return [
220280 edge for edge in self .graph_dict ["edges" ] if utterance in edge ["utterances" ]
221281 ]
222282
223283 def get_nodes_by_id (self , id : int ):
284+ """Retrieve a node by its ID.
285+
286+ Args:
287+ id (int): The ID of the node to retrieve.
288+
289+ Returns:
290+ dict: The node with the specified ID if found, None otherwise.
291+ """
224292 for node in self .graph_dict ["nodes" ]:
225293 if node ["id" ] == id :
226294 return node
227295
228296 def get_edges_by_source (self , id : int ):
297+ """Get all edges originating from a specific node.
298+
299+ Args:
300+ id (int): The ID of the source node.
301+
302+ Returns:
303+ list[dict]: List of edges with the specified source node.
304+ """
229305 return [edge for edge in self .graph_dict ["edges" ] if edge ["source" ] == id ]
230306
231307 def get_edges_by_target (self , id : int ):
308+ """Get all edges targeting a specific node.
309+
310+ Args:
311+ id (int): The ID of the target node.
312+
313+ Returns:
314+ list[dict]: List of edges with the specified target node.
315+ """
232316 return [edge for edge in self .graph_dict ["edges" ] if edge ["target" ] == id ]
233317
234318 def match_edges_nodes (self ) -> bool :
235- """Checks whether source and target
236- of all the edges correspond to nodes
319+ """Verify that all edge endpoints correspond to existing nodes.
320+
321+ Returns:
322+ bool: True if all edge endpoints match existing nodes, False otherwise.
237323 """
238324 graph = self .graph_dict
239325
@@ -248,6 +334,13 @@ def match_edges_nodes(self) -> bool:
248334 return nodes_set == edges_set
249335
250336 def remove_duplicated_edges (self ) -> BaseGraph :
337+ """Remove duplicate edges between the same node pairs.
338+
339+ Combines utterances from duplicate edges into a single edge.
340+
341+ Returns:
342+ BaseGraph: New graph instance with duplicate edges removed.
343+ """
251344 graph = self .graph_dict
252345 edges = graph ["edges" ]
253346 node_couples = [(e ["source" ], e ["target" ]) for e in edges ]
@@ -269,6 +362,12 @@ def remove_duplicated_edges(self) -> BaseGraph:
269362 return Graph (self .graph_dict )
270363
271364 def remove_duplicated_nodes (self ) -> BaseGraph | None :
365+ """Remove duplicate nodes based on their utterances.
366+
367+ Returns:
368+ BaseGraph | None: New graph instance with duplicate nodes removed,
369+ or None if invalid state is detected.
370+ """
272371 graph = self .graph_dict
273372 nodes = graph ["nodes" ].copy ()
274373 edges = graph ["edges" ].copy ()
@@ -301,18 +400,16 @@ def remove_duplicated_nodes(self) -> BaseGraph | None:
301400 def get_all_paths (
302401 self , start_node_id : int , visited_nodes : list [int ], repeats_limit : int
303402 ) -> list [list [int ]]:
304- """Recursion to find all the graph paths consisting of nodes ids
305- which start from node with id=start_node_id
306- and do not repeat last repeats_limit elements of the visited_nodes
403+ """Find all possible paths in the graph from a starting node.
307404
308405 Args:
309- visited_nodes: a path traveled so far
310- repeats_limit: recursion stopper with maximum length
311- of finishing sequence not to repeat on the path
406+ start_node_id (int): ID of the starting node.
407+ visited_nodes (list[int]): List of nodes already visited in the current path.
408+ repeats_limit (int): Maximum number of times a sequence can repeat.
312409
313- Returns: list of found paths
410+ Returns:
411+ list[list[int]]: List of all valid paths found.
314412 """
315-
316413 if len (visited_nodes ) >= repeats_limit and self ._is_seq_in (
317414 visited_nodes [- repeats_limit :] + [start_node_id ], visited_nodes
318415 ):
@@ -332,11 +429,15 @@ def get_all_paths(
332429 def find_paths (
333430 self , start_node_id : int , end_node_id : int , visited_nodes : list [int ]
334431 ) -> list [list [int ]]:
335- """Recursion to find paths from start_node_id
336- where end_node_id on the path stops recursion
432+ """Find all paths between two nodes in the graph.
433+
337434 Args:
338- visited_nodes: a path traveled so far
339- Returns: list of all paths from start_node_id which probably could be finishing by end_node_id
435+ start_node_id (int): ID of the starting node.
436+ end_node_id (int): ID of the target node.
437+ visited_nodes (list[int]): List of nodes already visited.
438+
439+ Returns:
440+ list[list[int]]: List of all paths found between start and end nodes.
340441 """
341442 visited_paths = [[]]
342443
@@ -357,11 +458,13 @@ def find_paths(
357458 return visited_paths
358459
359460 def get_ends (self ) -> list [int ]:
360- """Find finishing nodes which have no outgoing edges
461+ """Find all terminal nodes in the graph.
462+
463+ Terminal nodes are those with no outgoing edges.
464+
361465 Returns:
362- list of finishing nodes ids
466+ list[int]: List of IDs of terminal nodes.
363467 """
364-
365468 graph = self .graph_dict
366469 sources = list (set ([g ["source" ] for g in graph ["edges" ]]))
367470 finishes = [g ["id" ] for g in graph ["nodes" ] if g ["id" ] not in sources ]
@@ -379,9 +482,10 @@ def get_ends(self) -> list[int]:
379482 return finishes
380483
381484 def get_list_from_nodes (self ) -> list [str ]:
382- """Method to form auxiliary list from the graph nodes
485+ """Create a list of concatenated utterances from all nodes.
486+
383487 Returns:
384- list of concatenations of all nodes utterances
488+ list[str]: List where each element is the concatenated utterances of a node.
385489 """
386490 graph = self .graph_dict
387491 result = []
@@ -395,10 +499,12 @@ def get_list_from_nodes(self) -> list[str]:
395499 return result
396500
397501 def get_list_from_graph (self ) -> tuple [list [str ], int ]:
398- """Method to form auxiliary data from the graph
502+ """Create a list of concatenated utterances from nodes and their edges.
503+
399504 Returns:
400- res_list: concatenation of utterances of every node and its outgoing edges
401- n_edges: total number of utterances in all edges
505+ tuple[list[str], int]: Tuple containing:
506+ - list of concatenated utterances
507+ - total number of utterances in edges
402508 """
403509 graph = self .graph_dict
404510 res_list = []
0 commit comments