Skip to content

Commit 9d9a84b

Browse files
Added LoopedGraphGenerator to complex graph datasets
1 parent 437c249 commit 9d9a84b

File tree

5 files changed

+383
-68
lines changed

5 files changed

+383
-68
lines changed
Lines changed: 105 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
1+
from enum import Enum
2+
from typing import Optional, Dict, Any, Union
13
from dataclasses import dataclass
2-
from typing import Optional, Dict, Any
4+
5+
from pydantic import BaseModel, Field
36
import networkx as nx
4-
from langchain_core.language_models.chat_models import BaseChatModel
7+
58
from langchain.prompts import PromptTemplate
9+
from langchain_core.output_parsers.pydantic import PydanticOutputParser
10+
from langchain_core.language_models.chat_models import BaseChatModel
611

7-
# from chatsky_llm_autoconfig.algorithms.topic_graph_generation import CycleGraphGenerator
8-
# from chatsky_llm_autoconfig.algorithms.dialogue_generation import RecursiveDialogueSampler
9-
# from chatsky_llm_autoconfig.metrics.automatic_metrics import all_utterances_present
10-
# from chatsky_llm_autoconfig.metrics.llm_metrics import are_triples_valid, is_theme_valid
11-
# from chatsky_llm_autoconfig.graph import BaseGraph
12-
# from chatsky_llm_autoconfig.prompts import cycle_graph_generation_prompt_enhanced, cycle_graph_repair_prompt
13-
from openai import BaseModel
12+
from dialogue2graph.pipelines.core.dialogue_sampling import RecursiveDialogueSampler
13+
from dialogue2graph.metrics.automatic_metrics import all_utterances_present
14+
from dialogue2graph.metrics.llm_metrics import are_triplets_valid, is_theme_valid
15+
from dialogue2graph.pipelines.core.graph import BaseGraph, Graph
16+
from dialogue2graph.pipelines.core.algorithms import TopicGraphGenerator
17+
from dialogue2graph.pipelines.core.schemas import GraphGenerationResult, DialogueGraph
1418

15-
from enum import Enum
16-
from typing import Union
17-
18-
from dialogue2graph.pipelines.core.schemas import GraphGenerationResult
19+
from .prompts import cycle_graph_generation_prompt_enhanced, cycle_graph_repair_prompt
1920

2021

2122
class ErrorType(str, Enum):
@@ -38,70 +39,83 @@ class GenerationError(BaseModel):
3839
PipelineResult = Union[GraphGenerationResult, GenerationError]
3940

4041

42+
class CycleGraphGenerator(TopicGraphGenerator):
43+
"""Generator specifically for topic-based cyclic graphs"""
44+
45+
def __init__(self):
46+
super().__init__()
47+
48+
def invoke(self, model: BaseChatModel, prompt: PromptTemplate, **kwargs) -> BaseGraph:
49+
"""
50+
Generate a cyclic dialogue graph based on the topic input.
51+
"""
52+
parser = PydanticOutputParser(pydantic_object=DialogueGraph)
53+
chain = prompt | model | parser
54+
return Graph(chain.invoke(kwargs))
55+
56+
async def ainvoke(self, *args, **kwargs):
57+
"""Async version of invoke - to be implemented"""
58+
pass
59+
60+
def evaluate(self, *args, report_type="dict", **kwargs):
61+
pass
62+
63+
4164
@dataclass
42-
class GenerationPipeline:
65+
class GenerationPipeline(BaseModel):
4366
generation_model: BaseChatModel
4467
validation_model: BaseChatModel
45-
graph_generator: CycleGraphGenerator
46-
generation_prompt: PromptTemplate
47-
repair_prompt: PromptTemplate
68+
graph_generator: CycleGraphGenerator = Field(default_factory=CycleGraphGenerator)
69+
generation_prompt: PromptTemplate = Field(default_factory=lambda: cycle_graph_generation_prompt_enhanced)
70+
repair_prompt: PromptTemplate = Field(default_factory=lambda: cycle_graph_repair_prompt)
4871
min_cycles: int = 2
4972
max_fix_attempts: int = 3
73+
dialogue_sampler: RecursiveDialogueSampler = Field(default_factory=RecursiveDialogueSampler)
5074

5175
def __init__(
5276
self,
5377
generation_model: BaseChatModel,
5478
validation_model: BaseChatModel,
55-
generation_prompt: Optional[PromptTemplate] = None,
56-
repair_prompt: Optional[PromptTemplate] = None,
79+
generation_prompt: Optional[PromptTemplate],
80+
repair_prompt: Optional[PromptTemplate],
5781
min_cycles: int = 0,
5882
max_fix_attempts: int = 2,
5983
):
60-
self.generation_model = generation_model
61-
self.validation_model = validation_model
62-
self.graph_generator = CycleGraphGenerator()
63-
self.dialogue_sampler = RecursiveDialogueSampler()
64-
65-
self.generation_prompt = generation_prompt or cycle_graph_generation_prompt_enhanced
66-
self.repair_prompt = repair_prompt or cycle_graph_repair_prompt
67-
68-
self.min_cycles = min_cycles
69-
self.max_fix_attempts = max_fix_attempts
84+
super().__init__(
85+
generation_model=generation_model,
86+
validation_model=validation_model,
87+
generation_prompt=generation_prompt,
88+
repair_prompt=repair_prompt,
89+
min_cycles=min_cycles,
90+
max_fix_attempts=max_fix_attempts,
91+
)
7092

7193
def validate_graph_cycle_requirement(self, graph: BaseGraph, min_cycles: int = 2) -> Dict[str, Any]:
72-
"""
73-
Проверяет граф на соответствие требованиям по количеству циклов
74-
"""
94+
"""Checks the graph for cycle requirements"""
7595
print("\n🔍 Checking graph requirements...")
76-
7796
try:
7897
cycles = list(nx.simple_cycles(graph.graph))
7998
cycles_count = len(cycles)
80-
8199
print(f"🔄 Found {cycles_count} cycles in the graph:")
82100
for i, cycle in enumerate(cycles, 1):
83101
print(f"Cycle {i}: {' -> '.join(map(str, cycle + [cycle[0]]))}")
84102

85103
meets_requirements = cycles_count >= min_cycles
86-
87-
if not meets_requirements:
88-
print(f"❌ Graph doesn't meet cycle requirements (minimum {min_cycles} cycles needed)")
89-
else:
90-
print("✅ Graph meets cycle requirements")
91-
104+
print(
105+
"✅ Graph meets cycle requirements"
106+
if meets_requirements
107+
else f"❌ Graph doesn't meet cycle requirements (minimum {min_cycles} cycles needed)"
108+
)
92109
return {"meets_requirements": meets_requirements, "cycles": cycles, "cycles_count": cycles_count}
93110

94111
except Exception as e:
95112
print(f"❌ Validation error: {str(e)}")
96113
raise
97114

98115
def check_and_fix_transitions(self, graph: BaseGraph, max_attempts: int = 3) -> Dict[str, Any]:
99-
"""
100-
Проверяет переходы в графе и пытается исправить невалидные через LLM
101-
"""
116+
"""Checks transitions in the graph and attempts to fix invalid ones via LLM"""
102117
print("Validating initial graph")
103-
104-
initial_validation = are_triples_valid(graph, self.validation_model, return_type="detailed")
118+
initial_validation = are_triplets_valid(graph, self.validation_model, return_type="detailed")
105119
if initial_validation["is_valid"]:
106120
return {"is_valid": True, "graph": graph, "validation_details": {"invalid_transitions": [], "attempts_made": 0, "fixed_count": 0}}
107121

@@ -111,18 +125,15 @@ def check_and_fix_transitions(self, graph: BaseGraph, max_attempts: int = 3) ->
111125

112126
while current_attempt < max_attempts:
113127
print(f"\n🔄 Fix attempt {current_attempt + 1}/{max_attempts}")
114-
115128
try:
116-
# Используем generation_model для исправления графа
117129
current_graph = self.graph_generator.invoke(
118130
model=self.generation_model,
119131
prompt=self.repair_prompt,
120132
invalid_transitions=initial_validation["invalid_transitions"],
121133
graph_json=current_graph.graph_dict,
122134
)
123135

124-
# Проверяем исправленный граф используя validation_model
125-
validation = are_triples_valid(current_graph, self.validation_model, return_type="detailed")
136+
validation = are_triplets_valid(current_graph, self.validation_model, return_type="detailed")
126137
if validation["is_valid"]:
127138
return {
128139
"is_valid": True,
@@ -139,7 +150,6 @@ def check_and_fix_transitions(self, graph: BaseGraph, max_attempts: int = 3) ->
139150
current_attempt += 1
140151

141152
remaining_invalid = len(validation["invalid_transitions"])
142-
143153
return {
144154
"is_valid": False,
145155
"graph": current_graph,
@@ -151,38 +161,30 @@ def check_and_fix_transitions(self, graph: BaseGraph, max_attempts: int = 3) ->
151161
}
152162

153163
def generate_and_validate(self, topic: str) -> PipelineResult:
154-
"""
155-
Generates and validates a dialogue graph for given topic
156-
"""
164+
"""Generates and validates a dialogue graph for given topic"""
157165
try:
158-
# 1. Generate initial graph
159166
print("Generating Graph ...")
160167
graph = self.graph_generator.invoke(model=self.generation_model, prompt=self.generation_prompt, topic=topic)
161168

162-
# 2. Validate cycles
163169
cycle_validation = self.validate_graph_cycle_requirement(graph, self.min_cycles)
164170
if not cycle_validation["meets_requirements"]:
165171
return GenerationError(
166172
error_type=ErrorType.TOO_MANY_CYCLES,
167173
message=f"Graph requires minimum {self.min_cycles} cycles, found {cycle_validation['cycles_count']}",
168174
)
169175

170-
# 3. Generate and validate dialogues
171176
print("Sampling dialogues...")
172177
sampled_dialogues = self.dialogue_sampler.invoke(graph, 15)
173178
print(f"Sampled {len(sampled_dialogues)} dialogues")
174-
print(sampled_dialogues)
175179
if not all_utterances_present(graph, sampled_dialogues):
176180
return GenerationError(
177181
error_type=ErrorType.SAMPLING_FAILED, message="Failed to sample valid dialogues - not all utterances are present"
178182
)
179183

180-
# 4. Validate theme
181184
theme_validation = is_theme_valid(graph, self.validation_model, topic)
182185
if not theme_validation["value"]:
183186
return GenerationError(error_type=ErrorType.INVALID_THEME, message=f"Theme validation failed: {theme_validation['description']}")
184187

185-
# 5. Validate and fix transitions
186188
print("Validating and fixing transitions...")
187189
transition_validation = self.check_and_fix_transitions(graph=graph, max_attempts=self.max_fix_attempts)
188190

@@ -193,7 +195,6 @@ def generate_and_validate(self, topic: str) -> PipelineResult:
193195
message=f"Found {len(invalid_transitions)} invalid transitions after {transition_validation['validation_details']['attempts_made']} fix attempts",
194196
)
195197

196-
# All validations passed - return successful result
197198
return GraphGenerationResult(graph=transition_validation["graph"].graph_dict, topic=topic, dialogues=sampled_dialogues)
198199

199200
except Exception as e:
@@ -202,3 +203,47 @@ def generate_and_validate(self, topic: str) -> PipelineResult:
202203
def __call__(self, topic: str) -> PipelineResult:
203204
"""Shorthand for generate_and_validate"""
204205
return self.generate_and_validate(topic)
206+
207+
208+
class LoopedGraphGenerator(TopicGraphGenerator):
209+
generation_model: BaseChatModel
210+
validation_model: BaseChatModel
211+
pipeline: GenerationPipeline
212+
213+
def __init__(self, generation_model: BaseChatModel, validation_model: BaseChatModel):
214+
super().__init__(
215+
generation_model=generation_model,
216+
validation_model=validation_model,
217+
pipeline=GenerationPipeline(
218+
generation_model=generation_model,
219+
validation_model=validation_model,
220+
generation_prompt=cycle_graph_generation_prompt_enhanced,
221+
repair_prompt=cycle_graph_repair_prompt,
222+
),
223+
)
224+
225+
def invoke(self, topic) -> list[dict]:
226+
print(f"\n{'='*50}")
227+
print(f"Generating graph for topic: {topic}")
228+
print(f"{'='*50}")
229+
successful_generations = []
230+
try:
231+
result = self.pipeline(topic)
232+
233+
if isinstance(result, GraphGenerationResult):
234+
print(f"✅ Successfully generated graph for {topic}")
235+
successful_generations.append(
236+
{"graph": result.graph.model_dump(), "topic": result.topic, "dialogues": [d.model_dump() for d in result.dialogues]}
237+
)
238+
else:
239+
print(f"❌ Failed to generate graph for {topic}")
240+
print(f"Error type: {result.error_type}")
241+
print(f"Error message: {result.message}")
242+
243+
except Exception as e:
244+
print(f"❌ Unexpected error processing topic '{topic}': {str(e)}")
245+
246+
return successful_generations
247+
248+
def evaluate(self, *args, report_type="dict", **kwargs):
249+
return super().evaluate(*args, report_type=report_type, **kwargs)

0 commit comments

Comments
 (0)