Skip to content

Commit 073589d

Browse files
committed
Adjusted huggingface login code
1 parent 5cc13b5 commit 073589d

File tree

6 files changed

+994
-3
lines changed

6 files changed

+994
-3
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
__author__ = "guangzhi"
2+
'''
3+
Adapted from https://github.com/Teddy-XiongGZ/MedRAG/blob/main/src/medrag.py
4+
'''
5+
6+
import os
7+
import re
8+
import json
9+
import tqdm
10+
import torch
11+
import time
12+
import argparse
13+
import transformers
14+
from transformers import AutoTokenizer
15+
import openai
16+
from transformers import StoppingCriteria, StoppingCriteriaList
17+
import tiktoken
18+
import openai
19+
import sys
20+
from huggingface_hub import login
21+
22+
login(token=os.getenv("HUGGINGFACE_TOKEN"))
23+
24+
openai.api_key = os.getenv("OPENAI_API_KEY")
25+
26+
27+
class LLMInference:
28+
29+
def __init__(self, llm_name="OpenAI/gpt-3.5-turbo", cache_dir="../../huggingface/hub"):
30+
self.llm_name = llm_name
31+
self.cache_dir = cache_dir
32+
if self.llm_name.split('/')[0].lower() == "openai":
33+
self.model = self.llm_name.split('/')[-1]
34+
if "gpt-3.5" in self.model or "gpt-35" in self.model:
35+
self.max_length = 4096
36+
elif "gpt-4" in self.model:
37+
self.max_length = 8192
38+
self.tokenizer = tiktoken.get_encoding("cl100k_base")
39+
else:
40+
self.type = torch.bfloat16
41+
self.tokenizer = AutoTokenizer.from_pretrained(self.llm_name, cache_dir=self.cache_dir, legacy=False)
42+
if "mixtral" in llm_name.lower() or "mistral" in llm_name.lower():
43+
self.tokenizer.chat_template = open('../templates/mistral-instruct.jinja').read().replace(' ', '').replace('\n', '')
44+
self.max_length = 32768
45+
elif "llama-2" in llm_name.lower():
46+
self.max_length = 4096
47+
self.type = torch.float16
48+
elif "llama-3" in llm_name.lower():
49+
self.max_length = 8192
50+
elif "meditron-70b" in llm_name.lower():
51+
self.tokenizer.chat_template = open('../templates/meditron.jinja').read().replace(' ', '').replace('\n', '')
52+
self.max_length = 4096
53+
elif "pmc_llama" in llm_name.lower():
54+
self.tokenizer.chat_template = open('../templates/pmc_llama.jinja').read().replace(' ', '').replace('\n', '')
55+
self.max_length = 2048
56+
self.model = transformers.pipeline(
57+
"text-generation",
58+
model=self.llm_name,
59+
torch_dtype=self.type,
60+
device_map="auto",
61+
model_kwargs={"cache_dir":self.cache_dir},
62+
)
63+
64+
def answer(self, messages):
65+
# generate answers
66+
67+
ans = self.generate(messages)
68+
ans = re.sub("\s+", " ", ans)
69+
70+
return ans
71+
72+
def custom_stop(self, stop_str, input_len=0):
73+
stopping_criteria = StoppingCriteriaList([CustomStoppingCriteria(stop_str, self.tokenizer, input_len)])
74+
return stopping_criteria
75+
76+
def generate(self, messages, prompt=None):
77+
'''
78+
generate response given messages
79+
'''
80+
if "openai" in self.llm_name.lower():
81+
response = openai.ChatCompletion.create(
82+
model=self.model,
83+
messages=messages
84+
)
85+
86+
ans = response.choices[0].message.content
87+
88+
else:
89+
stopping_criteria = None
90+
if prompt is None:
91+
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
92+
if "meditron" in self.llm_name.lower():
93+
stopping_criteria = self.custom_stop(["###", "User:", "\n\n\n"], input_len=len(self.tokenizer.encode(prompt, add_special_tokens=True)))
94+
if "llama-3" in self.llm_name.lower():
95+
response = self.model(
96+
prompt,
97+
do_sample=False,
98+
eos_token_id=[self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>")],
99+
pad_token_id=self.tokenizer.eos_token_id,
100+
max_length=min(self.max_length, len(self.tokenizer.encode(prompt, add_special_tokens=True)) + 4096),
101+
truncation=True,
102+
stopping_criteria=stopping_criteria,
103+
temperature=0.0
104+
)
105+
else:
106+
response = self.model(
107+
prompt,
108+
do_sample=False,
109+
eos_token_id=self.tokenizer.eos_token_id,
110+
pad_token_id=self.tokenizer.eos_token_id,
111+
max_length=min(self.max_length, len(self.tokenizer.encode(prompt, add_special_tokens=True)) + 4096),
112+
truncation=True,
113+
stopping_criteria=stopping_criteria,
114+
temperature=0.0
115+
)
116+
ans = response[0]["generated_text"]
117+
return ans
118+
119+
120+
class CustomStoppingCriteria(StoppingCriteria):
121+
def __init__(self, stop_words, tokenizer, input_len=0):
122+
super().__init__()
123+
self.tokenizer = tokenizer
124+
self.stops_words = stop_words
125+
self.input_len = input_len
126+
127+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
128+
tokens = self.tokenizer.decode(input_ids[0][self.input_len:])
129+
return any(stop in tokens for stop in self.stops_words)

0 commit comments

Comments
 (0)