Skip to content

Commit 8ad7711

Browse files
authored
feat: decontextualization returns new objects without mutating inputs (#63)
1 parent e5d7852 commit 8ad7711

File tree

4 files changed

+88
-88
lines changed

4 files changed

+88
-88
lines changed

src/aidial_rag_eval/generation/inference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,10 @@ def segment_hypotheses(
219219
]
220220
if show_progress_bar:
221221
print("Converting hypothesis...")
222-
converter.transform_texts(segmented_hypotheses, show_progress_bar)
223-
return segmented_hypotheses
222+
decontextualized__segmented_hypotheses = converter.transform_texts(
223+
segmented_hypotheses, show_progress_bar
224+
)
225+
return decontextualized__segmented_hypotheses
224226

225227

226228
def extract_statements(

src/aidial_rag_eval/generation/models/converters/base_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ class SegmentConverter(ABC):
1414
@abstractmethod
1515
def transform_texts(
1616
self, segmented_texts: List[SegmentedText], show_progress_bar: bool
17-
):
17+
) -> List[SegmentedText]:
1818
pass

src/aidial_rag_eval/generation/models/converters/llm_decontextualization_converter.py

Lines changed: 60 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,28 @@
55
from langchain_core.exceptions import OutputParserException
66
from langchain_core.language_models import BaseChatModel
77
from langchain_core.messages import AIMessage
8-
from langchain_core.runnables import RunnableSerializable, chain
8+
from langchain_core.runnables import (
9+
RunnableBranch,
10+
RunnablePassthrough,
11+
RunnableSerializable,
12+
chain,
13+
)
914
from langchain_core.utils.json import parse_json_markdown
1015

1116
from aidial_rag_eval.generation.models.converters.base_converter import SegmentConverter
1217
from aidial_rag_eval.generation.models.converters.decontextualization_template import (
1318
decontextualization_prompt,
1419
)
15-
from aidial_rag_eval.generation.types import TextSegment
1620
from aidial_rag_eval.generation.utils.progress_bar import ProgressBarCallback
1721
from aidial_rag_eval.generation.utils.segmented_text import SegmentedText
1822

1923

24+
@chain
25+
def check_if_sentences_less_than_2(input_: Dict) -> bool:
26+
assert type(input_) is dict
27+
return len(input_["segmented_text"].segments) < 2
28+
29+
2030
@chain
2131
def json_to_dict_segments(input_: AIMessage) -> List[str]:
2232
"""
@@ -48,21 +58,40 @@ def json_to_dict_segments(input_: AIMessage) -> List[str]:
4858

4959

5060
@chain
51-
def sentences_to_json_list(input_: Dict) -> Dict:
61+
def segmented_text_to_json_list(input_: Dict) -> Dict:
62+
assert type(input_) is dict
63+
return {"sentences_str": json.dumps(input_["segmented_text"].segments)}
64+
65+
66+
@chain
67+
def return_original_segmented_text(input_: Dict) -> Dict:
5268
assert type(input_) is dict
53-
return {"sentences_str": json.dumps(input_["sentences"])}
69+
return input_["segmented_text"]
70+
71+
72+
@chain
73+
def dict_segments_to_segmented_text(llm_outputs_with_inputs: Dict) -> SegmentedText:
74+
original_segmented_text: SegmentedText = llm_outputs_with_inputs["segmented_text"]
75+
try:
76+
decontextualized_segments = llm_outputs_with_inputs["decontextualized_segments"]
77+
assert len(decontextualized_segments) == len(original_segmented_text.segments)
78+
return SegmentedText(
79+
decontextualized_segments, original_segmented_text.delimiters
80+
)
81+
except (TypeError, KeyError, AssertionError):
82+
return original_segmented_text
5483

5584

5685
class LLMNoPronounsConverter(SegmentConverter):
5786
"""
58-
The LLMNoPronounsBatchConverter is designed to replace pronouns
59-
in text segments using a LLM.
60-
61-
Input is a list of SegmentedText objects.
62-
If a SegmentedText object contains more than one segment,
63-
segments are sent in a prompt to the LLM.
64-
In a prompt, the first segment is used only for context,
65-
and pronoun replacement is performed only in the remaining segments.
87+
Converter that decontextualizes text segments using an LLM.
88+
89+
Takes a list of SegmentedText objects and processes each one:
90+
- If a SegmentedText has fewer than 2 segments, it is returned unchanged.
91+
- Otherwise, all segments are sent to the LLM for decontextualization.
92+
93+
The LLM replaces pronouns and context-dependent references
94+
to make each segment self-contained.
6695
"""
6796

6897
_chain: RunnableSerializable
@@ -79,24 +108,24 @@ def __init__(
79108
model: BaseChatModel,
80109
max_concurrency: int,
81110
):
82-
83-
self._chain = (
84-
sentences_to_json_list
85-
| decontextualization_prompt
86-
| model
87-
| json_to_dict_segments
111+
self._chain = RunnableBranch(
112+
(check_if_sentences_less_than_2, return_original_segmented_text),
113+
RunnablePassthrough.assign(
114+
decontextualized_segments=segmented_text_to_json_list
115+
| decontextualization_prompt
116+
| model
117+
| json_to_dict_segments
118+
)
119+
| dict_segments_to_segmented_text,
88120
)
89121
self.max_concurrency = max_concurrency
90122

91123
def transform_texts(
92124
self, segmented_texts: List[SegmentedText], show_progress_bar: bool
93-
):
125+
) -> List[SegmentedText]:
94126
"""
95127
Method that converts segmented texts by replacing pronouns using an LLM.
96-
The LLM processes segments,
97-
where the additional first segment is not converted
98-
but is provided for context to enable the conversion of the second sentence.
99-
The LLM returns converted segments.
128+
The LLM processes all segments and returns converted segments.
100129
If the invariant of the length of input and output segment batches
101130
is not maintained, the segments of this batch are not replaced.
102131
@@ -107,60 +136,20 @@ def transform_texts(
107136
108137
show_progress_bar : bool
109138
A flag that controls the display of a progress bar.
110-
"""
111-
original_segment_batches: List[List[TextSegment]] = []
112-
segment_ids: List[int] = []
113-
for text_id, segmented_text in enumerate(segmented_texts):
114-
segments = segmented_text.segments
115-
if len(segments) <= 1:
116-
continue
117-
original_segment_batches.append(segments)
118-
segment_ids.append(text_id)
119-
120-
no_pronouns_segment_batches = self._get_no_pronouns_segments(
121-
original_segment_batches, show_progress_bar
122-
)
123-
124-
for text_id, no_pronouns_segment_batch, original_segment_batch in zip(
125-
segment_ids, no_pronouns_segment_batches, original_segment_batches
126-
):
127-
if len(no_pronouns_segment_batch) != len(original_segment_batch):
128-
continue
129-
segmented_texts[text_id].replace_segments(
130-
no_pronouns_segment_batch[1:],
131-
1,
132-
)
133139
134-
def _get_no_pronouns_segments(
135-
self,
136-
original_segment_batches: List[List[TextSegment]],
137-
show_progress_bar: bool,
138-
) -> List[List[TextSegment]]:
139-
"""
140-
Method that calls _chain to replace pronouns.
141-
142-
Parameters
143-
-----------
144-
original_segment_batches : List[List[str]]
145-
Segments of texts.
146-
147-
show_progress_bar : bool
148-
A flag that controls the display of a progress bar.
149140
Returns
150-
------------
151-
List[List[str]]
152-
List of converted segments, divided into batches.
141+
-------
142+
List[SegmentedText]
143+
A list of segmented texts with decontextualized segments.
153144
"""
154-
with ProgressBarCallback(
155-
len(original_segment_batches), show_progress_bar
156-
) as cb:
157-
no_pronouns_segment_batches = self._chain.batch(
145+
with ProgressBarCallback(len(segmented_texts), show_progress_bar) as cb:
146+
decontextualized_segmented_texts = self._chain.batch(
158147
[
159148
{
160-
"sentences": batch,
149+
"segmented_text": segmented_text,
161150
}
162-
for batch in original_segment_batches
151+
for segmented_text in segmented_texts
163152
],
164153
config={"callbacks": [cb], "max_concurrency": self.max_concurrency},
165154
)
166-
return no_pronouns_segment_batches
155+
return decontextualized_segmented_texts

tests/chain_tests/test_decontextualization_chain.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,14 @@ def test_valid_json_response():
1818
segments=["John went to the store.", "He bought milk."], delimiters=[" "]
1919
)
2020

21-
converter.transform_texts([segmented_text], show_progress_bar=False)
21+
decontext_segmented_text = converter.transform_texts(
22+
[segmented_text], show_progress_bar=False
23+
)[0]
2224

23-
assert segmented_text.segments == ["John went to the store.", "John bought milk."]
25+
assert decontext_segmented_text.segments == [
26+
"John went to the store.",
27+
"John bought milk.",
28+
]
2429

2530

2631
def test_invalid_json_response():
@@ -30,11 +35,12 @@ def test_invalid_json_response():
3035
segmented_text = SegmentedText(
3136
segments=["John went to the store.", "He bought milk."], delimiters=[" "]
3237
)
33-
original_segments = segmented_text.segments.copy()
3438

35-
converter.transform_texts([segmented_text], show_progress_bar=False)
39+
decontext_segmented_text = converter.transform_texts(
40+
[segmented_text], show_progress_bar=False
41+
)[0]
3642

37-
assert segmented_text.segments == original_segments
43+
assert decontext_segmented_text.segments == segmented_text.segments
3844

3945

4046
def test_json_missing_segments_key():
@@ -46,11 +52,12 @@ def test_json_missing_segments_key():
4652
segmented_text = SegmentedText(
4753
segments=["John went to the store.", "He bought milk."], delimiters=[" "]
4854
)
49-
original_segments = segmented_text.segments.copy()
5055

51-
converter.transform_texts([segmented_text], show_progress_bar=False)
56+
decontext_segmented_text = converter.transform_texts(
57+
[segmented_text], show_progress_bar=False
58+
)[0]
5259

53-
assert segmented_text.segments == original_segments
60+
assert decontext_segmented_text.segments == segmented_text.segments
5461

5562

5663
def test_segment_count_mismatch():
@@ -60,11 +67,12 @@ def test_segment_count_mismatch():
6067
segmented_text = SegmentedText(
6168
segments=["John went to the store.", "He bought milk."], delimiters=[" "]
6269
)
63-
original_segments = segmented_text.segments.copy()
6470

65-
converter.transform_texts([segmented_text], show_progress_bar=False)
71+
decontext_segmented_text = converter.transform_texts(
72+
[segmented_text], show_progress_bar=False
73+
)[0]
6674

67-
assert segmented_text.segments == original_segments
75+
assert decontext_segmented_text.segments == segmented_text.segments
6876

6977

7078
def test_empty_response():
@@ -74,11 +82,12 @@ def test_empty_response():
7482
segmented_text = SegmentedText(
7583
segments=["John went to the store.", "He bought milk."], delimiters=[" "]
7684
)
77-
original_segments = segmented_text.segments.copy()
7885

79-
converter.transform_texts([segmented_text], show_progress_bar=False)
86+
decontext_segmented_text = converter.transform_texts(
87+
[segmented_text], show_progress_bar=False
88+
)[0]
8089

81-
assert segmented_text.segments == original_segments
90+
assert decontext_segmented_text.segments == segmented_text.segments
8291

8392

8493
def test_invoke_raises_exception():

0 commit comments

Comments
 (0)