1+ from dataclasses import dataclass
2+ from typing import Optional , Dict , Any
3+ import networkx as nx
4+ from langchain_core .language_models .chat_models import BaseChatModel
5+ from langchain .prompts import PromptTemplate
6+ from chatsky_llm_autoconfig .algorithms .topic_graph_generation import CycleGraphGenerator
7+ from chatsky_llm_autoconfig .algorithms .dialogue_generation import RecursiveDialogueSampler
8+ from chatsky_llm_autoconfig .algorithms .experimental_sampler import get_dialogues , get_full_dialogues
9+ from chatsky_llm_autoconfig .metrics .automatic_metrics import all_utterances_present
10+ from chatsky_llm_autoconfig .metrics .llm_metrics import graph_validation , is_theme_valid
11+ from chatsky_llm_autoconfig .graph import BaseGraph
12+ from chatsky_llm_autoconfig .prompts import cycle_graph_generation_prompt_enhanced , cycle_graph_repair_prompt
13+ from openai import BaseModel
14+
15+ from enum import Enum
16+ from typing import Union
17+
18+ from chatsky_llm_autoconfig .schemas import GraphGenerationResult
19+
20+
21+ class ErrorType (str , Enum ):
22+ """Types of errors that can occur during generation"""
23+ INVALID_GRAPH_STRUCTURE = "invalid_graph_structure"
24+ TOO_MANY_CYCLES = "too_many_cycles"
25+ SAMPLING_FAILED = "sampling_failed"
26+ INVALID_THEME = "invalid_theme"
27+ GENERATION_FAILED = "generation_failed"
28+
29+
30+ class GenerationError (BaseModel ):
31+ """Base error with essential fields"""
32+ error_type : ErrorType
33+ message : str
34+
35+
36+ PipelineResult = Union [GraphGenerationResult , GenerationError ]
37+
38+
39+ @dataclass
40+ class GraphGenerationPipeline :
41+ generation_model : BaseChatModel
42+ validation_model : BaseChatModel
43+ graph_generator : CycleGraphGenerator
44+ generation_prompt : PromptTemplate
45+ repair_prompt : PromptTemplate
46+ min_cycles : int = 2
47+ max_fix_attempts : int = 3
48+
49+ def __init__ (
50+ self ,
51+ generation_model : BaseChatModel ,
52+ validation_model : BaseChatModel ,
53+ generation_prompt : Optional [PromptTemplate ] = None ,
54+ repair_prompt : Optional [PromptTemplate ] = None ,
55+ min_cycles : int = 0 ,
56+ max_fix_attempts : int = 2
57+ ):
58+ self .generation_model = generation_model
59+ self .validation_model = validation_model
60+ self .graph_generator = CycleGraphGenerator ()
61+ self .dialogue_sampler = get_full_dialogues
62+
63+ self .generation_prompt = generation_prompt or cycle_graph_generation_prompt_enhanced
64+ self .repair_prompt = repair_prompt or cycle_graph_repair_prompt
65+
66+ self .min_cycles = min_cycles
67+ self .max_fix_attempts = max_fix_attempts
68+
69+ def validate_graph_cycle_requirement (
70+ self ,
71+ graph : BaseGraph ,
72+ min_cycles : int = 2
73+ ) -> Dict [str , Any ]:
74+ """
75+ Проверяет граф на соответствие требованиям по количеству циклов
76+ """
77+ print ("\n 🔍 Checking graph requirements..." )
78+
79+ try :
80+ cycles = list (nx .simple_cycles (graph .graph ))
81+ cycles_count = len (cycles )
82+
83+ print (f"🔄 Found { cycles_count } cycles in the graph:" )
84+ for i , cycle in enumerate (cycles , 1 ):
85+ print (f"Cycle { i } : { ' -> ' .join (map (str , cycle + [cycle [0 ]]))} " )
86+
87+ meets_requirements = cycles_count >= min_cycles
88+
89+ if not meets_requirements :
90+ print (f"❌ Graph doesn't meet cycle requirements (minimum { min_cycles } cycles needed)" )
91+ else :
92+ print ("✅ Graph meets cycle requirements" )
93+
94+ return {
95+ "meets_requirements" : meets_requirements ,
96+ "cycles" : cycles ,
97+ "cycles_count" : cycles_count
98+ }
99+
100+ except Exception as e :
101+ print (f"❌ Validation error: { str (e )} " )
102+ raise
103+
104+ def check_and_fix_transitions (
105+ self ,
106+ graph : BaseGraph ,
107+ max_attempts : int = 3
108+ ) -> Dict [str , Any ]:
109+ """
110+ Проверяет переходы в графе и пытается исправить невалидные через LLM
111+ """
112+ print ("Validating initial graph" )
113+
114+ initial_validation = graph_validation (graph , self .validation_model )
115+ if initial_validation ["is_valid" ]:
116+ return {
117+ "is_valid" : True ,
118+ "graph" : graph ,
119+ "validation_details" : {
120+ "invalid_transitions" : [],
121+ "attempts_made" : 0 ,
122+ "fixed_count" : 0
123+ }
124+ }
125+
126+ initial_invalid_count = len (initial_validation ["invalid_transitions" ])
127+ current_graph = graph
128+ current_attempt = 0
129+
130+ while current_attempt < max_attempts :
131+ print (f"\n 🔄 Fix attempt { current_attempt + 1 } /{ max_attempts } " )
132+
133+ try :
134+ # Используем generation_model для исправления графа
135+ current_graph = self .graph_generator .invoke (
136+ model = self .generation_model ,
137+ prompt = self .repair_prompt ,
138+ invalid_transitions = initial_validation ["invalid_transitions" ],
139+ graph_json = current_graph .graph_dict
140+ )
141+
142+ # Проверяем исправленный граф используя validation_model
143+ validation = graph_validation (current_graph , self .validation_model )
144+ if validation ["is_valid" ]:
145+ return {
146+ "is_valid" : True ,
147+ "graph" : current_graph ,
148+ "validation_details" : {
149+ "invalid_transitions" : [],
150+ "attempts_made" : current_attempt + 1 ,
151+ "fixed_count" : initial_invalid_count
152+ }
153+ }
154+ else :
155+ print (f"⚠️ Found these { validation ['invalid_transitions' ]} invalid transitions after fix attempt" )
156+
157+ except Exception as e :
158+ print (f"⚠️ Error during fix attempt: { str (e )} " )
159+ break
160+
161+ current_attempt += 1
162+
163+ remaining_invalid = len (validation ["invalid_transitions" ])
164+
165+ return {
166+ "is_valid" : False ,
167+ "graph" : current_graph ,
168+ "validation_details" : {
169+ "invalid_transitions" : validation ["invalid_transitions" ],
170+ "attempts_made" : current_attempt ,
171+ "fixed_count" : initial_invalid_count - remaining_invalid
172+ }
173+ }
174+
175+ def generate_and_validate (self , topic : str ) -> PipelineResult :
176+ """
177+ Generates and validates a dialogue graph for given topic
178+ """
179+ try :
180+ # 1. Generate initial graph
181+ print ("Generating Graph ..." )
182+ graph = self .graph_generator .invoke (
183+ model = self .generation_model ,
184+ prompt = self .generation_prompt ,
185+ topic = topic
186+ )
187+
188+ # 2. Validate cycles
189+ cycle_validation = self .validate_graph_cycle_requirement (graph , self .min_cycles )
190+ if not cycle_validation ["meets_requirements" ]:
191+ return GenerationError (
192+ error_type = ErrorType .TOO_MANY_CYCLES ,
193+ message = f"Graph requires minimum { self .min_cycles } cycles, found { cycle_validation ['cycles_count' ]} "
194+ )
195+
196+ # 3. Generate and validate dialogues
197+ print ("Sampling dialogues..." )
198+ sampled_dialogues = self .dialogue_sampler (graph , 15 )
199+ print (f"Sampled { len (sampled_dialogues )} dialogues" )
200+ print (sampled_dialogues )
201+ if not all_utterances_present (graph , sampled_dialogues ):
202+ return GenerationError (
203+ error_type = ErrorType .SAMPLING_FAILED ,
204+ message = "Failed to sample valid dialogues - not all utterances are present"
205+ )
206+
207+ # 4. Validate theme
208+ theme_validation = is_theme_valid (graph , self .validation_model , topic )
209+ if not theme_validation ["value" ]:
210+ return GenerationError (
211+ error_type = ErrorType .INVALID_THEME ,
212+ message = f"Theme validation failed: { theme_validation ['description' ]} "
213+ )
214+
215+ # 5. Validate and fix transitions
216+ print ("Validating and fixing transitions..." )
217+ transition_validation = self .check_and_fix_transitions (
218+ graph = graph ,
219+ max_attempts = self .max_fix_attempts
220+ )
221+
222+ if not transition_validation ["is_valid" ]:
223+ invalid_transitions = transition_validation ["validation_details" ]["invalid_transitions" ]
224+ return GenerationError (
225+ error_type = ErrorType .INVALID_GRAPH_STRUCTURE ,
226+ message = f"Found { len (invalid_transitions )} invalid transitions after { transition_validation ['validation_details' ]['attempts_made' ]} fix attempts"
227+ )
228+
229+ # All validations passed - return successful result
230+ return GraphGenerationResult (
231+ graph = transition_validation ["graph" ].graph_dict ,
232+ topic = topic ,
233+ dialogues = sampled_dialogues
234+ )
235+
236+ except Exception as e :
237+ return GenerationError (
238+ error_type = ErrorType .GENERATION_FAILED ,
239+ message = f"Unexpected error during generation: { str (e )} "
240+ )
241+
242+ def __call__ (self , topic : str ) -> PipelineResult :
243+ """Shorthand for generate_and_validate"""
244+ return self .generate_and_validate (topic )
0 commit comments