Skip to content

Commit 376d1b8

Browse files
author
Yuriy Peshkichev
committed
dialog frequency
1 parent 1a0a9ea commit 376d1b8

File tree

11 files changed

+3821
-44
lines changed

11 files changed

+3821
-44
lines changed

dialog2graph/pipelines/core/dialog_sampling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ def remove_duplicated_paths(node_paths: list[list[int]]) -> list[list[int]]:
216216

217217

218218
def get_dialog_triplets(seq: list[list[dict]]) -> set[tuple[str]]:
219-
"""Find all dialog triplets with (source, edge, target) utterances
219+
"""Get all dialog triplets with (source, edge, target) utterances
220+
from sequence of dialogs
220221
221222
Args:
222223
seq: sequence of dialogs

dialog2graph/pipelines/core/graph.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import Optional, Any
1313
import matplotlib.pyplot as plt
1414
import abc
15+
import colorsys
1516

1617
from dialog2graph.utils.logger import Logger
1718

@@ -252,7 +253,7 @@ def visualise_short(self, name="", *args, **kwargs):
252253
try:
253254
pos = nx.nx_agraph.pygraphviz_layout(self.graph)
254255
except ImportError as e:
255-
pos = nx.kamada_kawai_layout(self.graph)
256+
pos = nx.spring_layout(self.graph)
256257
logger.warning(
257258
f"{e}.\nInstall pygraphviz from http://pygraphviz.github.io/ .\nFalling back to default layout."
258259
)
@@ -290,29 +291,48 @@ def visualise_interactive(self, *args, **kwargs) -> gv._internal.plotting.data_s
290291

291292
"""
292293
Visualises the graph using interactive visualisation library "gravis".
294+
293295
Returns:
294296
A figure object representing the interactive graph visualization.
295297
"""
296-
graph = self.graph.copy()
297-
298-
node_labels = nx.get_node_attributes(graph, "utterances")
299-
edge_labels = nx.get_edge_attributes(graph, "utterances")
300-
301-
for edge_id in graph.edges:
302-
edge = graph.edges[edge_id]
303-
edge['label'] = len(edge_labels[edge_id])
304-
edge['hover'] = edge_labels[edge_id]
305-
306-
for node_id in graph.nodes:
307-
node = graph.nodes[node_id]
308-
node['label'] = f"{node_id}:{len(node_labels[node_id])}"
309-
node['hover'] = node_labels[node_id]
310-
298+
graph = {"graph": {}}
299+
if "frequency" in self.graph_dict["nodes"][0]:
300+
node_rgb = [colorsys.hsv_to_rgb(node["frequency"]/30, 1.0, 1.0) for node in self.graph_dict["nodes"]]
301+
node_colors = ["#%02x%02x%02x" % tuple([round(255*x) for x in rgb]) for rgb in node_rgb]
302+
node_frequency = [node["frequency"] for node in self.graph_dict["nodes"]]
303+
else:
304+
node_colors = ["#000000"]*len(self.graph_dict["nodes"])
305+
node_frequency = [0]*len(self.graph_dict["nodes"])
306+
if "frequency" in self.graph_dict["edges"][0]:
307+
edge_rgb = [colorsys.hsv_to_rgb(node["frequency"]/30, 1.0, 1.0) for node in self.graph_dict["edges"]]
308+
edge_colors = ["#%02x%02x%02x" % tuple([round(255*x) for x in rgb]) for rgb in edge_rgb]
309+
edge_frequency = [edge["frequency"] for edge in self.graph_dict["edges"]]
310+
else:
311+
edge_colors = ["#000000"]*len(self.graph_dict["edges"])
312+
edge_frequency = [0]*len(self.graph_dict["edges"])
313+
314+
graph["graph"]["nodes"] = {
315+
str(node["id"]): {
316+
"label": f"{node['id']}:{len(node['utterances'])}",
317+
"metadata": {
318+
"hover": f"frequency: {node_frequency[idx]}\n" + '\n'.join([str(i+1)+": "+ node["utterances"][i] for i in range(len(node["utterances"]))]),
319+
"color": node_colors[idx]
320+
}
321+
} for idx, node in enumerate(self.graph_dict["nodes"])
322+
}
323+
graph["graph"]["edges"] = [{"source": str(e["source"]),
324+
"target": str(e["target"]),
325+
"label": len(e["utterances"]),
326+
"metadata": {
327+
"hover": f"frequency: {edge_frequency[idx]}\n" + '\n'.join([str(i+1)+": "+ e["utterances"][i] for i in range(len(e["utterances"]))]),
328+
"color": edge_colors[idx]
329+
}
330+
} for idx, e in enumerate(self.graph_dict["edges"])]
311331
return gv.vis(
312332
graph, show_node_label=True, show_edge_label=True,
313333
node_label_data_source='label',
314334
edge_label_data_source='label', edge_label_size_factor=1.7,
315-
layout_algorithm="hierarchicalRepulsion"
335+
layout_algorithm="hierarchicalRepulsion",
316336
)
317337

318338

dialog2graph/utils/dg_helper.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,19 @@ def connect_nodes(
2828
"""
2929
edges = []
3030
node_store = NodeStore(nodes, utt_sim)
31+
for idx in range(len(nodes)):
32+
nodes[idx]["frequency"] = 0
3133
for dialog in dialogs:
3234
turns = dialog.to_list()
3335
dialog_store = DialogStore(turns, utt_sim)
3436
for node in nodes:
3537
for utt in node["utterances"]:
36-
ids = dialog_store.search_assistant(utt)
38+
ids = dialog_store.search_store(
39+
dialog_store.assistant_store,
40+
dialog_store.assistant_size,
41+
utt
42+
)
43+
node["frequency"] += len(ids)
3744
if ids:
3845
for id, user_utt in zip(ids, dialog_store.get_user_by_id(ids=ids)):
3946
if len(turns) > 2 * (int(id) + 1):
@@ -66,6 +73,7 @@ def connect_nodes(
6673
"utterances"
6774
]
6875
+ [user_utt],
76+
"frequency": 0
6977
}
7078
)
7179
else:
@@ -74,8 +82,17 @@ def connect_nodes(
7482
"source": node["id"],
7583
"target": target,
7684
"utterances": [user_utt],
85+
"frequency": 0,
7786
}
7887
)
88+
for edge in edges:
89+
for utt in edge["utterances"]:
90+
ids = dialog_store.search_store(
91+
dialog_store.user_store,
92+
dialog_store.user_size,
93+
utt
94+
)
95+
edge["frequency"] += len(ids)
7996
return {"edges": edges, "nodes": nodes}
8097

8198

dialog2graph/utils/vector_stores.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@ class DialogStore:
1717
User and assistant utterances vectorized separately
1818
1919
Attributes:
20-
_assistant_store: store for assistant utterances
21-
_user_store: store for user utterances
22-
_assistant_size: number of assistant utterances
20+
assistant_store: store for assistant utterances
21+
user_store: store for user utterances
22+
assistant_size: number of assistant utterances
23+
user_size: number of user utterances
2324
_score_threshold: simlarity threshold
2425
"""
2526

26-
_assistant_store: Chroma
27-
_user_store: Chroma
28-
_assistant_size: int
27+
assistant_store: Chroma
28+
user_store: Chroma
29+
assistant_size: int
30+
user_size: int
2931
_score_threshold: int
3032

3133
def _load_dialog(
@@ -39,10 +41,10 @@ def _load_dialog(
3941
dialog: list of dicts in a form {"participant": "user" or "assistant", "text": text}
4042
embedder: embedding function for vector store
4143
"""
42-
self._assistant_store = Chroma(
44+
self.assistant_store = Chroma(
4345
collection_name=str(uuid.uuid4()), embedding_function=embedder
4446
)
45-
self._user_store = Chroma(
47+
self.user_store = Chroma(
4648
collection_name=str(uuid.uuid4()), embedding_function=embedder
4749
)
4850
assistant_docs = [
@@ -53,11 +55,12 @@ def _load_dialog(
5355
]
5456
user_docs = [
5557
Document(page_content=turn["text"].lower(), id=id, metadata={"id": id})
56-
for id, turn in enumerate(d for d in dialog if d["participant"] == "user")
58+
for id, turn in enumerate([d for d in dialog if d["participant"] == "user"])
5759
]
58-
self._assistant_size = len(assistant_docs)
59-
self._assistant_store.add_documents(documents=assistant_docs)
60-
self._user_store.add_documents(documents=user_docs)
60+
self.assistant_size = len(assistant_docs)
61+
self.user_size = len(user_docs)
62+
self.assistant_store.add_documents(documents=assistant_docs)
63+
self.user_store.add_documents(documents=user_docs)
6164

6265
def __init__(
6366
self,
@@ -75,17 +78,20 @@ def __init__(
7578
self._score_threshold = score_threshold
7679
self._load_dialog(dialog, embedder)
7780

78-
def search_assistant(self, utterance) -> list[str]:
79-
"""Search for utterance over assistant store
81+
def search_store(self, store: Chroma, size: int, utterance: str) -> list[str]:
82+
"""Search for utterance over store
8083
8184
Args:
85+
store: Chroma store
86+
size: size of the store
8287
utterance: utterance to search for
8388
Returns:
84-
list of found documents ids of assistant store
89+
list of found documents ids
8590
"""
86-
docs = self._assistant_store.similarity_search_with_relevance_scores(
91+
92+
docs = store.similarity_search_with_relevance_scores(
8793
utterance.lower(),
88-
k=self._assistant_size,
94+
k=size,
8995
score_threshold=self._score_threshold,
9096
)
9197
res = [d[0].metadata["id"] for d in docs]
@@ -94,6 +100,7 @@ def search_assistant(self, utterance) -> list[str]:
94100

95101
return res
96102

103+
97104
def get_user_by_id(self, ids: list[str]) -> list[str]:
98105
"""Get utterances of user with ids
99106
@@ -102,7 +109,7 @@ def get_user_by_id(self, ids: list[str]) -> list[str]:
102109
Returns:
103110
list of utterances
104111
"""
105-
res = self._user_store.get(ids=ids)["documents"]
112+
res = self.user_store.get(ids=ids)["documents"]
106113
return res
107114

108115

experiments/exp2025_03_20_d2g_pipeline/exp2025_03_20_d2g_pipeline/test_pipeline.ipynb

Lines changed: 3678 additions & 7 deletions
Large diffs are not rendered by default.

scripts/check_metrics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def test_d2g_pipeline(pipeline: BasePipeline) -> bool:
113113
# Parse the raw data
114114
raw_data = PipelineRawDataType(dialogs=dialogs, true_graph=graph)
115115
report = pipeline.invoke(raw_data, enable_evals=True)[1].model_dump()
116+
116117
# Extract the duration and similarity from the report
117118
new_summary.append(
118119
{
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[
2+
{
3+
"graph": "Responding to DMs on Instagram/Facebook.",
4+
"duration": 42.00281095504761,
5+
"similarity": 0.9939578771591187
6+
},
7+
{
8+
"graph": "average",
9+
"duration": 42.00281095504761,
10+
"similarity": 0.9939578771591187
11+
}
12+
]
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[
2+
{
3+
"graph": "Responding to DMs on Instagram/Facebook.",
4+
"duration": 19.314972162246704,
5+
"similarity": 0.9247153401374817
6+
},
7+
{
8+
"graph": "average",
9+
"duration": 19.314972162246704,
10+
"similarity": 0.9247153401374817
11+
}
12+
]
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[
2+
{
3+
"graph": "Responding to DMs on Instagram/Facebook.",
4+
"duration": 19.48843026161194,
5+
"similarity": 0.9247153401374817
6+
},
7+
{
8+
"graph": "average",
9+
"duration": 19.48843026161194,
10+
"similarity": 0.9247153401374817
11+
}
12+
]
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[
2+
{
3+
"graph": "Responding to DMs on Instagram/Facebook.",
4+
"duration": 21.439910411834717,
5+
"similarity": 0.9247153401374817
6+
},
7+
{
8+
"graph": "average",
9+
"duration": 21.439910411834717,
10+
"similarity": 0.9247153401374817
11+
}
12+
]

0 commit comments

Comments
 (0)