Skip to content

Commit e0f56f0

Browse files
committed
reformatted code
1 parent 1e103f3 commit e0f56f0

File tree

3 files changed

+21
-24
lines changed

3 files changed

+21
-24
lines changed

fastchat/serve/monitor/classify/category.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
import ast
1212
import re
1313

14-
from utils import (
15-
HuggingFaceRefusalClassifier
16-
)
14+
from utils import HuggingFaceRefusalClassifier
1715

1816

1917
class Category:
@@ -192,6 +190,9 @@ def __init__(self):
192190
def pre_process(self, conversation):
193191
conv = []
194192
for i in range(0, len(conversation), 2):
195-
args = {"QUERY": conversation[i]["content"], "RESPONSE": conversation[i+1]["content"]}
193+
args = {
194+
"QUERY": conversation[i]["content"],
195+
"RESPONSE": conversation[i + 1]["content"],
196+
}
196197
conv.append(self.prompt_template.format(**args))
197198
return conv

fastchat/serve/monitor/classify/label.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111

1212
from category import Category
1313

14-
from utils import (
15-
api_config,
16-
chat_completion_openai
17-
)
14+
from utils import api_config, chat_completion_openai
1815

1916
LOCK = threading.RLock()
2017

@@ -58,7 +55,6 @@ def get_answer(
5855
output_log = {}
5956

6057
for category in categories:
61-
6258
if category.name_tag == "refusal_v0.2":
6359
refusal_classifier = category.classifier
6460

@@ -69,10 +65,10 @@ def get_answer(
6965
batch_size = 16
7066
refusal_results = []
7167
for i in range(0, len(refusal_prompts), batch_size):
72-
batch_prompts = refusal_prompts[i:i + batch_size]
68+
batch_prompts = refusal_prompts[i : i + batch_size]
7369
batch_results = refusal_classifier.classify_batch(batch_prompts)
7470
refusal_results.extend(batch_results)
75-
71+
7672
# If any query/resp classified as refusal, entire conversation is refusal
7773
output = any(refusal_results)
7874

fastchat/serve/monitor/classify/utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,24 +66,24 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No
6666
class HuggingFaceRefusalClassifier:
6767
def __init__(self):
6868
print("Loading model and tokenizer...")
69-
self.tokenizer = AutoTokenizer.from_pretrained("derixu/refusal_classifier-mlm_then_classifier_v3") #TODO: Migrate to LMSYS account and change path
70-
self.model = AutoModelForSequenceClassification.from_pretrained("derixu/refusal_classifier-mlm_then_classifier_v3")
71-
self.model.eval()
69+
self.tokenizer = AutoTokenizer.from_pretrained(
70+
"derixu/refusal_classifier-mlm_then_classifier_v3"
71+
) # TODO: Migrate to LMSYS account and change path
72+
self.model = AutoModelForSequenceClassification.from_pretrained(
73+
"derixu/refusal_classifier-mlm_then_classifier_v3"
74+
)
75+
self.model.eval()
7276

7377
def classify_batch(self, input_texts):
7478
inputs = self.tokenizer(
75-
input_texts,
76-
truncation=True,
77-
max_length=512,
78-
return_tensors="pt",
79-
padding=True
79+
input_texts,
80+
truncation=True,
81+
max_length=512,
82+
return_tensors="pt",
83+
padding=True,
8084
)
81-
with torch.no_grad():
85+
with torch.no_grad():
8286
outputs = self.model(**inputs)
8387
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
8488
pred_classes = torch.argmax(probabilities, dim=-1).tolist()
8589
return [bool(pred) for pred in pred_classes]
86-
87-
88-
89-

0 commit comments

Comments
 (0)