Skip to content

Commit 192cbbd

Browse files
Experiments and updates on recursive sampler and LoopedGraphGenerator
1 parent a4b537f commit 192cbbd

File tree

20 files changed

+3128055
-125
lines changed

20 files changed

+3128055
-125
lines changed
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
from dataclasses import dataclass
2+
from typing import Optional, Dict, Any
3+
import networkx as nx
4+
from langchain_core.language_models.chat_models import BaseChatModel
5+
from langchain.prompts import PromptTemplate
6+
from chatsky_llm_autoconfig.algorithms.topic_graph_generation import CycleGraphGenerator
7+
from chatsky_llm_autoconfig.algorithms.dialogue_generation import RecursiveDialogueSampler
8+
from chatsky_llm_autoconfig.algorithms.experimental_sampler import get_dialogues, get_full_dialogues
9+
from chatsky_llm_autoconfig.metrics.automatic_metrics import all_utterances_present
10+
from chatsky_llm_autoconfig.metrics.llm_metrics import graph_validation, is_theme_valid
11+
from chatsky_llm_autoconfig.graph import BaseGraph
12+
from chatsky_llm_autoconfig.prompts import cycle_graph_generation_prompt_enhanced, cycle_graph_repair_prompt
13+
from openai import BaseModel
14+
15+
from enum import Enum
16+
from typing import Union
17+
18+
from chatsky_llm_autoconfig.schemas import GraphGenerationResult
19+
20+
21+
class ErrorType(str, Enum):
22+
"""Types of errors that can occur during generation"""
23+
INVALID_GRAPH_STRUCTURE = "invalid_graph_structure"
24+
TOO_MANY_CYCLES = "too_many_cycles"
25+
SAMPLING_FAILED = "sampling_failed"
26+
INVALID_THEME = "invalid_theme"
27+
GENERATION_FAILED = "generation_failed"
28+
29+
30+
class GenerationError(BaseModel):
31+
"""Base error with essential fields"""
32+
error_type: ErrorType
33+
message: str
34+
35+
36+
PipelineResult = Union[GraphGenerationResult, GenerationError]
37+
38+
39+
@dataclass
40+
class GraphGenerationPipeline:
41+
generation_model: BaseChatModel
42+
validation_model: BaseChatModel
43+
graph_generator: CycleGraphGenerator
44+
generation_prompt: PromptTemplate
45+
repair_prompt: PromptTemplate
46+
min_cycles: int = 2
47+
max_fix_attempts: int = 3
48+
49+
def __init__(
50+
self,
51+
generation_model: BaseChatModel,
52+
validation_model: BaseChatModel,
53+
generation_prompt: Optional[PromptTemplate] = None,
54+
repair_prompt: Optional[PromptTemplate] = None,
55+
min_cycles: int = 0,
56+
max_fix_attempts: int = 2
57+
):
58+
self.generation_model = generation_model
59+
self.validation_model = validation_model
60+
self.graph_generator = CycleGraphGenerator()
61+
self.dialogue_sampler = get_full_dialogues
62+
63+
self.generation_prompt = generation_prompt or cycle_graph_generation_prompt_enhanced
64+
self.repair_prompt = repair_prompt or cycle_graph_repair_prompt
65+
66+
self.min_cycles = min_cycles
67+
self.max_fix_attempts = max_fix_attempts
68+
69+
def validate_graph_cycle_requirement(
70+
self,
71+
graph: BaseGraph,
72+
min_cycles: int = 2
73+
) -> Dict[str, Any]:
74+
"""
75+
Проверяет граф на соответствие требованиям по количеству циклов
76+
"""
77+
print("\n🔍 Checking graph requirements...")
78+
79+
try:
80+
cycles = list(nx.simple_cycles(graph.graph))
81+
cycles_count = len(cycles)
82+
83+
print(f"🔄 Found {cycles_count} cycles in the graph:")
84+
for i, cycle in enumerate(cycles, 1):
85+
print(f"Cycle {i}: {' -> '.join(map(str, cycle + [cycle[0]]))}")
86+
87+
meets_requirements = cycles_count >= min_cycles
88+
89+
if not meets_requirements:
90+
print(f"❌ Graph doesn't meet cycle requirements (minimum {min_cycles} cycles needed)")
91+
else:
92+
print("✅ Graph meets cycle requirements")
93+
94+
return {
95+
"meets_requirements": meets_requirements,
96+
"cycles": cycles,
97+
"cycles_count": cycles_count
98+
}
99+
100+
except Exception as e:
101+
print(f"❌ Validation error: {str(e)}")
102+
raise
103+
104+
def check_and_fix_transitions(
105+
self,
106+
graph: BaseGraph,
107+
max_attempts: int = 3
108+
) -> Dict[str, Any]:
109+
"""
110+
Проверяет переходы в графе и пытается исправить невалидные через LLM
111+
"""
112+
print("Validating initial graph")
113+
114+
initial_validation = graph_validation(graph, self.validation_model)
115+
if initial_validation["is_valid"]:
116+
return {
117+
"is_valid": True,
118+
"graph": graph,
119+
"validation_details": {
120+
"invalid_transitions": [],
121+
"attempts_made": 0,
122+
"fixed_count": 0
123+
}
124+
}
125+
126+
initial_invalid_count = len(initial_validation["invalid_transitions"])
127+
current_graph = graph
128+
current_attempt = 0
129+
130+
while current_attempt < max_attempts:
131+
print(f"\n🔄 Fix attempt {current_attempt + 1}/{max_attempts}")
132+
133+
try:
134+
# Используем generation_model для исправления графа
135+
current_graph = self.graph_generator.invoke(
136+
model=self.generation_model,
137+
prompt=self.repair_prompt,
138+
invalid_transitions=initial_validation["invalid_transitions"],
139+
graph_json=current_graph.graph_dict
140+
)
141+
142+
# Проверяем исправленный граф используя validation_model
143+
validation = graph_validation(current_graph, self.validation_model)
144+
if validation["is_valid"]:
145+
return {
146+
"is_valid": True,
147+
"graph": current_graph,
148+
"validation_details": {
149+
"invalid_transitions": [],
150+
"attempts_made": current_attempt + 1,
151+
"fixed_count": initial_invalid_count
152+
}
153+
}
154+
else:
155+
print(f"⚠️ Found these {validation['invalid_transitions']} invalid transitions after fix attempt")
156+
157+
except Exception as e:
158+
print(f"⚠️ Error during fix attempt: {str(e)}")
159+
break
160+
161+
current_attempt += 1
162+
163+
remaining_invalid = len(validation["invalid_transitions"])
164+
165+
return {
166+
"is_valid": False,
167+
"graph": current_graph,
168+
"validation_details": {
169+
"invalid_transitions": validation["invalid_transitions"],
170+
"attempts_made": current_attempt,
171+
"fixed_count": initial_invalid_count - remaining_invalid
172+
}
173+
}
174+
175+
def generate_and_validate(self, topic: str) -> PipelineResult:
176+
"""
177+
Generates and validates a dialogue graph for given topic
178+
"""
179+
try:
180+
# 1. Generate initial graph
181+
print("Generating Graph ...")
182+
graph = self.graph_generator.invoke(
183+
model=self.generation_model,
184+
prompt=self.generation_prompt,
185+
topic=topic
186+
)
187+
188+
# 2. Validate cycles
189+
cycle_validation = self.validate_graph_cycle_requirement(graph, self.min_cycles)
190+
if not cycle_validation["meets_requirements"]:
191+
return GenerationError(
192+
error_type=ErrorType.TOO_MANY_CYCLES,
193+
message=f"Graph requires minimum {self.min_cycles} cycles, found {cycle_validation['cycles_count']}"
194+
)
195+
196+
# 3. Generate and validate dialogues
197+
print("Sampling dialogues...")
198+
sampled_dialogues = self.dialogue_sampler(graph, 15)
199+
print(f"Sampled {len(sampled_dialogues)} dialogues")
200+
print(sampled_dialogues)
201+
if not all_utterances_present(graph, sampled_dialogues):
202+
return GenerationError(
203+
error_type=ErrorType.SAMPLING_FAILED,
204+
message="Failed to sample valid dialogues - not all utterances are present"
205+
)
206+
207+
# 4. Validate theme
208+
theme_validation = is_theme_valid(graph, self.validation_model, topic)
209+
if not theme_validation["value"]:
210+
return GenerationError(
211+
error_type=ErrorType.INVALID_THEME,
212+
message=f"Theme validation failed: {theme_validation['description']}"
213+
)
214+
215+
# 5. Validate and fix transitions
216+
print("Validating and fixing transitions...")
217+
transition_validation = self.check_and_fix_transitions(
218+
graph=graph,
219+
max_attempts=self.max_fix_attempts
220+
)
221+
222+
if not transition_validation["is_valid"]:
223+
invalid_transitions = transition_validation["validation_details"]["invalid_transitions"]
224+
return GenerationError(
225+
error_type=ErrorType.INVALID_GRAPH_STRUCTURE,
226+
message=f"Found {len(invalid_transitions)} invalid transitions after {transition_validation['validation_details']['attempts_made']} fix attempts"
227+
)
228+
229+
# All validations passed - return successful result
230+
return GraphGenerationResult(
231+
graph=transition_validation["graph"].graph_dict,
232+
topic=topic,
233+
dialogues=sampled_dialogues
234+
)
235+
236+
except Exception as e:
237+
return GenerationError(
238+
error_type=ErrorType.GENERATION_FAILED,
239+
message=f"Unexpected error during generation: {str(e)}"
240+
)
241+
242+
def __call__(self, topic: str) -> PipelineResult:
243+
"""Shorthand for generate_and_validate"""
244+
return self.generate_and_validate(topic)

dev_packages/chatsky_llm_autoconfig/chatsky_llm_autoconfig/algorithms/data_generation.py

Whitespace-only changes.

0 commit comments

Comments
 (0)