Skip to content

Commit a4b537f

Browse files
Quickfix of RecursiveDialogueSampler
1 parent c07ac04 commit a4b537f

File tree

5 files changed

+614
-119
lines changed

5 files changed

+614
-119
lines changed

dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_generation.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -71,28 +71,28 @@ def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = -1, topi
7171

7272
return all_dialogues
7373

74-
async def ainvoke(self, *args, **kwargs):
74+
async def ainvoke(self, *args, **kwargs):
7575
return self.invoke(*args, **kwargs)
7676

7777

7878
@AlgorithmRegistry.register(input_type=BaseGraph, output_type=Dialogue)
7979
class DialoguePathSampler(DialogueGenerator):
8080
def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = -1, topic="") -> list[Dialogue]:
8181
nx_graph = graph.graph
82-
82+
8383
# Find all nodes with no outgoing edges (end nodes)
8484
end_nodes = [node for node in nx_graph.nodes() if nx_graph.out_degree(node) == 0]
8585
dialogues = []
8686
# If no end nodes found, return empty list
8787
if not end_nodes:
8888
return []
89-
89+
9090
all_paths = []
9191
# Get paths from start node to each end node
9292
for end in end_nodes:
9393
paths = list(nx.all_simple_paths(nx_graph, source=start_node, target=end))
9494
all_paths.extend(paths)
95-
95+
9696
for path in all_paths:
9797
dialogue_turns = []
9898
# Process each node and edge in the path
@@ -101,59 +101,69 @@ def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = -1, topi
101101
current_node = path[i]
102102
assistant_utterance = random.choice(nx_graph.nodes[current_node]["utterances"])
103103
dialogue_turns.append({"text": assistant_utterance, "participant": "assistant"})
104-
104+
105105
# Add user utterance from edge (if not at last node)
106106
if i < len(path) - 1:
107107
next_node = path[i + 1]
108108
edge_data = nx_graph.edges[current_node, next_node]
109-
user_utterance = (
110-
random.choice(edge_data["utterances"])
111-
if isinstance(edge_data["utterances"], list)
112-
else edge_data["utterances"]
113-
)
109+
user_utterance = random.choice(edge_data["utterances"]) if isinstance(edge_data["utterances"], list) else edge_data["utterances"]
114110
dialogue_turns.append({"text": user_utterance, "participant": "user"})
115-
111+
116112
dialogues.append(Dialogue().from_list(dialogue_turns))
117-
113+
118114
return dialogues
119-
115+
120116
async def ainvoke(self, *args, **kwargs):
121117
return self.invoke(*args, **kwargs)
122-
118+
123119

124120
@AlgorithmRegistry.register(input_type=BaseGraph, output_type=Dialogue)
125121
class RecursiveDialogueSampler(DialogueGenerator):
126122
def _list_in(self, a: list, b: list) -> bool:
127123
"""Check if sequence a exists within sequence b."""
128-
return any(map(lambda x: b[x:x + len(a)] == a, range(len(b) - len(a) + 1)))
129-
130-
124+
return any(map(lambda x: b[x : x + len(a)] == a, range(len(b) - len(a) + 1)))
131125

132126
def invoke(self, graph: BaseGraph, start_node: int = 1, end_node: int = -1, topic="") -> list[Dialogue]:
133127
starts = [n for n in graph.graph_dict.get("nodes") if n["is_start"]]
134128
visitedList = [[]]
129+
135130
def all_paths(graph, start: int, visited: list):
136131
# print("start: ", start, len(visitedList))
137-
if len(visited) < 2 or not self._list_in(visited[-2:]+[start],visited):
132+
if len(visited) < 2 or not self._list_in(visited[-2:] + [start], visited):
138133
visited.append(start)
139134
# print("visited:", visited)
140135
for edge in graph.edge_by_source(start):
141136

142-
# if [start,edge['target']] not in visited:
143-
all_paths(graph, edge['target'], visited.copy())
137+
# if [start,edge['target']] not in visited:
138+
all_paths(graph, edge["target"], visited.copy())
144139
visitedList.append(visited)
145140

146-
all_paths(graph, starts[0]['id'], [])
141+
all_paths(graph, starts[0]["id"], [])
147142
visitedList.sort()
148-
final = list(k for k,_ in itertools.groupby(visitedList))[1:]
149-
150-
dialogues = []
151-
for nodes in final:
152-
dialogues.append(Dialogue().from_nodes_ids(graph=graph, node_list=nodes))
143+
final = list(k for k, _ in itertools.groupby(visitedList))[1:]
144+
sources = list(set([g["source"] for g in graph.graph_dict["edges"]]))
145+
ends = [g["id"] for g in graph.graph_dict["nodes"] if g["id"] not in sources]
146+
node_paths = [f for f in final if f[-1] in ends]
147+
full_paths = []
148+
for p in node_paths:
149+
# print(p)
150+
path = []
151+
for idx, s in enumerate(p[:-1]):
152+
path.append({"participant": "assistant", "text": graph.node_by_id(s)["utterances"][0]})
153+
# path.append({"user": list(set(gr.edge_by_source(s)) & set(gr.edge_by_target(p[idx+1])))[0]['utterances']})
154+
sources = graph.edge_by_source(s)
155+
targets = graph.edge_by_target(p[idx + 1])
156+
# print("SOURCES: ", sources, s)
157+
# print("TARGETS: ", targets, p[idx+1])
158+
# targets = set([(e['source'],e['target']) for e in gr.edge_by_target(p[idx+1])])
159+
edge = [e for e in sources if e in targets][0]
160+
path.append(({"participant": "user", "text": edge["utterances"][0]}))
161+
path.append({"participant": "assistant", "text": graph.node_by_id(p[-1])["utterances"][0]})
162+
full_paths.append(path)
163+
164+
dialogues = [Dialogue().from_list(i) for i in full_paths]
153165

154166
return dialogues
155167

156168
async def ainvoke(self, *args, **kwargs):
157169
return self.invoke(*args, **kwargs)
158-
159-

dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/dialogue.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,15 @@ def from_nodes_ids(cls, graph, node_list, validate: bool = True) -> "Dialogue":
5252
nodes_attributes = nx.get_node_attributes(graph.graph, "utterances")
5353
edges_attributes = nx.get_edge_attributes(graph.graph, "utterances")
5454
for node in range(len(node_list)):
55-
utts.append({
56-
"participant": "assistant",
57-
"text": nodes_attributes[node_list[node]][0]
58-
})
55+
utts.append({"participant": "assistant", "text": nodes_attributes[node_list[node]][0]})
5956
if node == len(node_list) - 1:
6057
if graph.graph.has_edge(node_list[node], node_list[0]):
61-
utts.append({
62-
"participant": "user",
63-
"text": edges_attributes[(node_list[node], node_list[0])][0]})
58+
utts.append({"participant": "user", "text": edges_attributes[(node_list[node], node_list[0])][0]})
6459
else:
65-
if graph.graph.has_edge(node_list[node], node_list[node+1]):
66-
utts.append({"participant": "user", "text": edges_attributes[(node_list[node], node_list[node+1])][0]})
67-
60+
if graph.graph.has_edge(node_list[node], node_list[node + 1]):
61+
utts.append({"participant": "user", "text": edges_attributes[(node_list[node], node_list[node + 1])][0]})
62+
6863
return cls(messages=utts, validate=validate)
69-
70-
71-
72-
7364

7465
def to_list(self) -> List[Dict[str, str]]:
7566
"""Converts Dialogue to a list of message dictionaries."""
@@ -97,10 +88,9 @@ def extend(self, messages: List[Union[DialogueMessage, Dict[str, str]]]) -> None
9788
new_messages = [msg if isinstance(msg, DialogueMessage) else DialogueMessage(**msg) for msg in messages]
9889
self.__validate(new_messages)
9990
self.messages.extend(new_messages)
100-
91+
10192
def __validate(self, messages):
102-
"""Ensure that messages meets expectations.
103-
"""
93+
"""Ensure that messages meets expectations."""
10494
if not messages:
10595
return
10696

dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/graph.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def edges_by_utterance(self):
3636
def node_by_id(self):
3737
raise NotImplementedError
3838

39+
@abc.abstractmethod
40+
def edge_by_target(self):
41+
raise NotImplementedError
42+
3943

4044
class Graph(BaseGraph):
4145

@@ -87,15 +91,18 @@ def visualise(self, *args, **kwargs):
8791
plt.show()
8892

8993
def nodes_by_utterance(self, utterance: str) -> list[dict]:
90-
return [node for node in self.graph_dict['nodes'] if utterance in node['utterances']]
91-
94+
return [node for node in self.graph_dict["nodes"] if utterance in node["utterances"]]
95+
9296
def edges_by_utterance(self, utterance: str) -> list[dict]:
93-
return [edge for edge in self.graph_dict['edges'] if utterance in edge['utterances']]
94-
97+
return [edge for edge in self.graph_dict["edges"] if utterance in edge["utterances"]]
98+
9599
def node_by_id(self, id: int):
96-
for node in self.graph_dict['nodes']:
97-
if node['id'] == id:
100+
for node in self.graph_dict["nodes"]:
101+
if node["id"] == id:
98102
return node
99-
103+
100104
def edge_by_source(self, source: int):
101-
return [edge for edge in self.graph_dict['edges'] if source == edge['source']]
105+
return [edge for edge in self.graph_dict["edges"] if source == edge["source"]]
106+
107+
def edge_by_target(self, target: int):
108+
return [edge for edge in self.graph_dict["edges"] if target == edge["target"]]

dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/metrics/automatic_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ def all_utterances_present(G: BaseGraph, dialogues: list[Dialogue]) -> bool:
156156
if graph_utterances.issubset(dialogue_utterances):
157157
return True
158158
else:
159-
return False
160-
# return graph_utterances.difference(dialogue_utterances)
159+
# return False
160+
return graph_utterances.difference(dialogue_utterances)
161161

162162

163163
def all_roles_correct(D1: Dialogue, D2: Dialogue) -> bool:

experiments/2025.01.13_data_check_and_sampler_debugging/sampler.ipynb

Lines changed: 552 additions & 64 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)