Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/aidial_rag_eval/dataframe/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def calculate_generation_metrics(
metric_binds: List[MetricBind],
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = True,
) -> pd.DataFrame:
"""
Calculates RAG evaluation generation metrics from df_merged dataframe.
Expand Down Expand Up @@ -155,6 +156,7 @@ def calculate_generation_metrics(
llm=llm,
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
).to_dict(orient="series")
)
df_metrics = pd.DataFrame(data=metric_results)
Expand All @@ -178,6 +180,7 @@ def create_generation_metrics_report(
metric_binds: List[MetricBind],
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = True,
) -> pd.DataFrame:
"""
Calculates RAG evaluation generation metrics from input dataframes.
Expand Down Expand Up @@ -220,6 +223,7 @@ def create_generation_metrics_report(
metric_binds,
max_concurrency,
show_progress_bar,
auto_download_nltk,
)
return pd.merge(df_merged, df_metrics, left_index=True, right_index=True)

Expand All @@ -232,6 +236,7 @@ def create_rag_eval_metrics_report(
metric_binds: Optional[List[MetricBind]] = None,
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = True,
) -> pd.DataFrame:
"""
Calculates RAG evaluation metrics from input dataframes.
Expand Down Expand Up @@ -282,5 +287,6 @@ def create_rag_eval_metrics_report(
metric_binds,
max_concurrency,
show_progress_bar,
auto_download_nltk,
)
return pd.concat([df_merged, retrieval_metrics, generation_metrics], axis=1)
2 changes: 2 additions & 0 deletions src/aidial_rag_eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def evaluate(
metric_binds: Optional[List[MetricBind]] = None,
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = False,
) -> Dataset:
"""
Calculates RAG evaluation metrics from input
Expand Down Expand Up @@ -82,6 +83,7 @@ def evaluate(
metric_binds=metric_binds,
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
)
aggregated_metrics = df_final.mean(numeric_only=True)
assert isinstance(aggregated_metrics, pd.Series)
Expand Down
15 changes: 13 additions & 2 deletions src/aidial_rag_eval/generation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def segment_hypotheses(
llm: BaseChatModel,
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = True,
) -> List[SegmentedText]:
"""
Function that segments hypotheses into hypothesis segments(roughly into
Expand Down Expand Up @@ -215,7 +216,8 @@ def segment_hypotheses(
)

segmented_hypotheses = [
SegmentedText.from_text(text=hypothesis) for hypothesis in hypotheses
SegmentedText.from_text(text=hypothesis, auto_download_nltk=auto_download_nltk)
for hypothesis in hypotheses
]
if show_progress_bar:
print("Converting hypothesis...")
Expand Down Expand Up @@ -280,6 +282,7 @@ def infer_statements(
list_documents: Optional[List[Documents]] = None,
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = True,
) -> List[List[Tuple[InferenceInputs, InferenceScore]]]:
"""
Function that infers statements.
Expand Down Expand Up @@ -330,7 +333,10 @@ def infer_statements(
document_names = [_join_documents(docs) for docs in list_documents]
if questions is not None:
segmented_questions = [
SegmentedText.from_text(text=question) for question in questions
SegmentedText.from_text(
text=question, auto_download_nltk=auto_download_nltk
)
for question in questions
]
premises = [
question_split.segments[-1] + "\n" + premise
Expand Down Expand Up @@ -362,6 +368,7 @@ def calculate_batch_inference(
list_documents: Optional[List[Documents]] = None,
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = True,
) -> List[InferenceReturn]:
"""
Calculates pairwise the inference of a hypotheses from a premises.
Expand Down Expand Up @@ -404,6 +411,7 @@ def calculate_batch_inference(
llm=llm,
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
)
statements: List[List[List[Statement]]] = extract_statements(
list_of_hypothesis_segments=[
Expand All @@ -423,6 +431,7 @@ def calculate_batch_inference(
list_documents=list_documents,
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
)
)

Expand Down Expand Up @@ -459,6 +468,7 @@ def calculate_inference(
documents: Optional[Documents] = None,
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = True,
) -> InferenceReturn:
"""
Calculates the inference of a hypothesis from a premise.
Expand Down Expand Up @@ -504,5 +514,6 @@ def calculate_inference(
list_documents=list_documents,
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
)
return inference_returns[0]
19 changes: 14 additions & 5 deletions src/aidial_rag_eval/generation/metric_binds.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _wrapped_dataframe_inference(
document_column: Optional[str] = None,
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = True,
) -> pd.DataFrame:
inference_returns = calculate_batch_inference(
premises=_get_column_as_list_str(df_merged, premise_column),
Expand All @@ -49,6 +50,7 @@ def _wrapped_dataframe_inference(
),
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
)
return pd.DataFrame(
[vars(inference_return) for inference_return in inference_returns]
Expand All @@ -62,20 +64,22 @@ def _wrapped_dataframe_refusal(
prefix: str,
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = True,
) -> pd.DataFrame:
refusal_returns = calculate_batch_refusal(
answers=_get_column_as_list_str(df_merged, answer_column),
llm=llm,
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
)
return pd.DataFrame([vars(refusal) for refusal in refusal_returns]).add_prefix(
prefix
)


def context_to_answer_inference(
df_merged, llm, max_concurrency, show_progress_bar, **kwargs
df_merged, llm, max_concurrency, show_progress_bar, auto_download_nltk, **kwargs
) -> pd.DataFrame:
return _wrapped_dataframe_inference(
df_merged=df_merged,
Expand All @@ -93,11 +97,12 @@ def context_to_answer_inference(
document_column=MergedColumns.DOCUMENTS,
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
)


def answer_to_ground_truth_inference(
df_merged, llm, max_concurrency, show_progress_bar, **kwargs
df_merged, llm, max_concurrency, show_progress_bar, auto_download_nltk, **kwargs
) -> pd.DataFrame:
return _wrapped_dataframe_inference(
df_merged=df_merged,
Expand All @@ -109,11 +114,12 @@ def answer_to_ground_truth_inference(
document_column=MergedColumns.DOCUMENTS,
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
)


def ground_truth_to_answer_inference(
df_merged, llm, max_concurrency, show_progress_bar, **kwargs
df_merged, llm, max_concurrency, show_progress_bar, auto_download_nltk, **kwargs
) -> pd.DataFrame:
return _wrapped_dataframe_inference(
df_merged=df_merged,
Expand All @@ -125,11 +131,12 @@ def ground_truth_to_answer_inference(
document_column=MergedColumns.DOCUMENTS,
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
)


def answer_refusal(
df_merged, llm, max_concurrency, show_progress_bar, **kwargs
df_merged, llm, max_concurrency, show_progress_bar, auto_download_nltk, **kwargs
) -> pd.DataFrame:
return _wrapped_dataframe_refusal(
df_merged=df_merged,
Expand All @@ -138,11 +145,12 @@ def answer_refusal(
prefix=ANSWER_REFUSAL_PREFIX,
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
)


def ground_truth_refusal(
df_merged, llm, max_concurrency, show_progress_bar, **kwargs
df_merged, llm, max_concurrency, show_progress_bar, auto_download_nltk, **kwargs
) -> pd.DataFrame:
return _wrapped_dataframe_refusal(
df_merged=df_merged,
Expand All @@ -151,6 +159,7 @@ def ground_truth_refusal(
prefix=GT_ANSWER_REFUSAL_PREFIX,
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
)


Expand Down
8 changes: 7 additions & 1 deletion src/aidial_rag_eval/generation/refusal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def calculate_batch_refusal(
llm: BaseChatModel,
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = True,
) -> List[RefusalReturn]:
"""
Checks if the answers are answer refusal.
Expand All @@ -40,7 +41,10 @@ def calculate_batch_refusal(
Returns the list of the answer refusals.
"""
detector = LLMRefusalDetector(llm, max_concurrency)
answers_split = [SegmentedText.from_text(text=answer) for answer in answers]
answers_split = [
SegmentedText.from_text(text=answer, auto_download_nltk=auto_download_nltk)
for answer in answers
]
# As a heuristic, we send only the first 3 segments in the prompt.
# We believe that if there are 3 whole segments with information
# that is not related to refusal to answer,
Expand All @@ -60,6 +64,7 @@ def calculate_refusal(
llm: BaseChatModel,
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = True,
) -> RefusalReturn:
"""
Checks if the answer is answer refusal.
Expand Down Expand Up @@ -89,5 +94,6 @@ def calculate_refusal(
llm=llm,
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
)
return refusal_returns[0]
5 changes: 3 additions & 2 deletions src/aidial_rag_eval/generation/utils/segmented_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ def __init__(
self.delimiters = delimiters.copy()

@classmethod
def from_text(cls, text: Text) -> "SegmentedText":
nltk.download("punkt_tab", quiet=True)
def from_text(cls, text: Text, auto_download_nltk: bool = True) -> "SegmentedText":
if auto_download_nltk:
nltk.download("punkt_tab", quiet=True)

max_len = 500
min_len = 10
Expand Down
Loading