1+ from enum import Enum
2+ from typing import Optional , Dict , Any , Union
13from dataclasses import dataclass
2- from typing import Optional , Dict , Any
4+
5+ from pydantic import BaseModel , Field
36import networkx as nx
4- from langchain_core . language_models . chat_models import BaseChatModel
7+
58from langchain .prompts import PromptTemplate
9+ from langchain_core .output_parsers .pydantic import PydanticOutputParser
10+ from langchain_core .language_models .chat_models import BaseChatModel
611
7- # from chatsky_llm_autoconfig.algorithms.topic_graph_generation import CycleGraphGenerator
8- # from chatsky_llm_autoconfig.algorithms.dialogue_generation import RecursiveDialogueSampler
9- # from chatsky_llm_autoconfig.metrics.automatic_metrics import all_utterances_present
10- # from chatsky_llm_autoconfig.metrics.llm_metrics import are_triples_valid, is_theme_valid
11- # from chatsky_llm_autoconfig.graph import BaseGraph
12- # from chatsky_llm_autoconfig.prompts import cycle_graph_generation_prompt_enhanced, cycle_graph_repair_prompt
13- from openai import BaseModel
12+ from dialogue2graph .pipelines .core .dialogue_sampling import RecursiveDialogueSampler
13+ from dialogue2graph .metrics .automatic_metrics import all_utterances_present
14+ from dialogue2graph .metrics .llm_metrics import are_triplets_valid , is_theme_valid
15+ from dialogue2graph .pipelines .core .graph import BaseGraph , Graph
16+ from dialogue2graph .pipelines .core .algorithms import TopicGraphGenerator
17+ from dialogue2graph .pipelines .core .schemas import GraphGenerationResult , DialogueGraph
1418
15- from enum import Enum
16- from typing import Union
17-
18- from dialogue2graph .pipelines .core .schemas import GraphGenerationResult
19+ from .prompts import cycle_graph_generation_prompt_enhanced , cycle_graph_repair_prompt
1920
2021
2122class ErrorType (str , Enum ):
@@ -38,70 +39,83 @@ class GenerationError(BaseModel):
3839PipelineResult = Union [GraphGenerationResult , GenerationError ]
3940
4041
42+ class CycleGraphGenerator (TopicGraphGenerator ):
43+ """Generator specifically for topic-based cyclic graphs"""
44+
45+ def __init__ (self ):
46+ super ().__init__ ()
47+
48+ def invoke (self , model : BaseChatModel , prompt : PromptTemplate , ** kwargs ) -> BaseGraph :
49+ """
50+ Generate a cyclic dialogue graph based on the topic input.
51+ """
52+ parser = PydanticOutputParser (pydantic_object = DialogueGraph )
53+ chain = prompt | model | parser
54+ return Graph (chain .invoke (kwargs ))
55+
56+ async def ainvoke (self , * args , ** kwargs ):
57+ """Async version of invoke - to be implemented"""
58+ pass
59+
60+ def evaluate (self , * args , report_type = "dict" , ** kwargs ):
61+ pass
62+
63+
4164@dataclass
42- class GenerationPipeline :
65+ class GenerationPipeline ( BaseModel ) :
4366 generation_model : BaseChatModel
4467 validation_model : BaseChatModel
45- graph_generator : CycleGraphGenerator
46- generation_prompt : PromptTemplate
47- repair_prompt : PromptTemplate
68+ graph_generator : CycleGraphGenerator = Field ( default_factory = CycleGraphGenerator )
69+ generation_prompt : PromptTemplate = Field ( default_factory = lambda : cycle_graph_generation_prompt_enhanced )
70+ repair_prompt : PromptTemplate = Field ( default_factory = lambda : cycle_graph_repair_prompt )
4871 min_cycles : int = 2
4972 max_fix_attempts : int = 3
73+ dialogue_sampler : RecursiveDialogueSampler = Field (default_factory = RecursiveDialogueSampler )
5074
5175 def __init__ (
5276 self ,
5377 generation_model : BaseChatModel ,
5478 validation_model : BaseChatModel ,
55- generation_prompt : Optional [PromptTemplate ] = None ,
56- repair_prompt : Optional [PromptTemplate ] = None ,
79+ generation_prompt : Optional [PromptTemplate ],
80+ repair_prompt : Optional [PromptTemplate ],
5781 min_cycles : int = 0 ,
5882 max_fix_attempts : int = 2 ,
5983 ):
60- self .generation_model = generation_model
61- self .validation_model = validation_model
62- self .graph_generator = CycleGraphGenerator ()
63- self .dialogue_sampler = RecursiveDialogueSampler ()
64-
65- self .generation_prompt = generation_prompt or cycle_graph_generation_prompt_enhanced
66- self .repair_prompt = repair_prompt or cycle_graph_repair_prompt
67-
68- self .min_cycles = min_cycles
69- self .max_fix_attempts = max_fix_attempts
84+ super ().__init__ (
85+ generation_model = generation_model ,
86+ validation_model = validation_model ,
87+ generation_prompt = generation_prompt ,
88+ repair_prompt = repair_prompt ,
89+ min_cycles = min_cycles ,
90+ max_fix_attempts = max_fix_attempts ,
91+ )
7092
7193 def validate_graph_cycle_requirement (self , graph : BaseGraph , min_cycles : int = 2 ) -> Dict [str , Any ]:
72- """
73- Проверяет граф на соответствие требованиям по количеству циклов
74- """
94+ """Checks the graph for cycle requirements"""
7595 print ("\n 🔍 Checking graph requirements..." )
76-
7796 try :
7897 cycles = list (nx .simple_cycles (graph .graph ))
7998 cycles_count = len (cycles )
80-
8199 print (f"🔄 Found { cycles_count } cycles in the graph:" )
82100 for i , cycle in enumerate (cycles , 1 ):
83101 print (f"Cycle { i } : { ' -> ' .join (map (str , cycle + [cycle [0 ]]))} " )
84102
85103 meets_requirements = cycles_count >= min_cycles
86-
87- if not meets_requirements :
88- print (f"❌ Graph doesn't meet cycle requirements (minimum { min_cycles } cycles needed)" )
89- else :
90- print ("✅ Graph meets cycle requirements" )
91-
104+ print (
105+ "✅ Graph meets cycle requirements"
106+ if meets_requirements
107+ else f"❌ Graph doesn't meet cycle requirements (minimum { min_cycles } cycles needed)"
108+ )
92109 return {"meets_requirements" : meets_requirements , "cycles" : cycles , "cycles_count" : cycles_count }
93110
94111 except Exception as e :
95112 print (f"❌ Validation error: { str (e )} " )
96113 raise
97114
98115 def check_and_fix_transitions (self , graph : BaseGraph , max_attempts : int = 3 ) -> Dict [str , Any ]:
99- """
100- Проверяет переходы в графе и пытается исправить невалидные через LLM
101- """
116+ """Checks transitions in the graph and attempts to fix invalid ones via LLM"""
102117 print ("Validating initial graph" )
103-
104- initial_validation = are_triples_valid (graph , self .validation_model , return_type = "detailed" )
118+ initial_validation = are_triplets_valid (graph , self .validation_model , return_type = "detailed" )
105119 if initial_validation ["is_valid" ]:
106120 return {"is_valid" : True , "graph" : graph , "validation_details" : {"invalid_transitions" : [], "attempts_made" : 0 , "fixed_count" : 0 }}
107121
@@ -111,18 +125,15 @@ def check_and_fix_transitions(self, graph: BaseGraph, max_attempts: int = 3) ->
111125
112126 while current_attempt < max_attempts :
113127 print (f"\n 🔄 Fix attempt { current_attempt + 1 } /{ max_attempts } " )
114-
115128 try :
116- # Используем generation_model для исправления графа
117129 current_graph = self .graph_generator .invoke (
118130 model = self .generation_model ,
119131 prompt = self .repair_prompt ,
120132 invalid_transitions = initial_validation ["invalid_transitions" ],
121133 graph_json = current_graph .graph_dict ,
122134 )
123135
124- # Проверяем исправленный граф используя validation_model
125- validation = are_triples_valid (current_graph , self .validation_model , return_type = "detailed" )
136+ validation = are_triplets_valid (current_graph , self .validation_model , return_type = "detailed" )
126137 if validation ["is_valid" ]:
127138 return {
128139 "is_valid" : True ,
@@ -139,7 +150,6 @@ def check_and_fix_transitions(self, graph: BaseGraph, max_attempts: int = 3) ->
139150 current_attempt += 1
140151
141152 remaining_invalid = len (validation ["invalid_transitions" ])
142-
143153 return {
144154 "is_valid" : False ,
145155 "graph" : current_graph ,
@@ -151,38 +161,30 @@ def check_and_fix_transitions(self, graph: BaseGraph, max_attempts: int = 3) ->
151161 }
152162
153163 def generate_and_validate (self , topic : str ) -> PipelineResult :
154- """
155- Generates and validates a dialogue graph for given topic
156- """
164+ """Generates and validates a dialogue graph for given topic"""
157165 try :
158- # 1. Generate initial graph
159166 print ("Generating Graph ..." )
160167 graph = self .graph_generator .invoke (model = self .generation_model , prompt = self .generation_prompt , topic = topic )
161168
162- # 2. Validate cycles
163169 cycle_validation = self .validate_graph_cycle_requirement (graph , self .min_cycles )
164170 if not cycle_validation ["meets_requirements" ]:
165171 return GenerationError (
166172 error_type = ErrorType .TOO_MANY_CYCLES ,
167173 message = f"Graph requires minimum { self .min_cycles } cycles, found { cycle_validation ['cycles_count' ]} " ,
168174 )
169175
170- # 3. Generate and validate dialogues
171176 print ("Sampling dialogues..." )
172177 sampled_dialogues = self .dialogue_sampler .invoke (graph , 15 )
173178 print (f"Sampled { len (sampled_dialogues )} dialogues" )
174- print (sampled_dialogues )
175179 if not all_utterances_present (graph , sampled_dialogues ):
176180 return GenerationError (
177181 error_type = ErrorType .SAMPLING_FAILED , message = "Failed to sample valid dialogues - not all utterances are present"
178182 )
179183
180- # 4. Validate theme
181184 theme_validation = is_theme_valid (graph , self .validation_model , topic )
182185 if not theme_validation ["value" ]:
183186 return GenerationError (error_type = ErrorType .INVALID_THEME , message = f"Theme validation failed: { theme_validation ['description' ]} " )
184187
185- # 5. Validate and fix transitions
186188 print ("Validating and fixing transitions..." )
187189 transition_validation = self .check_and_fix_transitions (graph = graph , max_attempts = self .max_fix_attempts )
188190
@@ -193,7 +195,6 @@ def generate_and_validate(self, topic: str) -> PipelineResult:
193195 message = f"Found { len (invalid_transitions )} invalid transitions after { transition_validation ['validation_details' ]['attempts_made' ]} fix attempts" ,
194196 )
195197
196- # All validations passed - return successful result
197198 return GraphGenerationResult (graph = transition_validation ["graph" ].graph_dict , topic = topic , dialogues = sampled_dialogues )
198199
199200 except Exception as e :
@@ -202,3 +203,47 @@ def generate_and_validate(self, topic: str) -> PipelineResult:
202203 def __call__ (self , topic : str ) -> PipelineResult :
203204 """Shorthand for generate_and_validate"""
204205 return self .generate_and_validate (topic )
206+
207+
208+ class LoopedGraphGenerator (TopicGraphGenerator ):
209+ generation_model : BaseChatModel
210+ validation_model : BaseChatModel
211+ pipeline : GenerationPipeline
212+
213+ def __init__ (self , generation_model : BaseChatModel , validation_model : BaseChatModel ):
214+ super ().__init__ (
215+ generation_model = generation_model ,
216+ validation_model = validation_model ,
217+ pipeline = GenerationPipeline (
218+ generation_model = generation_model ,
219+ validation_model = validation_model ,
220+ generation_prompt = cycle_graph_generation_prompt_enhanced ,
221+ repair_prompt = cycle_graph_repair_prompt ,
222+ ),
223+ )
224+
225+ def invoke (self , topic ) -> list [dict ]:
226+ print (f"\n { '=' * 50 } " )
227+ print (f"Generating graph for topic: { topic } " )
228+ print (f"{ '=' * 50 } " )
229+ successful_generations = []
230+ try :
231+ result = self .pipeline (topic )
232+
233+ if isinstance (result , GraphGenerationResult ):
234+ print (f"✅ Successfully generated graph for { topic } " )
235+ successful_generations .append (
236+ {"graph" : result .graph .model_dump (), "topic" : result .topic , "dialogues" : [d .model_dump () for d in result .dialogues ]}
237+ )
238+ else :
239+ print (f"❌ Failed to generate graph for { topic } " )
240+ print (f"Error type: { result .error_type } " )
241+ print (f"Error message: { result .message } " )
242+
243+ except Exception as e :
244+ print (f"❌ Unexpected error processing topic '{ topic } ': { str (e )} " )
245+
246+ return successful_generations
247+
248+ def evaluate (self , * args , report_type = "dict" , ** kwargs ):
249+ return super ().evaluate (* args , report_type = report_type , ** kwargs )
0 commit comments