1414
1515env_settings = EnvSettings()
1616
17- def list_in(a, b):
18- return any(map(lambda x: b[x:x + len(a)] == a, range(len(b) - len(a) + 1)))
19-
2017def len_in(a,b):
2118 return sum([b[x:x + len(a)] == a for x in range(len(b) - len(a) + 1)])
2219
@@ -31,17 +28,17 @@ def mix_ends(graph: BaseGraph, ends: list[int], cycles: list[int]):
3128 visited.append(c)
3229 return [e for e in cycles if e not in visited] + ends
3330
34- def all_paths(graph: BaseGraph, start: int, visited: list, repeats: int):
35- global visited_list
31+ # def all_paths(graph: BaseGraph, start: int, visited: list, repeats: int):
32+ # global visited_list
3633
37- # if len(visited) < 1 or len_in([visited[-1],start],visited) < repeats:
38- if len(visited) < repeats or not list_in(visited[-repeats:]+[start],visited):
39- # print("LEN: ", len(visited))
40- visited.append(start)
41- for edge in graph.edge_by_source(start):
42- # print("TARGET: ", edge['target'])
43- all_paths(graph, edge['target'], visited.copy(), repeats)
44- visited_list.append(visited)
34+ # # if len(visited) < 1 or len_in([visited[-1],start],visited) < repeats:
35+ # if len(visited) < repeats or not list_in(visited[-repeats:]+[start],visited):
36+ # # print("LEN: ", len(visited))
37+ # visited.append(start)
38+ # for edge in graph.edge_by_source(start):
39+ # # print("TARGET: ", edge['target'])
40+ # all_paths(graph, edge['target'], visited.copy(), repeats)
41+ # visited_list.append(visited)
4542
4643def all_combinations(path: list, start: dict, next: int, visited: list):
4744 global visited_list
@@ -90,10 +87,12 @@ def get_utts(seq: list[list[dict]]) -> set[tuple[str]]:
9087 return set(res)
9188
9289def dialogue_edges(seq: list[list[dict]]) -> set[tuple[str]]:
90+
9391 res = []
9492 for dialogue in seq:
9593 assist_texts = [d['text'] for d in dialogue if d['participant']=='assistant']
96- res.extend([(a1,a2) for a1,a2 in zip(assist_texts[:-1],assist_texts[1:])])
94+ user_texts = [d['text'] for d in dialogue if d['participant']=='user']
95+ res.extend([(a1,u,a2) for a1,u,a2 in zip(assist_texts[:-1],user_texts[:len(assist_texts)-1],assist_texts[1:])])
9796 # print("DIA: ", set(res))
9897 return set(res)
9998
@@ -110,7 +109,7 @@ def get_dialogues(graph: BaseGraph, repeats: int, ends: list[int]) -> list[Dialo
110109 starts = [n for n in graph.graph_dict.get("nodes") if n["is_start"]]
111110 for s in starts:
112111 visited_list = [[]]
113- all_paths(graph, s['id'], [], repeats)
112+ graph. all_paths(s['id'], [], repeats)
114113 paths.extend(visited_list)
115114
116115 paths.sort()
@@ -148,16 +147,16 @@ def get_dialogues(graph: BaseGraph, repeats: int, ends: list[int]) -> list[Dialo
148147 dialogue = [el[1:] for el in visited_list if len(el)==len(f)+1]
149148 dialogues.extend(dialogue)
150149
151- # for d in dialogues:
152- # print("DGS: ", d)
153- # print("\n")
150+ for d in dialogues:
151+ print("DGS: ", d)
152+ print("\n")
154153 final = list(k for k,_ in itertools.groupby(dialogues))
155154 # print("BEFORE: ", len(final))
156155 final = remove_duplicated_utts(final)
157156 # print("AFTER: ", len(final))
158- # for f in final:
159- # print("FINAL: ", f)
160- # print("\n")
157+ for f in final:
158+ print("FINAL: ", f)
159+ print("\n")
161160 result = [Dialogue().from_list(seq) for seq in final]
162161 return result
163162
@@ -274,9 +273,6 @@ async def ainvoke(self, *args, **kwargs):
274273
275274# @AlgorithmRegistry.register(input_type=BaseGraph, output_type=Dialogue)
276275class RecursiveDialogueSampler(DialogueGenerator):
277- def _list_in(self, a: list, b: list) -> bool:
278- """Check if sequence a exists within sequence b."""
279- return any(map(lambda x: b[x : x + len(a)] == a, range(len(b) - len(a) + 1)))
280276
281277 def invoke(self, graph: BaseGraph, upper_limit: int) -> list[Dialogue]:
282278 global visited_list
0 commit comments