-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrefusal.py
More file actions
99 lines (81 loc) · 2.85 KB
/
refusal.py
File metadata and controls
99 lines (81 loc) · 2.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from typing import List
from langchain_core.language_models import BaseChatModel
from aidial_rag_eval.generation.models.refusal_detectors.llm_refusal_detector import (
LLMRefusalDetector,
)
from aidial_rag_eval.generation.types import RefusalReturn
from aidial_rag_eval.generation.utils.segmented_text import SegmentedText
from aidial_rag_eval.types import Answer
def calculate_batch_refusal(
answers: List[Answer],
llm: BaseChatModel,
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = True,
) -> List[RefusalReturn]:
"""
Checks if the answers are answer refusal.
Parameters
-----------
answers : List[str]
The list of the answers.
llm : BaseChatModel
The Langchain chat model used for calculating inference.
max_concurrency : int, default=8
The maximum number of concurrent requests to the LLM.
show_progress_bar : bool, default=True
Whether to display a progress bar during LLM requests.
Returns
------------
RefusalReturn
Returns the list of the answer refusals.
"""
detector = LLMRefusalDetector(llm, max_concurrency)
answers_split = [
SegmentedText.from_text(text=answer, auto_download_nltk=auto_download_nltk)
for answer in answers
]
# As a heuristic, we send only the first 3 segments in the prompt.
# We believe that if there are 3 whole segments with information
# that is not related to refusal to answer,
# we will not consider such a response as a refusal to answer
# in any case.
first_answers_sentences = [
answers_split[i].get_joined_segments_by_range(0, 3)
for i in range(len(answers_split))
]
if show_progress_bar:
print("Getting refusal...")
return detector.get_refusal(first_answers_sentences, show_progress_bar)
def calculate_refusal(
answer: Answer,
llm: BaseChatModel,
max_concurrency: int = 8,
show_progress_bar: bool = True,
auto_download_nltk: bool = True,
) -> RefusalReturn:
"""
Checks if the answer is answer refusal.
Parameters
-----------
answer : str
The text of the answer.
llm : BaseChatModel
The Langchain chat model used for calculating inference.
max_concurrency : int, default=8
The maximum number of concurrent requests to the LLM.
show_progress_bar : bool, default=True
Whether to display a progress bar during LLM requests.
Returns
------------
RefusalReturn
Returns the answer refusal.
"""
refusal_returns = calculate_batch_refusal(
answers=[answer],
llm=llm,
max_concurrency=max_concurrency,
show_progress_bar=show_progress_bar,
auto_download_nltk=auto_download_nltk,
)
return refusal_returns[0]