Skip to content

Commit 3a99a54

Browse files
authored
Add preprocessor to patch PromptGuard scores for inserted characters (meta-llama#636)
1 parent 1e638d6 commit 3a99a54

File tree

1 file changed

+54
-12
lines changed

1 file changed

+54
-12
lines changed

recipes/responsible_ai/prompt_guard/inference.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,45 @@ def load_model_and_tokenizer(model_name='meta-llama/Prompt-Guard-86M'):
3131
return model, tokenizer
3232

3333

34-
def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu'):
34+
def preprocess_text_for_promptguard(text: str, tokenizer) -> str:
35+
"""
36+
Preprocess the text by removing spaces that break apart larger tokens.
37+
This hotfixes a workaround to PromptGuard, where spaces can be inserted into a string
38+
to allow the string to be classified as benign.
39+
40+
Args:
41+
text (str): The input text to preprocess.
42+
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
43+
44+
Returns:
45+
str: The preprocessed text.
46+
"""
47+
48+
try:
49+
cleaned_text = ''
50+
index_map = []
51+
for i, char in enumerate(text):
52+
if not char.isspace():
53+
cleaned_text += char
54+
index_map.append(i)
55+
tokens = tokenizer.tokenize(cleaned_text)
56+
result = []
57+
last_end = 0
58+
for token in tokens:
59+
token_str = tokenizer.convert_tokens_to_string([token])
60+
start = cleaned_text.index(token_str, last_end)
61+
end = start + len(token_str)
62+
original_start = index_map[start]
63+
if original_start > 0 and text[original_start - 1].isspace():
64+
result.append(' ')
65+
result.append(token_str)
66+
last_end = end
67+
return ''.join(result)
68+
except Exception:
69+
return text
70+
71+
72+
def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
3573
"""
3674
Evaluate the model on the given text with temperature-adjusted softmax.
3775
Note, as this is a DeBERTa model, the input text should have a maximum length of 512.
@@ -44,6 +82,8 @@ def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu
4482
Returns:
4583
torch.Tensor: The probability of each class adjusted by the temperature.
4684
"""
85+
if preprocess:
86+
text = preprocess_text_for_promptguard(text, tokenizer)
4787
# Encode the text
4888
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
4989
inputs = inputs.to(device)
@@ -57,7 +97,7 @@ def get_class_probabilities(model, tokenizer, text, temperature=1.0, device='cpu
5797
return probabilities
5898

5999

60-
def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
100+
def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
61101
"""
62102
Evaluate the probability that a given string contains malicious jailbreak or prompt injection.
63103
Appropriate for filtering dialogue between a user and an LLM.
@@ -70,11 +110,11 @@ def get_jailbreak_score(model, tokenizer, text, temperature=1.0, device='cpu'):
70110
Returns:
71111
float: The probability of the text containing malicious content.
72112
"""
73-
probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
113+
probabilities = get_class_probabilities(model, tokenizer, text, temperature, device, preprocess)
74114
return probabilities[0, 2].item()
75115

76116

77-
def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu'):
117+
def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device='cpu', preprocess=True):
78118
"""
79119
Evaluate the probability that a given string contains any embedded instructions (malicious or benign).
80120
Appropriate for filtering third party inputs (e.g. web searches, tool outputs) into an LLM.
@@ -87,11 +127,11 @@ def get_indirect_injection_score(model, tokenizer, text, temperature=1.0, device
87127
Returns:
88128
float: The combined probability of the text containing malicious or embedded instructions.
89129
"""
90-
probabilities = get_class_probabilities(model, tokenizer, text, temperature, device)
130+
probabilities = get_class_probabilities(model, tokenizer, text, temperature, device, preprocess)
91131
return (probabilities[0, 1] + probabilities[0, 2]).item()
92132

93133

94-
def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
134+
def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu', preprocess=True):
95135
"""
96136
Process a batch of texts and return their class probabilities.
97137
Args:
@@ -104,6 +144,8 @@ def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
104144
Returns:
105145
torch.Tensor: A tensor containing the class probabilities for each text in the batch.
106146
"""
147+
if preprocess:
148+
texts = [preprocess_text_for_promptguard(text, tokenizer) for text in texts]
107149
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
108150
inputs = inputs.to(device)
109151
with torch.no_grad():
@@ -113,7 +155,7 @@ def process_text_batch(model, tokenizer, texts, temperature=1.0, device='cpu'):
113155
return probabilities
114156

115157

116-
def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0, device='cpu', max_batch_size=16):
158+
def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
117159
"""
118160
Compute scores for a list of texts, handling texts of arbitrary length by breaking them into chunks and processing in parallel.
119161
Args:
@@ -138,15 +180,15 @@ def get_scores_for_texts(model, tokenizer, texts, score_indices, temperature=1.0
138180
for i in range(0, len(all_chunks), max_batch_size):
139181
batch_chunks = all_chunks[i:i+max_batch_size]
140182
batch_indices = text_indices[i:i+max_batch_size]
141-
probabilities = process_text_batch(model, tokenizer, batch_chunks, temperature, device)
183+
probabilities = process_text_batch(model, tokenizer, batch_chunks, temperature, device, preprocess)
142184
scores = probabilities[:, score_indices].sum(dim=1).tolist()
143185

144186
for idx, score in zip(batch_indices, scores):
145187
all_scores[idx] = max(all_scores[idx], score)
146188
return all_scores
147189

148190

149-
def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16):
191+
def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
150192
"""
151193
Compute jailbreak scores for a list of texts.
152194
Args:
@@ -160,10 +202,10 @@ def get_jailbreak_scores_for_texts(model, tokenizer, texts, temperature=1.0, dev
160202
Returns:
161203
list[float]: A list of jailbreak scores for each text.
162204
"""
163-
return get_scores_for_texts(model, tokenizer, texts, [2], temperature, device, max_batch_size)
205+
return get_scores_for_texts(model, tokenizer, texts, [2], temperature, device, max_batch_size, preprocess)
164206

165207

166-
def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16):
208+
def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature=1.0, device='cpu', max_batch_size=16, preprocess=True):
167209
"""
168210
Compute indirect injection scores for a list of texts.
169211
Args:
@@ -177,4 +219,4 @@ def get_indirect_injection_scores_for_texts(model, tokenizer, texts, temperature
177219
Returns:
178220
list[float]: A list of indirect injection scores for each text.
179221
"""
180-
return get_scores_for_texts(model, tokenizer, texts, [1, 2], temperature, device, max_batch_size)
222+
return get_scores_for_texts(model, tokenizer, texts, [1, 2], temperature, device, max_batch_size, preprocess)

0 commit comments

Comments
 (0)