Skip to content

Commit f44b928

Browse files
author
Yuriy Peshkichev
committed
more graphs with dialogues generated
1 parent a367421 commit f44b928

24 files changed

+306383
-67404
lines changed

dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/cycle_graph_generation_pipeline.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,14 @@ def generate_and_validate(self, topic: str) -> PipelineResult:
236236
if not graph.edges_match_nodes():
237237
return GenerationError(
238238
error_type=ErrorType.INVALID_GRAPH_STRUCTURE,
239-
message="Genrated graph is wrong: edges don't match nodes"
239+
message="Generated graph is wrong: edges don't match nodes"
240240
)
241241
graph = graph.remove_duplicated_nodes()
242-
242+
if graph is None:
243+
return GenerationError(
244+
error_type=ErrorType.INVALID_GRAPH_STRUCTURE,
245+
message="Generated graph is wrong: utterances in nodes doubled"
246+
)
243247
# 2. Validate cycles
244248
cycle_validation = self.validate_graph_cycle_requirement(graph, self.min_cycles)
245249
if not cycle_validation["meets_requirements"]:
@@ -281,10 +285,22 @@ def generate_and_validate(self, topic: str) -> PipelineResult:
281285
error_type=ErrorType.INVALID_GRAPH_STRUCTURE,
282286
message=f"Found {len(invalid_transitions)} invalid transitions after {transition_validation['validation_details']['attempts_made']} fix attempts"
283287
)
288+
289+
graph = transition_validation["graph"]
290+
print("Sampling dialogues...")
291+
sampled_dialogues = self.dialogue_sampler.invoke(graph, 15)
292+
print(f"Sampled {len(sampled_dialogues)} dialogues")
293+
for s in sampled_dialogues:
294+
print(s)
295+
if all_utterances_present(graph, sampled_dialogues) != True:
296+
return GenerationError(
297+
error_type=ErrorType.SAMPLING_FAILED,
298+
message="Failed to sample valid dialogues - not all utterances are present"
299+
)
284300

285301
# All validations passed - return successful result
286302
return GraphGenerationResult(
287-
graph=transition_validation["graph"].graph_dict,
303+
graph=graph.graph_dict,
288304
topic=topic,
289305
dialogues=sampled_dialogues
290306
)

dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/dialogue_generation.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414

1515
env_settings = EnvSettings()
1616

17-
def list_in(a, b):
18-
return any(map(lambda x: b[x:x + len(a)] == a, range(len(b) - len(a) + 1)))
19-
2017
def len_in(a,b):
2118
return sum([b[x:x + len(a)] == a for x in range(len(b) - len(a) + 1)])
2219

@@ -31,17 +28,17 @@ def mix_ends(graph: BaseGraph, ends: list[int], cycles: list[int]):
3128
visited.append(c)
3229
return [e for e in cycles if e not in visited] + ends
3330

34-
def all_paths(graph: BaseGraph, start: int, visited: list, repeats: int):
35-
global visited_list
31+
# def all_paths(graph: BaseGraph, start: int, visited: list, repeats: int):
32+
# global visited_list
3633

37-
# if len(visited) < 1 or len_in([visited[-1],start],visited) < repeats:
38-
if len(visited) < repeats or not list_in(visited[-repeats:]+[start],visited):
39-
# print("LEN: ", len(visited))
40-
visited.append(start)
41-
for edge in graph.edge_by_source(start):
42-
# print("TARGET: ", edge['target'])
43-
all_paths(graph, edge['target'], visited.copy(), repeats)
44-
visited_list.append(visited)
34+
# # if len(visited) < 1 or len_in([visited[-1],start],visited) < repeats:
35+
# if len(visited) < repeats or not list_in(visited[-repeats:]+[start],visited):
36+
# # print("LEN: ", len(visited))
37+
# visited.append(start)
38+
# for edge in graph.edge_by_source(start):
39+
# # print("TARGET: ", edge['target'])
40+
# all_paths(graph, edge['target'], visited.copy(), repeats)
41+
# visited_list.append(visited)
4542

4643
def all_combinations(path: list, start: dict, next: int, visited: list):
4744
global visited_list
@@ -90,10 +87,12 @@ def get_utts(seq: list[list[dict]]) -> set[tuple[str]]:
9087
return set(res)
9188

9289
def dialogue_edges(seq: list[list[dict]]) -> set[tuple[str]]:
90+
9391
res = []
9492
for dialogue in seq:
9593
assist_texts = [d['text'] for d in dialogue if d['participant']=='assistant']
96-
res.extend([(a1,a2) for a1,a2 in zip(assist_texts[:-1],assist_texts[1:])])
94+
user_texts = [d['text'] for d in dialogue if d['participant']=='user']
95+
res.extend([(a1,u,a2) for a1,u,a2 in zip(assist_texts[:-1],user_texts[:len(assist_texts)-1],assist_texts[1:])])
9796
# print("DIA: ", set(res))
9897
return set(res)
9998

@@ -110,7 +109,7 @@ def get_dialogues(graph: BaseGraph, repeats: int, ends: list[int]) -> list[Dialo
110109
starts = [n for n in graph.graph_dict.get("nodes") if n["is_start"]]
111110
for s in starts:
112111
visited_list = [[]]
113-
all_paths(graph, s['id'], [], repeats)
112+
graph.all_paths(s['id'], [], repeats)
114113
paths.extend(visited_list)
115114

116115
paths.sort()
@@ -148,16 +147,16 @@ def get_dialogues(graph: BaseGraph, repeats: int, ends: list[int]) -> list[Dialo
148147
dialogue = [el[1:] for el in visited_list if len(el)==len(f)+1]
149148
dialogues.extend(dialogue)
150149

151-
# for d in dialogues:
152-
# print("DGS: ", d)
153-
# print("\n")
150+
for d in dialogues:
151+
print("DGS: ", d)
152+
print("\n")
154153
final = list(k for k,_ in itertools.groupby(dialogues))
155154
# print("BEFORE: ", len(final))
156155
final = remove_duplicated_utts(final)
157156
# print("AFTER: ", len(final))
158-
# for f in final:
159-
# print("FINAL: ", f)
160-
# print("\n")
157+
for f in final:
158+
print("FINAL: ", f)
159+
print("\n")
161160
result = [Dialogue().from_list(seq) for seq in final]
162161
return result
163162

@@ -274,9 +273,6 @@ async def ainvoke(self, *args, **kwargs):
274273

275274
# @AlgorithmRegistry.register(input_type=BaseGraph, output_type=Dialogue)
276275
class RecursiveDialogueSampler(DialogueGenerator):
277-
def _list_in(self, a: list, b: list) -> bool:
278-
"""Check if sequence a exists within sequence b."""
279-
return any(map(lambda x: b[x : x + len(a)] == a, range(len(b) - len(a) + 1)))
280276

281277
def invoke(self, graph: BaseGraph, upper_limit: int) -> list[Dialogue]:
282278
global visited_list

dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/three_stages_graph_generation.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1+
import pandas as pd
12
from langchain.prompts import PromptTemplate
23
from langchain_openai import ChatOpenAI
34
from langchain.output_parsers import PydanticOutputParser
45
from langchain_community.embeddings import HuggingFaceEmbeddings
56

67
from chatsky_llm_autoconfig.algorithms.base import GraphGenerator
7-
from chatsky_llm_autoconfig.graph import BaseGraph, Graph
8+
from chatsky_llm_autoconfig.metrics.automatic_metrics import (
9+
is_same_structure,
10+
compare_graphs
11+
)
812
from chatsky_llm_autoconfig.metrics.embedder import nodes2groups
913
from chatsky_llm_autoconfig.schemas import DialogueGraph
1014
from chatsky_llm_autoconfig.dialogue import Dialogue
1115
from chatsky_llm_autoconfig.autometrics.registry import AlgorithmRegistry
1216
from chatsky_llm_autoconfig.utils import call_llm_api, nodes2graph, dialogues2list
1317
from chatsky_llm_autoconfig.settings import EnvSettings
14-
1518
from chatsky_llm_autoconfig.missing_edges_prompt import three_1, three_2
16-
19+
from chatsky_llm_autoconfig.graph import BaseGraph, Graph
1720
env_settings = EnvSettings()
1821

1922
embeddings = HuggingFaceEmbeddings(model_name=env_settings.EMBEDDER_MODEL, model_kwargs={"device": env_settings.EMBEDDER_DEVICE})
@@ -89,3 +92,16 @@ def invoke(self, dialogues: list[Dialogue] = None, graph: DialogueGraph = None,
8992

9093
async def ainvoke(self, *args, **kwargs):
9194
return self.invoke(*args, **kwargs)
95+
96+
async def evaluate(self, dialogues, target_graph, report_type = "dict"):
97+
graph = self.invoke(dialogues)
98+
report = {
99+
"is_same_structure": is_same_structure(graph, target_graph),
100+
"graph_match": compare_graphs(graph, target_graph),
101+
}
102+
if report_type == "dataframe":
103+
report = pd.DataFrame(report, index=[0])
104+
elif report_type == "dict":
105+
return report
106+
else:
107+
raise ValueError(f"Invalid report_type: {report_type}")

0 commit comments

Comments
 (0)