@@ -71,28 +71,28 @@ def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = -1, topi
7171
7272 return all_dialogues
7373
74- async def ainvoke (self , * args , ** kwargs ):
74+ async def ainvoke (self , * args , ** kwargs ):
7575 return self .invoke (* args , ** kwargs )
7676
7777
7878@AlgorithmRegistry .register (input_type = BaseGraph , output_type = Dialogue )
7979class DialoguePathSampler (DialogueGenerator ):
8080 def invoke (self , graph : BaseGraph , start_node : int = 1 , end_node : int = - 1 , topic = "" ) -> list [Dialogue ]:
8181 nx_graph = graph .graph
82-
82+
8383 # Find all nodes with no outgoing edges (end nodes)
8484 end_nodes = [node for node in nx_graph .nodes () if nx_graph .out_degree (node ) == 0 ]
8585 dialogues = []
8686 # If no end nodes found, return empty list
8787 if not end_nodes :
8888 return []
89-
89+
9090 all_paths = []
9191 # Get paths from start node to each end node
9292 for end in end_nodes :
9393 paths = list (nx .all_simple_paths (nx_graph , source = start_node , target = end ))
9494 all_paths .extend (paths )
95-
95+
9696 for path in all_paths :
9797 dialogue_turns = []
9898 # Process each node and edge in the path
@@ -101,59 +101,69 @@ def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = -1, topi
101101 current_node = path [i ]
102102 assistant_utterance = random .choice (nx_graph .nodes [current_node ]["utterances" ])
103103 dialogue_turns .append ({"text" : assistant_utterance , "participant" : "assistant" })
104-
104+
105105 # Add user utterance from edge (if not at last node)
106106 if i < len (path ) - 1 :
107107 next_node = path [i + 1 ]
108108 edge_data = nx_graph .edges [current_node , next_node ]
109- user_utterance = (
110- random .choice (edge_data ["utterances" ])
111- if isinstance (edge_data ["utterances" ], list )
112- else edge_data ["utterances" ]
113- )
109+ user_utterance = random .choice (edge_data ["utterances" ]) if isinstance (edge_data ["utterances" ], list ) else edge_data ["utterances" ]
114110 dialogue_turns .append ({"text" : user_utterance , "participant" : "user" })
115-
111+
116112 dialogues .append (Dialogue ().from_list (dialogue_turns ))
117-
113+
118114 return dialogues
119-
115+
120116 async def ainvoke (self , * args , ** kwargs ):
121117 return self .invoke (* args , ** kwargs )
122-
118+
123119
124120@AlgorithmRegistry .register (input_type = BaseGraph , output_type = Dialogue )
125121class RecursiveDialogueSampler (DialogueGenerator ):
126122 def _list_in (self , a : list , b : list ) -> bool :
127123 """Check if sequence a exists within sequence b."""
128- return any (map (lambda x : b [x :x + len (a )] == a , range (len (b ) - len (a ) + 1 )))
129-
130-
124+ return any (map (lambda x : b [x : x + len (a )] == a , range (len (b ) - len (a ) + 1 )))
131125
132126 def invoke (self , graph : BaseGraph , start_node : int = 1 , end_node : int = - 1 , topic = "" ) -> list [Dialogue ]:
133127 starts = [n for n in graph .graph_dict .get ("nodes" ) if n ["is_start" ]]
134128 visitedList = [[]]
129+
135130 def all_paths (graph , start : int , visited : list ):
136131 # print("start: ", start, len(visitedList))
137- if len (visited ) < 2 or not self ._list_in (visited [- 2 :]+ [start ],visited ):
132+ if len (visited ) < 2 or not self ._list_in (visited [- 2 :] + [start ], visited ):
138133 visited .append (start )
139134 # print("visited:", visited)
140135 for edge in graph .edge_by_source (start ):
141136
142- # if [start,edge['target']] not in visited:
143- all_paths (graph , edge [' target' ], visited .copy ())
137+ # if [start,edge['target']] not in visited:
138+ all_paths (graph , edge [" target" ], visited .copy ())
144139 visitedList .append (visited )
145140
146- all_paths (graph , starts [0 ]['id' ], [])
141+ all_paths (graph , starts [0 ]["id" ], [])
147142 visitedList .sort ()
148- final = list (k for k ,_ in itertools .groupby (visitedList ))[1 :]
149-
150- dialogues = []
151- for nodes in final :
152- dialogues .append (Dialogue ().from_nodes_ids (graph = graph , node_list = nodes ))
143+ final = list (k for k , _ in itertools .groupby (visitedList ))[1 :]
144+ sources = list (set ([g ["source" ] for g in graph .graph_dict ["edges" ]]))
145+ ends = [g ["id" ] for g in graph .graph_dict ["nodes" ] if g ["id" ] not in sources ]
146+ node_paths = [f for f in final if f [- 1 ] in ends ]
147+ full_paths = []
148+ for p in node_paths :
149+ # print(p)
150+ path = []
151+ for idx , s in enumerate (p [:- 1 ]):
152+ path .append ({"participant" : "assistant" , "text" : graph .node_by_id (s )["utterances" ][0 ]})
153+ # path.append({"user": list(set(gr.edge_by_source(s)) & set(gr.edge_by_target(p[idx+1])))[0]['utterances']})
154+ sources = graph .edge_by_source (s )
155+ targets = graph .edge_by_target (p [idx + 1 ])
156+ # print("SOURCES: ", sources, s)
157+ # print("TARGETS: ", targets, p[idx+1])
158+ # targets = set([(e['source'],e['target']) for e in gr.edge_by_target(p[idx+1])])
159+ edge = [e for e in sources if e in targets ][0 ]
160+ path .append (({"participant" : "user" , "text" : edge ["utterances" ][0 ]}))
161+ path .append ({"participant" : "assistant" , "text" : graph .node_by_id (p [- 1 ])["utterances" ][0 ]})
162+ full_paths .append (path )
163+
164+ dialogues = [Dialogue ().from_list (i ) for i in full_paths ]
153165
154166 return dialogues
155167
156168 async def ainvoke (self , * args , ** kwargs ):
157169 return self .invoke (* args , ** kwargs )
158-
159-
0 commit comments