Skip to content

Commit bfe7df6

Browse files
authored
feat: add opt-in to skip NLTK downloads (#79)
1 parent 8ad7711 commit bfe7df6

File tree

6 files changed

+45
-10
lines changed

6 files changed

+45
-10
lines changed

src/aidial_rag_eval/dataframe/metrics.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def calculate_generation_metrics(
114114
metric_binds: List[MetricBind],
115115
max_concurrency: int = 8,
116116
show_progress_bar: bool = True,
117+
auto_download_nltk: bool = True,
117118
) -> pd.DataFrame:
118119
"""
119120
Calculates RAG evaluation generation metrics from df_merged dataframe.
@@ -155,6 +156,7 @@ def calculate_generation_metrics(
155156
llm=llm,
156157
max_concurrency=max_concurrency,
157158
show_progress_bar=show_progress_bar,
159+
auto_download_nltk=auto_download_nltk,
158160
).to_dict(orient="series")
159161
)
160162
df_metrics = pd.DataFrame(data=metric_results)
@@ -178,6 +180,7 @@ def create_generation_metrics_report(
178180
metric_binds: List[MetricBind],
179181
max_concurrency: int = 8,
180182
show_progress_bar: bool = True,
183+
auto_download_nltk: bool = True,
181184
) -> pd.DataFrame:
182185
"""
183186
Calculates RAG evaluation generation metrics from input dataframes.
@@ -220,6 +223,7 @@ def create_generation_metrics_report(
220223
metric_binds,
221224
max_concurrency,
222225
show_progress_bar,
226+
auto_download_nltk,
223227
)
224228
return pd.merge(df_merged, df_metrics, left_index=True, right_index=True)
225229

@@ -232,6 +236,7 @@ def create_rag_eval_metrics_report(
232236
metric_binds: Optional[List[MetricBind]] = None,
233237
max_concurrency: int = 8,
234238
show_progress_bar: bool = True,
239+
auto_download_nltk: bool = True,
235240
) -> pd.DataFrame:
236241
"""
237242
Calculates RAG evaluation metrics from input dataframes.
@@ -282,5 +287,6 @@ def create_rag_eval_metrics_report(
282287
metric_binds,
283288
max_concurrency,
284289
show_progress_bar,
290+
auto_download_nltk,
285291
)
286292
return pd.concat([df_merged, retrieval_metrics, generation_metrics], axis=1)

src/aidial_rag_eval/evaluate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def evaluate(
2121
metric_binds: Optional[List[MetricBind]] = None,
2222
max_concurrency: int = 8,
2323
show_progress_bar: bool = True,
24+
auto_download_nltk: bool = False,
2425
) -> Dataset:
2526
"""
2627
Calculates RAG evaluation metrics from input
@@ -82,6 +83,7 @@ def evaluate(
8283
metric_binds=metric_binds,
8384
max_concurrency=max_concurrency,
8485
show_progress_bar=show_progress_bar,
86+
auto_download_nltk=auto_download_nltk,
8587
)
8688
aggregated_metrics = df_final.mean(numeric_only=True)
8789
assert isinstance(aggregated_metrics, pd.Series)

src/aidial_rag_eval/generation/inference.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def segment_hypotheses(
184184
llm: BaseChatModel,
185185
max_concurrency: int = 8,
186186
show_progress_bar: bool = True,
187+
auto_download_nltk: bool = True,
187188
) -> List[SegmentedText]:
188189
"""
189190
Function that segments hypotheses into hypothesis segments(roughly into
@@ -215,7 +216,8 @@ def segment_hypotheses(
215216
)
216217

217218
segmented_hypotheses = [
218-
SegmentedText.from_text(text=hypothesis) for hypothesis in hypotheses
219+
SegmentedText.from_text(text=hypothesis, auto_download_nltk=auto_download_nltk)
220+
for hypothesis in hypotheses
219221
]
220222
if show_progress_bar:
221223
print("Converting hypothesis...")
@@ -280,6 +282,7 @@ def infer_statements(
280282
list_documents: Optional[List[Documents]] = None,
281283
max_concurrency: int = 8,
282284
show_progress_bar: bool = True,
285+
auto_download_nltk: bool = True,
283286
) -> List[List[Tuple[InferenceInputs, InferenceScore]]]:
284287
"""
285288
Function that infers statements.
@@ -330,7 +333,10 @@ def infer_statements(
330333
document_names = [_join_documents(docs) for docs in list_documents]
331334
if questions is not None:
332335
segmented_questions = [
333-
SegmentedText.from_text(text=question) for question in questions
336+
SegmentedText.from_text(
337+
text=question, auto_download_nltk=auto_download_nltk
338+
)
339+
for question in questions
334340
]
335341
premises = [
336342
question_split.segments[-1] + "\n" + premise
@@ -362,6 +368,7 @@ def calculate_batch_inference(
362368
list_documents: Optional[List[Documents]] = None,
363369
max_concurrency: int = 8,
364370
show_progress_bar: bool = True,
371+
auto_download_nltk: bool = True,
365372
) -> List[InferenceReturn]:
366373
"""
367374
Calculates pairwise the inference of a hypotheses from a premises.
@@ -404,6 +411,7 @@ def calculate_batch_inference(
404411
llm=llm,
405412
max_concurrency=max_concurrency,
406413
show_progress_bar=show_progress_bar,
414+
auto_download_nltk=auto_download_nltk,
407415
)
408416
statements: List[List[List[Statement]]] = extract_statements(
409417
list_of_hypothesis_segments=[
@@ -423,6 +431,7 @@ def calculate_batch_inference(
423431
list_documents=list_documents,
424432
max_concurrency=max_concurrency,
425433
show_progress_bar=show_progress_bar,
434+
auto_download_nltk=auto_download_nltk,
426435
)
427436
)
428437

@@ -459,6 +468,7 @@ def calculate_inference(
459468
documents: Optional[Documents] = None,
460469
max_concurrency: int = 8,
461470
show_progress_bar: bool = True,
471+
auto_download_nltk: bool = True,
462472
) -> InferenceReturn:
463473
"""
464474
Calculates the inference of a hypothesis from a premise.
@@ -504,5 +514,6 @@ def calculate_inference(
504514
list_documents=list_documents,
505515
max_concurrency=max_concurrency,
506516
show_progress_bar=show_progress_bar,
517+
auto_download_nltk=auto_download_nltk,
507518
)
508519
return inference_returns[0]

src/aidial_rag_eval/generation/metric_binds.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def _wrapped_dataframe_inference(
3232
document_column: Optional[str] = None,
3333
max_concurrency: int = 8,
3434
show_progress_bar: bool = True,
35+
auto_download_nltk: bool = True,
3536
) -> pd.DataFrame:
3637
inference_returns = calculate_batch_inference(
3738
premises=_get_column_as_list_str(df_merged, premise_column),
@@ -49,6 +50,7 @@ def _wrapped_dataframe_inference(
4950
),
5051
max_concurrency=max_concurrency,
5152
show_progress_bar=show_progress_bar,
53+
auto_download_nltk=auto_download_nltk,
5254
)
5355
return pd.DataFrame(
5456
[vars(inference_return) for inference_return in inference_returns]
@@ -62,20 +64,22 @@ def _wrapped_dataframe_refusal(
6264
prefix: str,
6365
max_concurrency: int = 8,
6466
show_progress_bar: bool = True,
67+
auto_download_nltk: bool = True,
6568
) -> pd.DataFrame:
6669
refusal_returns = calculate_batch_refusal(
6770
answers=_get_column_as_list_str(df_merged, answer_column),
6871
llm=llm,
6972
max_concurrency=max_concurrency,
7073
show_progress_bar=show_progress_bar,
74+
auto_download_nltk=auto_download_nltk,
7175
)
7276
return pd.DataFrame([vars(refusal) for refusal in refusal_returns]).add_prefix(
7377
prefix
7478
)
7579

7680

7781
def context_to_answer_inference(
78-
df_merged, llm, max_concurrency, show_progress_bar, **kwargs
82+
df_merged, llm, max_concurrency, show_progress_bar, auto_download_nltk, **kwargs
7983
) -> pd.DataFrame:
8084
return _wrapped_dataframe_inference(
8185
df_merged=df_merged,
@@ -93,11 +97,12 @@ def context_to_answer_inference(
9397
document_column=MergedColumns.DOCUMENTS,
9498
max_concurrency=max_concurrency,
9599
show_progress_bar=show_progress_bar,
100+
auto_download_nltk=auto_download_nltk,
96101
)
97102

98103

99104
def answer_to_ground_truth_inference(
100-
df_merged, llm, max_concurrency, show_progress_bar, **kwargs
105+
df_merged, llm, max_concurrency, show_progress_bar, auto_download_nltk, **kwargs
101106
) -> pd.DataFrame:
102107
return _wrapped_dataframe_inference(
103108
df_merged=df_merged,
@@ -109,11 +114,12 @@ def answer_to_ground_truth_inference(
109114
document_column=MergedColumns.DOCUMENTS,
110115
max_concurrency=max_concurrency,
111116
show_progress_bar=show_progress_bar,
117+
auto_download_nltk=auto_download_nltk,
112118
)
113119

114120

115121
def ground_truth_to_answer_inference(
116-
df_merged, llm, max_concurrency, show_progress_bar, **kwargs
122+
df_merged, llm, max_concurrency, show_progress_bar, auto_download_nltk, **kwargs
117123
) -> pd.DataFrame:
118124
return _wrapped_dataframe_inference(
119125
df_merged=df_merged,
@@ -125,11 +131,12 @@ def ground_truth_to_answer_inference(
125131
document_column=MergedColumns.DOCUMENTS,
126132
max_concurrency=max_concurrency,
127133
show_progress_bar=show_progress_bar,
134+
auto_download_nltk=auto_download_nltk,
128135
)
129136

130137

131138
def answer_refusal(
132-
df_merged, llm, max_concurrency, show_progress_bar, **kwargs
139+
df_merged, llm, max_concurrency, show_progress_bar, auto_download_nltk, **kwargs
133140
) -> pd.DataFrame:
134141
return _wrapped_dataframe_refusal(
135142
df_merged=df_merged,
@@ -138,11 +145,12 @@ def answer_refusal(
138145
prefix=ANSWER_REFUSAL_PREFIX,
139146
max_concurrency=max_concurrency,
140147
show_progress_bar=show_progress_bar,
148+
auto_download_nltk=auto_download_nltk,
141149
)
142150

143151

144152
def ground_truth_refusal(
145-
df_merged, llm, max_concurrency, show_progress_bar, **kwargs
153+
df_merged, llm, max_concurrency, show_progress_bar, auto_download_nltk, **kwargs
146154
) -> pd.DataFrame:
147155
return _wrapped_dataframe_refusal(
148156
df_merged=df_merged,
@@ -151,6 +159,7 @@ def ground_truth_refusal(
151159
prefix=GT_ANSWER_REFUSAL_PREFIX,
152160
max_concurrency=max_concurrency,
153161
show_progress_bar=show_progress_bar,
162+
auto_download_nltk=auto_download_nltk,
154163
)
155164

156165

src/aidial_rag_eval/generation/refusal.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def calculate_batch_refusal(
1515
llm: BaseChatModel,
1616
max_concurrency: int = 8,
1717
show_progress_bar: bool = True,
18+
auto_download_nltk: bool = True,
1819
) -> List[RefusalReturn]:
1920
"""
2021
Checks if the answers are answer refusal.
@@ -40,7 +41,10 @@ def calculate_batch_refusal(
4041
Returns the list of the answer refusals.
4142
"""
4243
detector = LLMRefusalDetector(llm, max_concurrency)
43-
answers_split = [SegmentedText.from_text(text=answer) for answer in answers]
44+
answers_split = [
45+
SegmentedText.from_text(text=answer, auto_download_nltk=auto_download_nltk)
46+
for answer in answers
47+
]
4448
# As a heuristic, we send only the first 3 segments in the prompt.
4549
# We believe that if there are 3 whole segments with information
4650
# that is not related to refusal to answer,
@@ -60,6 +64,7 @@ def calculate_refusal(
6064
llm: BaseChatModel,
6165
max_concurrency: int = 8,
6266
show_progress_bar: bool = True,
67+
auto_download_nltk: bool = True,
6368
) -> RefusalReturn:
6469
"""
6570
Checks if the answer is answer refusal.
@@ -89,5 +94,6 @@ def calculate_refusal(
8994
llm=llm,
9095
max_concurrency=max_concurrency,
9196
show_progress_bar=show_progress_bar,
97+
auto_download_nltk=auto_download_nltk,
9298
)
9399
return refusal_returns[0]

src/aidial_rag_eval/generation/utils/segmented_text.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ def __init__(
8080
self.delimiters = delimiters.copy()
8181

8282
@classmethod
83-
def from_text(cls, text: Text) -> "SegmentedText":
84-
nltk.download("punkt_tab", quiet=True)
83+
def from_text(cls, text: Text, auto_download_nltk: bool = True) -> "SegmentedText":
84+
if auto_download_nltk:
85+
nltk.download("punkt_tab", quiet=True)
8586

8687
max_len = 500
8788
min_len = 10

0 commit comments

Comments
 (0)