Skip to content

Commit 68f412b

Browse files
Add docstrings and remove unused modules (#51)
* working on docstrings and removing unused files * add docstrings for Graph * remove Dataset class, format * Refactor: move find_cycle_ends function to helpers
1 parent 256ac1f commit 68f412b

File tree

24 files changed

+246
-418
lines changed

24 files changed

+246
-418
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from dialogue2graph.datasets.core import Dataset
1+
from dialogue2graph.datasets.complex_dialogues.generation import CycleGraphGenerator
22

3-
__all__ = ["Dataset"]
3+
__all__ = ["CycleGraphGenerator"]

dialogue2graph/datasets/core/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

dialogue2graph/datasets/core/dataset.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

dialogue2graph/metrics/no_llm_metrics/metrics.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,15 +317,18 @@ def match_graph_triplets(G1: BaseGraph, G2: BaseGraph, change_to_original_ids=Fa
317317

318318

319319
def is_same_structure(G1: BaseGraph, G2: BaseGraph) -> bool:
320+
"""
321+
Check if graphs are isomorphic.
322+
323+
Args:
324+
G1: BaseGraph object containing the dialogue graph
325+
G2: BaseGraph object containing the dialogue graph
326+
"""
320327
g1 = G1.graph
321328
g2 = G2.graph
322329
return nx.is_isomorphic(g1, g2)
323330

324331

325-
def all_paths_sampled(G: BaseGraph, dialogue: Dialogue) -> bool:
326-
return True
327-
328-
329332
def _get_dialogue_triplets(seq: list[Dialogue]) -> set[tuple[str]]:
330333
"""Find all dialogue triplets with (source, edge, target) utterances"""
331334
result = []

dialogue2graph/pipelines/core/dialogue_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
match_dg_triplets,
1010
match_dialogue_triplets,
1111
)
12-
from dialogue2graph.datasets.complex_dialogues.find_cycle_ends import find_cycle_ends
12+
from dialogue2graph.pipelines.helpers.find_cycle_ends import find_cycle_ends
1313
from langchain_core.language_models.chat_models import BaseChatModel
1414

1515
logging.basicConfig(level=logging.INFO)

dialogue2graph/pipelines/core/graph.py

Lines changed: 129 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,17 @@
99

1010

1111
class BaseGraph(BaseModel, abc.ABC):
12+
"""Base abstract class for graph representations of dialogues.
13+
14+
This class provides the interface for graph operations and manipulations.
15+
It inherits from both BaseModel for data validation and ABC for abstract methods.
16+
17+
Attributes:
18+
graph_dict (dict): Dictionary containing the graph structure with nodes and edges.
19+
graph (Optional[nx.Graph]): NetworkX graph instance.
20+
node_mapping (Optional[dict]): Mapping between original node IDs and internal representation.
21+
"""
22+
1223
graph_dict: dict
1324
graph: Optional[nx.Graph] = None
1425
node_mapping: Optional[dict] = None
@@ -77,8 +88,22 @@ def get_list_from_graph(self):
7788

7889

7990
class Graph(BaseGraph):
91+
"""Implementation of BaseGraph for dialogue graph operations.
92+
93+
This class provides concrete implementations for graph operations including
94+
loading, visualization, path finding, and graph manipulation methods.
95+
96+
Attributes:
97+
Inherits all attributes from BaseGraph.
98+
"""
99+
80100
def __init__(self, graph_dict: dict, **kwargs: Any):
81-
# Pass graph_dict to the parent class
101+
"""Initialize the Graph instance.
102+
103+
Args:
104+
graph_dict (dict): Dictionary containing the graph structure.
105+
**kwargs: Additional keyword arguments passed to parent class.
106+
"""
82107
super().__init__(graph_dict=graph_dict, **kwargs)
83108
if graph_dict:
84109
self.load_graph()
@@ -101,6 +126,11 @@ def check_edges(self, seq: list[list[int]]) -> bool:
101126
return seen == edge_set
102127

103128
def load_graph(self):
129+
"""Load graph from dictionary representation into NetworkX DiGraph.
130+
131+
Creates a directed graph from the graph_dict, handling node and edge attributes.
132+
Also creates node mapping if node IDs need renumbering.
133+
"""
104134
self.graph = nx.DiGraph()
105135
nodes = sorted([v["id"] for v in self.graph_dict["nodes"]])
106136
logging.debug(f"Nodes: {nodes}")
@@ -144,6 +174,11 @@ def load_graph(self):
144174
)
145175

146176
def visualise(self, *args, **kwargs):
177+
"""Visualize the graph using matplotlib and networkx.
178+
179+
Creates a visualization of the graph with nodes and edges labeled with utterances.
180+
Uses pygraphviz layout if available, falls back to kamada_kawai_layout.
181+
"""
147182
plt.figure(figsize=(17, 11)) # Make the plot bigger
148183
try:
149184
pos = nx.nx_agraph.pygraphviz_layout(self.graph)
@@ -173,6 +208,15 @@ def visualise(self, *args, **kwargs):
173208
plt.show()
174209

175210
def visualise_short(self, name, *args, **kwargs):
211+
"""Create a compact visualization of the graph.
212+
213+
Args:
214+
name (str): Title for the visualization.
215+
*args: Variable length argument list.
216+
**kwargs: Arbitrary keyword arguments.
217+
218+
Creates a simplified visualization showing only node IDs and utterance counts.
219+
"""
176220
try:
177221
pos = nx.nx_agraph.pygraphviz_layout(self.graph)
178222
except ImportError as e:
@@ -211,29 +255,71 @@ def visualise_short(self, name, *args, **kwargs):
211255
plt.show()
212256

213257
def find_nodes_by_utterance(self, utterance: str) -> list[dict]:
258+
"""Find nodes containing a specific utterance.
259+
260+
Args:
261+
utterance (str): The utterance to search for.
262+
263+
Returns:
264+
list[dict]: List of nodes containing the utterance.
265+
"""
214266
return [
215267
node for node in self.graph_dict["nodes"] if utterance in node["utterances"]
216268
]
217269

218270
def find_edges_by_utterance(self, utterance: str) -> list[dict]:
271+
"""Find edges containing a specific utterance.
272+
273+
Args:
274+
utterance (str): The utterance to search for.
275+
276+
Returns:
277+
list[dict]: List of edges containing the utterance.
278+
"""
219279
return [
220280
edge for edge in self.graph_dict["edges"] if utterance in edge["utterances"]
221281
]
222282

223283
def get_nodes_by_id(self, id: int):
284+
"""Retrieve a node by its ID.
285+
286+
Args:
287+
id (int): The ID of the node to retrieve.
288+
289+
Returns:
290+
dict: The node with the specified ID if found, None otherwise.
291+
"""
224292
for node in self.graph_dict["nodes"]:
225293
if node["id"] == id:
226294
return node
227295

228296
def get_edges_by_source(self, id: int):
297+
"""Get all edges originating from a specific node.
298+
299+
Args:
300+
id (int): The ID of the source node.
301+
302+
Returns:
303+
list[dict]: List of edges with the specified source node.
304+
"""
229305
return [edge for edge in self.graph_dict["edges"] if edge["source"] == id]
230306

231307
def get_edges_by_target(self, id: int):
308+
"""Get all edges targeting a specific node.
309+
310+
Args:
311+
id (int): The ID of the target node.
312+
313+
Returns:
314+
list[dict]: List of edges with the specified target node.
315+
"""
232316
return [edge for edge in self.graph_dict["edges"] if edge["target"] == id]
233317

234318
def match_edges_nodes(self) -> bool:
235-
"""Checks whether source and target
236-
of all the edges correspond to nodes
319+
"""Verify that all edge endpoints correspond to existing nodes.
320+
321+
Returns:
322+
bool: True if all edge endpoints match existing nodes, False otherwise.
237323
"""
238324
graph = self.graph_dict
239325

@@ -248,6 +334,13 @@ def match_edges_nodes(self) -> bool:
248334
return nodes_set == edges_set
249335

250336
def remove_duplicated_edges(self) -> BaseGraph:
337+
"""Remove duplicate edges between the same node pairs.
338+
339+
Combines utterances from duplicate edges into a single edge.
340+
341+
Returns:
342+
BaseGraph: New graph instance with duplicate edges removed.
343+
"""
251344
graph = self.graph_dict
252345
edges = graph["edges"]
253346
node_couples = [(e["source"], e["target"]) for e in edges]
@@ -269,6 +362,12 @@ def remove_duplicated_edges(self) -> BaseGraph:
269362
return Graph(self.graph_dict)
270363

271364
def remove_duplicated_nodes(self) -> BaseGraph | None:
365+
"""Remove duplicate nodes based on their utterances.
366+
367+
Returns:
368+
BaseGraph | None: New graph instance with duplicate nodes removed,
369+
or None if invalid state is detected.
370+
"""
272371
graph = self.graph_dict
273372
nodes = graph["nodes"].copy()
274373
edges = graph["edges"].copy()
@@ -301,18 +400,16 @@ def remove_duplicated_nodes(self) -> BaseGraph | None:
301400
def get_all_paths(
302401
self, start_node_id: int, visited_nodes: list[int], repeats_limit: int
303402
) -> list[list[int]]:
304-
"""Recursion to find all the graph paths consisting of nodes ids
305-
which start from node with id=start_node_id
306-
and do not repeat last repeats_limit elements of the visited_nodes
403+
"""Find all possible paths in the graph from a starting node.
307404
308405
Args:
309-
visited_nodes: a path traveled so far
310-
repeats_limit: recursion stopper with maximum length
311-
of finishing sequence not to repeat on the path
406+
start_node_id (int): ID of the starting node.
407+
visited_nodes (list[int]): List of nodes already visited in the current path.
408+
repeats_limit (int): Maximum number of times a sequence can repeat.
312409
313-
Returns: list of found paths
410+
Returns:
411+
list[list[int]]: List of all valid paths found.
314412
"""
315-
316413
if len(visited_nodes) >= repeats_limit and self._is_seq_in(
317414
visited_nodes[-repeats_limit:] + [start_node_id], visited_nodes
318415
):
@@ -332,11 +429,15 @@ def get_all_paths(
332429
def find_paths(
333430
self, start_node_id: int, end_node_id: int, visited_nodes: list[int]
334431
) -> list[list[int]]:
335-
"""Recursion to find paths from start_node_id
336-
where end_node_id on the path stops recursion
432+
"""Find all paths between two nodes in the graph.
433+
337434
Args:
338-
visited_nodes: a path traveled so far
339-
Returns: list of all paths from start_node_id which probably could be finishing by end_node_id
435+
start_node_id (int): ID of the starting node.
436+
end_node_id (int): ID of the target node.
437+
visited_nodes (list[int]): List of nodes already visited.
438+
439+
Returns:
440+
list[list[int]]: List of all paths found between start and end nodes.
340441
"""
341442
visited_paths = [[]]
342443

@@ -357,11 +458,13 @@ def find_paths(
357458
return visited_paths
358459

359460
def get_ends(self) -> list[int]:
360-
"""Find finishing nodes which have no outgoing edges
461+
"""Find all terminal nodes in the graph.
462+
463+
Terminal nodes are those with no outgoing edges.
464+
361465
Returns:
362-
list of finishing nodes ids
466+
list[int]: List of IDs of terminal nodes.
363467
"""
364-
365468
graph = self.graph_dict
366469
sources = list(set([g["source"] for g in graph["edges"]]))
367470
finishes = [g["id"] for g in graph["nodes"] if g["id"] not in sources]
@@ -379,9 +482,10 @@ def get_ends(self) -> list[int]:
379482
return finishes
380483

381484
def get_list_from_nodes(self) -> list[str]:
382-
"""Method to form auxiliary list from the graph nodes
485+
"""Create a list of concatenated utterances from all nodes.
486+
383487
Returns:
384-
list of concatenations of all nodes utterances
488+
list[str]: List where each element is the concatenated utterances of a node.
385489
"""
386490
graph = self.graph_dict
387491
result = []
@@ -395,10 +499,12 @@ def get_list_from_nodes(self) -> list[str]:
395499
return result
396500

397501
def get_list_from_graph(self) -> tuple[list[str], int]:
398-
"""Method to form auxiliary data from the graph
502+
"""Create a list of concatenated utterances from nodes and their edges.
503+
399504
Returns:
400-
res_list: concatenation of utterances of every node and its outgoing edges
401-
n_edges: total number of utterances in all edges
505+
tuple[list[str], int]: Tuple containing:
506+
- list of concatenated utterances
507+
- total number of utterances in edges
402508
"""
403509
graph = self.graph_dict
404510
res_list = []

dialogue2graph/pipelines/core/pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from typing import Union
2+
from typing import Union, Tuple, Any
33
from pydantic import BaseModel, Field
44
from dialogue2graph.pipelines.core.algorithms import (
55
DialogAugmentation,
@@ -25,7 +25,9 @@ class BasePipeline(BaseModel):
2525
def _validate_pipeline(self):
2626
pass
2727

28-
def invoke(self, raw_data: PipelineRawDataType, enable_evals=False):
28+
def invoke(
29+
self, raw_data: PipelineRawDataType, enable_evals=False
30+
) -> Tuple[Any, PipelineReport]:
2931
data: PipelineDataType = RawDGParser().invoke(raw_data)
3032
report = PipelineReport(service=self.name)
3133
st_time = time.time()

dialogue2graph/pipelines/cycled_graphs/dialogue.py

Whitespace-only changes.

dialogue2graph/pipelines/cycled_graphs/graph.py

Whitespace-only changes.

dialogue2graph/pipelines/cycled_graphs/pipeline.py

Whitespace-only changes.

0 commit comments

Comments
 (0)