11from typing import Dict , List
22
33from langchain_core .language_models import BaseChatModel
4- from langchain_core .runnables import RunnableSerializable , chain
4+ from langchain_core .runnables import RunnablePassthrough , RunnableSerializable , chain
55
66from aidial_rag_eval .generation .models .lambdas import json_to_list
77from aidial_rag_eval .generation .models .statement_extractor .base_statement_extractor import (
1616
1717@chain
1818def list_to_statements (
19- statements_for_each_hypothesis : List [ Dict [ str , List [ str ]]] ,
19+ llm_outputs_with_inputs : Dict ,
2020) -> List [List [str ]]:
2121 """
2222 Function is part of a chain that extracts segments from a list.
2323
2424 Parameters
2525 -----------
26- statements_for_each_hypothesis : List[ Dict[str, List[str]]]
27- The output list of dicts from the LLM with extracted statements.
26+ llm_outputs_with_inputs : Dict
27+ Passed inputs with output list of dicts from the LLM with extracted statements.
2828
2929 Returns
3030 ------------
31- List[str]
31+ List[List[ str] ]
3232 The extracted statements. if the LLM output is valid;
33- otherwise, an empty list is returned.
33+ otherwise, return an original list of hypothesis_segments,
34+ each hypothesis_segment is wrapped.
3435 """
3536 try :
37+ statements_for_hypothesis_segments = llm_outputs_with_inputs [
38+ "llm_output_statements"
39+ ]
40+ hypothesis_segments = llm_outputs_with_inputs ["hypothesis_segments" ]
41+ assert len (hypothesis_segments ) == len (statements_for_hypothesis_segments )
3642 return [
37- return_dict ["statements" ] for return_dict in statements_for_each_hypothesis
43+ return_dict ["statements" ]
44+ for return_dict in llm_outputs_with_inputs ["llm_output_statements" ]
3845 ]
3946 except (
4047 TypeError ,
4148 KeyError ,
49+ AssertionError ,
4250 ):
43- return []
51+ return [
52+ [hypothesis_segment ]
53+ for hypothesis_segment in llm_outputs_with_inputs ["hypothesis_segments" ]
54+ ]
4455
4556
4657@chain
4758def wrap_hypotheses (input_ : Dict ) -> Dict :
4859 assert type (input_ ) is dict
4960 return {
5061 "hypotheses" : [
51- f"<hypothesis{ index + 1 } > { hypothesis } </hypothesis{ index + 1 } >"
52- for index , hypothesis in enumerate (input_ ["hypotheses " ])
62+ f"<hypothesis{ index + 1 } > { hypothesis_segment } </hypothesis{ index + 1 } >"
63+ for index , hypothesis_segment in enumerate (input_ ["hypothesis_segments " ])
5364 ],
5465 }
5566
@@ -76,17 +87,19 @@ def __init__(
7687 ):
7788
7889 self ._chain = (
79- wrap_hypotheses
80- | statement_prompt
81- | model
82- | json_to_list
90+ RunnablePassthrough .assign (
91+ llm_output_statements = wrap_hypotheses
92+ | statement_prompt
93+ | model
94+ | json_to_list
95+ )
8396 | list_to_statements
8497 )
8598 self .max_concurrency = max_concurrency
8699
87100 def extract (
88101 self ,
89- hypothesis_segments : List [List [HypothesisSegment ]],
102+ list_of_hypothesis_segments : List [List [HypothesisSegment ]],
90103 show_progress_bar : bool ,
91104 ) -> List [List [List [Statement ]]]:
92105 """
@@ -95,7 +108,7 @@ def extract(
95108
96109 Parameters
97110 -----------
98- hypothesis_segments : List[List[HypothesisSegment]]
111+ list_of_hypothesis_segments : List[List[HypothesisSegment]]
99112 A list of hypothesis segments as a sources of statements.
100113
101114 show_progress_bar : bool
@@ -107,13 +120,15 @@ def extract(
107120 Returns the statements for each hypothesis segment.
108121 """
109122
110- with ProgressBarCallback (len (hypothesis_segments ), show_progress_bar ) as cb :
123+ with ProgressBarCallback (
124+ len (list_of_hypothesis_segments ), show_progress_bar
125+ ) as cb :
111126 returns = self ._chain .batch (
112127 [
113128 {
114- "hypotheses " : hypotheses ,
129+ "hypothesis_segments " : hypothesis_segments ,
115130 }
116- for hypotheses in hypothesis_segments
131+ for hypothesis_segments in list_of_hypothesis_segments
117132 ],
118133 config = {"callbacks" : [cb ], "max_concurrency" : self .max_concurrency },
119134 )
0 commit comments