Skip to content

Commit 8d6a9ea

Browse files
authored
feat: сhanged error handling in the statement extractor (#61)
1 parent 411410e commit 8d6a9ea

File tree

3 files changed

+39
-24
lines changed

3 files changed

+39
-24
lines changed

src/aidial_rag_eval/generation/inference.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def segment_hypotheses(
224224

225225

226226
def extract_statements(
227-
hypotheses_segments: List[List[HypothesisSegment]],
227+
list_of_hypothesis_segments: List[List[HypothesisSegment]],
228228
llm: BaseChatModel,
229229
max_concurrency: int = 8,
230230
show_progress_bar: bool = True,
@@ -237,7 +237,7 @@ def extract_statements(
237237
Parameters
238238
-----------
239239
240-
hypotheses_segments : List[List[HypothesisSegment]]
240+
list_of_hypothesis_segments : List[List[HypothesisSegment]]
241241
Nested list of hypothesis segments.
242242
243243
llm : BaseChatModel
@@ -264,7 +264,7 @@ def extract_statements(
264264
if show_progress_bar:
265265
print("Extracting statements...")
266266
statements = extractor.extract(
267-
hypotheses_segments,
267+
list_of_hypothesis_segments,
268268
show_progress_bar,
269269
)
270270
return statements
@@ -404,7 +404,7 @@ def calculate_batch_inference(
404404
show_progress_bar=show_progress_bar,
405405
)
406406
statements: List[List[List[Statement]]] = extract_statements(
407-
hypotheses_segments=[
407+
list_of_hypothesis_segments=[
408408
segmented_hypothesis.segments
409409
for segmented_hypothesis in segmented_hypotheses
410410
],

src/aidial_rag_eval/generation/models/statement_extractor/base_statement_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class StatementExtractor(ABC):
1414
@abstractmethod
1515
def extract(
1616
self,
17-
hypothesis_segments: List[List[HypothesisSegment]],
17+
list_of_hypothesis_segments: List[List[HypothesisSegment]],
1818
show_progress_bar: bool,
1919
) -> List[List[List[Statement]]]:
2020
pass

src/aidial_rag_eval/generation/models/statement_extractor/llm_statement_extractor.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Dict, List
22

33
from langchain_core.language_models import BaseChatModel
4-
from langchain_core.runnables import RunnableSerializable, chain
4+
from langchain_core.runnables import RunnablePassthrough, RunnableSerializable, chain
55

66
from aidial_rag_eval.generation.models.lambdas import json_to_list
77
from aidial_rag_eval.generation.models.statement_extractor.base_statement_extractor import (
@@ -16,40 +16,51 @@
1616

1717
@chain
1818
def list_to_statements(
19-
statements_for_each_hypothesis: List[Dict[str, List[str]]],
19+
llm_outputs_with_inputs: Dict,
2020
) -> List[List[str]]:
2121
"""
2222
Function is part of a chain that extracts segments from a list.
2323
2424
Parameters
2525
-----------
26-
statements_for_each_hypothesis : List[Dict[str, List[str]]]
27-
The output list of dicts from the LLM with extracted statements.
26+
llm_outputs_with_inputs : Dict
27+
Passed inputs with output list of dicts from the LLM with extracted statements.
2828
2929
Returns
3030
------------
31-
List[str]
31+
List[List[str]]
3232
The extracted statements. if the LLM output is valid;
33-
otherwise, an empty list is returned.
33+
otherwise, return an original list of hypothesis_segments,
34+
each hypothesis_segment is wrapped.
3435
"""
3536
try:
37+
statements_for_hypothesis_segments = llm_outputs_with_inputs[
38+
"llm_output_statements"
39+
]
40+
hypothesis_segments = llm_outputs_with_inputs["hypothesis_segments"]
41+
assert len(hypothesis_segments) == len(statements_for_hypothesis_segments)
3642
return [
37-
return_dict["statements"] for return_dict in statements_for_each_hypothesis
43+
return_dict["statements"]
44+
for return_dict in llm_outputs_with_inputs["llm_output_statements"]
3845
]
3946
except (
4047
TypeError,
4148
KeyError,
49+
AssertionError,
4250
):
43-
return []
51+
return [
52+
[hypothesis_segment]
53+
for hypothesis_segment in llm_outputs_with_inputs["hypothesis_segments"]
54+
]
4455

4556

4657
@chain
4758
def wrap_hypotheses(input_: Dict) -> Dict:
4859
assert type(input_) is dict
4960
return {
5061
"hypotheses": [
51-
f"<hypothesis{index + 1}> {hypothesis} </hypothesis{index + 1}>"
52-
for index, hypothesis in enumerate(input_["hypotheses"])
62+
f"<hypothesis{index + 1}> {hypothesis_segment} </hypothesis{index + 1}>"
63+
for index, hypothesis_segment in enumerate(input_["hypothesis_segments"])
5364
],
5465
}
5566

@@ -76,17 +87,19 @@ def __init__(
7687
):
7788

7889
self._chain = (
79-
wrap_hypotheses
80-
| statement_prompt
81-
| model
82-
| json_to_list
90+
RunnablePassthrough.assign(
91+
llm_output_statements=wrap_hypotheses
92+
| statement_prompt
93+
| model
94+
| json_to_list
95+
)
8396
| list_to_statements
8497
)
8598
self.max_concurrency = max_concurrency
8699

87100
def extract(
88101
self,
89-
hypothesis_segments: List[List[HypothesisSegment]],
102+
list_of_hypothesis_segments: List[List[HypothesisSegment]],
90103
show_progress_bar: bool,
91104
) -> List[List[List[Statement]]]:
92105
"""
@@ -95,7 +108,7 @@ def extract(
95108
96109
Parameters
97110
-----------
98-
hypothesis_segments : List[List[HypothesisSegment]]
111+
list_of_hypothesis_segments : List[List[HypothesisSegment]]
99112
A list of hypothesis segments as a sources of statements.
100113
101114
show_progress_bar : bool
@@ -107,13 +120,15 @@ def extract(
107120
Returns the statements for each hypothesis segment.
108121
"""
109122

110-
with ProgressBarCallback(len(hypothesis_segments), show_progress_bar) as cb:
123+
with ProgressBarCallback(
124+
len(list_of_hypothesis_segments), show_progress_bar
125+
) as cb:
111126
returns = self._chain.batch(
112127
[
113128
{
114-
"hypotheses": hypotheses,
129+
"hypothesis_segments": hypothesis_segments,
115130
}
116-
for hypotheses in hypothesis_segments
131+
for hypothesis_segments in list_of_hypothesis_segments
117132
],
118133
config={"callbacks": [cb], "max_concurrency": self.max_concurrency},
119134
)

0 commit comments

Comments
 (0)