55from langchain_core .exceptions import OutputParserException
66from langchain_core .language_models import BaseChatModel
77from langchain_core .messages import AIMessage
8- from langchain_core .runnables import RunnableSerializable , chain
8+ from langchain_core .runnables import (
9+ RunnableBranch ,
10+ RunnablePassthrough ,
11+ RunnableSerializable ,
12+ chain ,
13+ )
914from langchain_core .utils .json import parse_json_markdown
1015
1116from aidial_rag_eval .generation .models .converters .base_converter import SegmentConverter
1217from aidial_rag_eval .generation .models .converters .decontextualization_template import (
1318 decontextualization_prompt ,
1419)
15- from aidial_rag_eval .generation .types import TextSegment
1620from aidial_rag_eval .generation .utils .progress_bar import ProgressBarCallback
1721from aidial_rag_eval .generation .utils .segmented_text import SegmentedText
1822
1923
24+ @chain
25+ def check_if_sentences_less_than_2 (input_ : Dict ) -> bool :
26+ assert type (input_ ) is dict
27+ return len (input_ ["segmented_text" ].segments ) < 2
28+
29+
2030@chain
2131def json_to_dict_segments (input_ : AIMessage ) -> List [str ]:
2232 """
@@ -48,21 +58,40 @@ def json_to_dict_segments(input_: AIMessage) -> List[str]:
4858
4959
5060@chain
51- def sentences_to_json_list (input_ : Dict ) -> Dict :
61+ def segmented_text_to_json_list (input_ : Dict ) -> Dict :
62+ assert type (input_ ) is dict
63+ return {"sentences_str" : json .dumps (input_ ["segmented_text" ].segments )}
64+
65+
66+ @chain
67+ def return_original_segmented_text (input_ : Dict ) -> Dict :
5268 assert type (input_ ) is dict
53- return {"sentences_str" : json .dumps (input_ ["sentences" ])}
69+ return input_ ["segmented_text" ]
70+
71+
72+ @chain
73+ def dict_segments_to_segmented_text (llm_outputs_with_inputs : Dict ) -> SegmentedText :
74+ original_segmented_text : SegmentedText = llm_outputs_with_inputs ["segmented_text" ]
75+ try :
76+ decontextualized_segments = llm_outputs_with_inputs ["decontextualized_segments" ]
77+ assert len (decontextualized_segments ) == len (original_segmented_text .segments )
78+ return SegmentedText (
79+ decontextualized_segments , original_segmented_text .delimiters
80+ )
81+ except (TypeError , KeyError , AssertionError ):
82+ return original_segmented_text
5483
5584
5685class LLMNoPronounsConverter (SegmentConverter ):
5786 """
58- The LLMNoPronounsBatchConverter is designed to replace pronouns
59- in text segments using a LLM.
60-
61- Input is a list of SegmentedText objects .
62- If a SegmentedText object contains more than one segment,
63- segments are sent in a prompt to the LLM.
64- In a prompt, the first segment is used only for context,
65- and pronoun replacement is performed only in the remaining segments .
87+ Converter that decontextualizes text segments using an LLM.
88+
89+ Takes a list of SegmentedText objects and processes each one:
90+ - If a SegmentedText has fewer than 2 segments, it is returned unchanged .
91+ - Otherwise, all segments are sent to the LLM for decontextualization.
92+
93+ The LLM replaces pronouns and context-dependent references
94+ to make each segment self-contained .
6695 """
6796
6897 _chain : RunnableSerializable
@@ -79,24 +108,24 @@ def __init__(
79108 model : BaseChatModel ,
80109 max_concurrency : int ,
81110 ):
82-
83- self ._chain = (
84- sentences_to_json_list
85- | decontextualization_prompt
86- | model
87- | json_to_dict_segments
111+ self ._chain = RunnableBranch (
112+ (check_if_sentences_less_than_2 , return_original_segmented_text ),
113+ RunnablePassthrough .assign (
114+ decontextualized_segments = segmented_text_to_json_list
115+ | decontextualization_prompt
116+ | model
117+ | json_to_dict_segments
118+ )
119+ | dict_segments_to_segmented_text ,
88120 )
89121 self .max_concurrency = max_concurrency
90122
91123 def transform_texts (
92124 self , segmented_texts : List [SegmentedText ], show_progress_bar : bool
93- ):
125+ ) -> List [ SegmentedText ] :
94126 """
95127 Method that converts segmented texts by replacing pronouns using an LLM.
96- The LLM processes segments,
97- where the additional first segment is not converted
98- but is provided for context to enable the conversion of the second sentence.
99- The LLM returns converted segments.
128+ The LLM processes all segments and returns converted segments.
100129 If the invariant of the length of input and output segment batches
101130 is not maintained, the segments of this batch are not replaced.
102131
@@ -107,60 +136,20 @@ def transform_texts(
107136
108137 show_progress_bar : bool
109138 A flag that controls the display of a progress bar.
110- """
111- original_segment_batches : List [List [TextSegment ]] = []
112- segment_ids : List [int ] = []
113- for text_id , segmented_text in enumerate (segmented_texts ):
114- segments = segmented_text .segments
115- if len (segments ) <= 1 :
116- continue
117- original_segment_batches .append (segments )
118- segment_ids .append (text_id )
119-
120- no_pronouns_segment_batches = self ._get_no_pronouns_segments (
121- original_segment_batches , show_progress_bar
122- )
123-
124- for text_id , no_pronouns_segment_batch , original_segment_batch in zip (
125- segment_ids , no_pronouns_segment_batches , original_segment_batches
126- ):
127- if len (no_pronouns_segment_batch ) != len (original_segment_batch ):
128- continue
129- segmented_texts [text_id ].replace_segments (
130- no_pronouns_segment_batch [1 :],
131- 1 ,
132- )
133139
134- def _get_no_pronouns_segments (
135- self ,
136- original_segment_batches : List [List [TextSegment ]],
137- show_progress_bar : bool ,
138- ) -> List [List [TextSegment ]]:
139- """
140- Method that calls _chain to replace pronouns.
141-
142- Parameters
143- -----------
144- original_segment_batches : List[List[str]]
145- Segments of texts.
146-
147- show_progress_bar : bool
148- A flag that controls the display of a progress bar.
149140 Returns
150- ------------
151- List[List[str] ]
152- List of converted segments, divided into batches .
141+ -------
142+ List[SegmentedText ]
143+ A list of segmented texts with decontextualized segments .
153144 """
154- with ProgressBarCallback (
155- len (original_segment_batches ), show_progress_bar
156- ) as cb :
157- no_pronouns_segment_batches = self ._chain .batch (
145+ with ProgressBarCallback (len (segmented_texts ), show_progress_bar ) as cb :
146+ decontextualized_segmented_texts = self ._chain .batch (
158147 [
159148 {
160- "sentences " : batch ,
149+ "segmented_text " : segmented_text ,
161150 }
162- for batch in original_segment_batches
151+ for segmented_text in segmented_texts
163152 ],
164153 config = {"callbacks" : [cb ], "max_concurrency" : self .max_concurrency },
165154 )
166- return no_pronouns_segment_batches
155+ return decontextualized_segmented_texts
0 commit comments