88from dialogue2graph .pipelines .core .algorithms import DialogAugmentation
99from dialogue2graph .pipelines .core .dialogue import Dialogue
1010from dialogue2graph .pipelines .model_storage import ModelStorage
11- from dialogue2graph .metrics .no_llm_metrics .metrics import (
12- is_correct_length , match_roles
13- )
11+ from dialogue2graph .metrics .no_llm_metrics .metrics import is_correct_length , match_roles
1412
1513logging .getLogger ("langchain_core.vectorstores.base" ).setLevel (logging .ERROR )
1614
15+
1716class AugmentedTurn (BaseModel ):
1817 """Dialogue turn to augment"""
18+
1919 participant : str
20- text : list [str ] = Field (..., description = "List of utterance variations for this turn" )
20+ text : list [str ] = Field (
21+ ..., description = "List of utterance variations for this turn"
22+ )
23+
2124
2225class DialogueSequence (BaseModel ):
2326 """Result as dialogue sequence"""
27+
2428 result : list [AugmentedTurn ] = Field (..., description = "Sequence of augmented turns" )
2529
2630
2731class DialogueAugmenter (DialogAugmentation ):
2832 """Class for dialogue augmentation.
29-
33+
3034 Augments dialogues while preserving structure and conversation flow by rephrasing original dialogue lines."""
31-
35+
3236 model_storage : ModelStorage = Field (..., description = "Model storage instance" )
3337 generation_llm : str = Field (..., description = "Key for generation LLM in storage" )
3438 formatting_llm : str = Field (..., description = "Key for formatting LLM in storage" )
@@ -40,33 +44,34 @@ def invoke(
4044 topic : str = "" ,
4145 ) -> Union [list [Dialogue ], str ]:
4246 """Augment dialogue while preserving conversation structure.
43-
47+
4448 Args:
4549 dialogue: Input Dialogue object to augment
4650 prompt: Required augmentation prompt template
4751 topic: Contextual topic for augmentation (default: empty)
48-
52+
4953 Returns:
5054 List of augmented Dialogue objects or error message
5155 """
52- if prompt == '' :
53- return ' Preprocessing failed: prompt should be a valid instruction for LLM'
54-
56+ if prompt == "" :
57+ return " Preprocessing failed: prompt should be a valid instruction for LLM"
58+
5559 try :
5660 message_dicts = [msg .model_dump () for msg in dialogue .messages ]
5761 if message_dicts == []:
58- return ' Preprocessing failed: no messages found in the dialogue'
59-
62+ return " Preprocessing failed: no messages found in the dialogue"
63+
6064 augmentation_prompt = PromptTemplate .from_template (prompt )
6165 parser = JsonOutputParser (pydantic_object = DialogueSequence )
62-
66+
6367 fixed_parser = OutputFixingParser .from_llm (
64- parser = parser ,
65- llm = self ._get_llm (self .formatting_llm )
68+ parser = parser , llm = self ._get_llm (self .formatting_llm )
69+ )
70+
71+ chain = (
72+ augmentation_prompt | self ._get_llm (self .generation_llm ) | fixed_parser
6673 )
6774
68- chain = augmentation_prompt | self ._get_llm (self .generation_llm ) | fixed_parser
69-
7075 for attempt in range (3 ):
7176 try :
7277 result = chain .invoke ({"topic" : topic , "dialogue" : message_dicts })
@@ -76,58 +81,55 @@ def invoke(
7681 except Exception as e :
7782 logging .error (f"Error creating dialogues: { str (e )} " )
7883 return f"Post-processing failed: { str (e )} "
79-
84+
8085 except ValidationError as ve :
81- logging .warning (f"Validation error attempt { attempt + 1 } : { ve } " )
86+ logging .warning (f"Validation error attempt { attempt + 1 } : { ve } " )
8287
8388 except Exception as e :
8489 logging .error (f"Unexpected error: { str (e )} " )
8590 if attempt == 2 :
8691 return f"Augmentation failed: { str (e )} "
87-
92+
8893 return "Augmentation failed after 3 attempts"
89-
94+
9095 except Exception as e :
9196 logging .exception ("Critical error in augmentation pipeline" )
9297 return f"Critical error: { str (e )} "
9398
9499 async def ainvoke (self , * args , ** kwargs ):
95100 """Async version of invoke"""
96101 return self .invoke (* args , ** kwargs )
97-
98- async def evaluate (
99- self ,
100- dialogue : Dialogue ,
101- prompt : str ,
102- topic : str = ""
103- ) -> dict :
102+
103+ async def evaluate (self , dialogue : Dialogue , prompt : str , topic : str = "" ) -> dict :
104104 """Evaluate augmentation quality with dictionary report format."""
105105 result = self .invoke (dialogue , prompt , topic )
106-
106+
107107 if isinstance (result , str ):
108108 return {"error" : result }
109-
110- report = {}
109+
110+ report = {}
111111 for i , augmented_dialogue in enumerate (result ):
112- try :
113- report [f' augmented_dialogue_{ i } ' ] = {
112+ try :
113+ report [f" augmented_dialogue_{ i } " ] = {
114114 "match_roles" : match_roles (dialogue , augmented_dialogue ),
115- "correct_length" : is_correct_length (dialogue , augmented_dialogue )
115+ "correct_length" : is_correct_length (dialogue , augmented_dialogue ),
116116 }
117117 except Exception as e :
118- logging .error (f"Error while calculating metrics: { str (e )} " )
118+ logging .error (f"Error while calculating metrics: { str (e )} " )
119119 return report
120120
121121 def _get_llm (self , llm_key : str ):
122122 """Get model from model storage safely"""
123123 if llm_key not in self .model_storage .storage :
124124 raise ValueError (f"LLM key '{ llm_key } ' not found in model storage" )
125125 return self .model_storage .storage [llm_key ].model
126-
127- def _combine_one_dialogue (self , augmentation_result : DialogueSequence , i : int ) -> dict :
126+
127+ def _combine_one_dialogue (
128+ self , augmentation_result : DialogueSequence , i : int
129+ ) -> dict :
128130 """Combine new augmented dialogues from utterance variations"""
129131 new_augmented_dialogue = {}
130- new_augmented_dialogue [' messages' ] = []
132+ new_augmented_dialogue [" messages" ] = []
131133 roles_to_add = [turn .participant for turn in augmentation_result .result ]
132134 utterances_to_add = [turn .text [i ] for turn in augmentation_result .result ]
133135
@@ -139,13 +141,13 @@ def _combine_one_dialogue(self, augmentation_result: DialogueSequence, i: int) -
139141
140142 return new_augmented_dialogue
141143
142- def _create_dialogues (self , result : dict ) -> list [Dialogue ]:
144+ def _create_dialogues (self , result : dict ) -> list [Dialogue ]:
143145 """Create a list of Dialogue objects"""
144146 try :
145147 augmentation_result = DialogueSequence (result = result )
146148 except Exception as e :
147149 logging .error (f"Wrong type of augmentation result: { str (e )} " )
148- return f"Creating a list of Dialogue objects failed: { str (e )} "
150+ return f"Creating a list of Dialogue objects failed: { str (e )} "
149151
150152 utterances_lists = [turn .text for turn in augmentation_result .result ]
151153 lens = [len (uttr_list ) for uttr_list in utterances_lists ]
@@ -154,5 +156,8 @@ def _create_dialogues(self, result: dict) -> list[Dialogue]:
154156 for i in range (min (lens )):
155157 new_augmented_dialogue = self ._combine_one_dialogue (augmentation_result , i )
156158 augmented_dialogues .append (new_augmented_dialogue )
157-
158- return [Dialogue .from_list (new_augmented_dialogue ['messages' ]) for new_augmented_dialogue in augmented_dialogues ]
159+
160+ return [
161+ Dialogue .from_list (new_augmented_dialogue ["messages" ])
162+ for new_augmented_dialogue in augmented_dialogues
163+ ]
0 commit comments