@@ -32,6 +32,7 @@ def _wrapped_dataframe_inference(
3232 document_column : Optional [str ] = None ,
3333 max_concurrency : int = 8 ,
3434 show_progress_bar : bool = True ,
35+ auto_download_nltk : bool = True ,
3536) -> pd .DataFrame :
3637 inference_returns = calculate_batch_inference (
3738 premises = _get_column_as_list_str (df_merged , premise_column ),
@@ -49,6 +50,7 @@ def _wrapped_dataframe_inference(
4950 ),
5051 max_concurrency = max_concurrency ,
5152 show_progress_bar = show_progress_bar ,
53+ auto_download_nltk = auto_download_nltk ,
5254 )
5355 return pd .DataFrame (
5456 [vars (inference_return ) for inference_return in inference_returns ]
@@ -62,20 +64,22 @@ def _wrapped_dataframe_refusal(
6264 prefix : str ,
6365 max_concurrency : int = 8 ,
6466 show_progress_bar : bool = True ,
67+ auto_download_nltk : bool = True ,
6568) -> pd .DataFrame :
6669 refusal_returns = calculate_batch_refusal (
6770 answers = _get_column_as_list_str (df_merged , answer_column ),
6871 llm = llm ,
6972 max_concurrency = max_concurrency ,
7073 show_progress_bar = show_progress_bar ,
74+ auto_download_nltk = auto_download_nltk ,
7175 )
7276 return pd .DataFrame ([vars (refusal ) for refusal in refusal_returns ]).add_prefix (
7377 prefix
7478 )
7579
7680
7781def context_to_answer_inference (
78- df_merged , llm , max_concurrency , show_progress_bar , ** kwargs
82+ df_merged , llm , max_concurrency , show_progress_bar , auto_download_nltk , ** kwargs
7983) -> pd .DataFrame :
8084 return _wrapped_dataframe_inference (
8185 df_merged = df_merged ,
@@ -93,11 +97,12 @@ def context_to_answer_inference(
9397 document_column = MergedColumns .DOCUMENTS ,
9498 max_concurrency = max_concurrency ,
9599 show_progress_bar = show_progress_bar ,
100+ auto_download_nltk = auto_download_nltk ,
96101 )
97102
98103
99104def answer_to_ground_truth_inference (
100- df_merged , llm , max_concurrency , show_progress_bar , ** kwargs
105+ df_merged , llm , max_concurrency , show_progress_bar , auto_download_nltk , ** kwargs
101106) -> pd .DataFrame :
102107 return _wrapped_dataframe_inference (
103108 df_merged = df_merged ,
@@ -109,11 +114,12 @@ def answer_to_ground_truth_inference(
109114 document_column = MergedColumns .DOCUMENTS ,
110115 max_concurrency = max_concurrency ,
111116 show_progress_bar = show_progress_bar ,
117+ auto_download_nltk = auto_download_nltk ,
112118 )
113119
114120
115121def ground_truth_to_answer_inference (
116- df_merged , llm , max_concurrency , show_progress_bar , ** kwargs
122+ df_merged , llm , max_concurrency , show_progress_bar , auto_download_nltk , ** kwargs
117123) -> pd .DataFrame :
118124 return _wrapped_dataframe_inference (
119125 df_merged = df_merged ,
@@ -125,11 +131,12 @@ def ground_truth_to_answer_inference(
125131 document_column = MergedColumns .DOCUMENTS ,
126132 max_concurrency = max_concurrency ,
127133 show_progress_bar = show_progress_bar ,
134+ auto_download_nltk = auto_download_nltk ,
128135 )
129136
130137
131138def answer_refusal (
132- df_merged , llm , max_concurrency , show_progress_bar , ** kwargs
139+ df_merged , llm , max_concurrency , show_progress_bar , auto_download_nltk , ** kwargs
133140) -> pd .DataFrame :
134141 return _wrapped_dataframe_refusal (
135142 df_merged = df_merged ,
@@ -138,11 +145,12 @@ def answer_refusal(
138145 prefix = ANSWER_REFUSAL_PREFIX ,
139146 max_concurrency = max_concurrency ,
140147 show_progress_bar = show_progress_bar ,
148+ auto_download_nltk = auto_download_nltk ,
141149 )
142150
143151
144152def ground_truth_refusal (
145- df_merged , llm , max_concurrency , show_progress_bar , ** kwargs
153+ df_merged , llm , max_concurrency , show_progress_bar , auto_download_nltk , ** kwargs
146154) -> pd .DataFrame :
147155 return _wrapped_dataframe_refusal (
148156 df_merged = df_merged ,
@@ -151,6 +159,7 @@ def ground_truth_refusal(
151159 prefix = GT_ANSWER_REFUSAL_PREFIX ,
152160 max_concurrency = max_concurrency ,
153161 show_progress_bar = show_progress_bar ,
162+ auto_download_nltk = auto_download_nltk ,
154163 )
155164
156165
0 commit comments