1+ __author__ = "guangzhi"
2+ '''
3+ Adapted from https://github.com/Teddy-XiongGZ/MedRAG/blob/main/src/medrag.py
4+ '''
5+
6+ import os
7+ import re
8+ import json
9+ import tqdm
10+ import torch
11+ import time
12+ import argparse
13+ import transformers
14+ from transformers import AutoTokenizer
15+ import openai
16+ from transformers import StoppingCriteria , StoppingCriteriaList
17+ import tiktoken
18+ import openai
19+ import sys
20+ from huggingface_hub import login
21+
22+ login (token = os .getenv ("HUGGINGFACE_TOKEN" ))
23+
24+ openai .api_key = os .getenv ("OPENAI_API_KEY" )
25+
26+
27+ class LLMInference :
28+
29+ def __init__ (self , llm_name = "OpenAI/gpt-3.5-turbo" , cache_dir = "../../huggingface/hub" ):
30+ self .llm_name = llm_name
31+ self .cache_dir = cache_dir
32+ if self .llm_name .split ('/' )[0 ].lower () == "openai" :
33+ self .model = self .llm_name .split ('/' )[- 1 ]
34+ if "gpt-3.5" in self .model or "gpt-35" in self .model :
35+ self .max_length = 4096
36+ elif "gpt-4" in self .model :
37+ self .max_length = 8192
38+ self .tokenizer = tiktoken .get_encoding ("cl100k_base" )
39+ else :
40+ self .type = torch .bfloat16
41+ self .tokenizer = AutoTokenizer .from_pretrained (self .llm_name , cache_dir = self .cache_dir , legacy = False )
42+ if "mixtral" in llm_name .lower () or "mistral" in llm_name .lower ():
43+ self .tokenizer .chat_template = open ('../templates/mistral-instruct.jinja' ).read ().replace (' ' , '' ).replace ('\n ' , '' )
44+ self .max_length = 32768
45+ elif "llama-2" in llm_name .lower ():
46+ self .max_length = 4096
47+ self .type = torch .float16
48+ elif "llama-3" in llm_name .lower ():
49+ self .max_length = 8192
50+ elif "meditron-70b" in llm_name .lower ():
51+ self .tokenizer .chat_template = open ('../templates/meditron.jinja' ).read ().replace (' ' , '' ).replace ('\n ' , '' )
52+ self .max_length = 4096
53+ elif "pmc_llama" in llm_name .lower ():
54+ self .tokenizer .chat_template = open ('../templates/pmc_llama.jinja' ).read ().replace (' ' , '' ).replace ('\n ' , '' )
55+ self .max_length = 2048
56+ self .model = transformers .pipeline (
57+ "text-generation" ,
58+ model = self .llm_name ,
59+ torch_dtype = self .type ,
60+ device_map = "auto" ,
61+ model_kwargs = {"cache_dir" :self .cache_dir },
62+ )
63+
64+ def answer (self , messages ):
65+ # generate answers
66+
67+ ans = self .generate (messages )
68+ ans = re .sub ("\s+" , " " , ans )
69+
70+ return ans
71+
72+ def custom_stop (self , stop_str , input_len = 0 ):
73+ stopping_criteria = StoppingCriteriaList ([CustomStoppingCriteria (stop_str , self .tokenizer , input_len )])
74+ return stopping_criteria
75+
76+ def generate (self , messages , prompt = None ):
77+ '''
78+ generate response given messages
79+ '''
80+ if "openai" in self .llm_name .lower ():
81+ response = openai .ChatCompletion .create (
82+ model = self .model ,
83+ messages = messages
84+ )
85+
86+ ans = response .choices [0 ].message .content
87+
88+ else :
89+ stopping_criteria = None
90+ if prompt is None :
91+ prompt = self .tokenizer .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
92+ if "meditron" in self .llm_name .lower ():
93+ stopping_criteria = self .custom_stop (["###" , "User:" , "\n \n \n " ], input_len = len (self .tokenizer .encode (prompt , add_special_tokens = True )))
94+ if "llama-3" in self .llm_name .lower ():
95+ response = self .model (
96+ prompt ,
97+ do_sample = False ,
98+ eos_token_id = [self .tokenizer .eos_token_id , self .tokenizer .convert_tokens_to_ids ("<|eot_id|>" )],
99+ pad_token_id = self .tokenizer .eos_token_id ,
100+ max_length = min (self .max_length , len (self .tokenizer .encode (prompt , add_special_tokens = True )) + 4096 ),
101+ truncation = True ,
102+ stopping_criteria = stopping_criteria ,
103+ temperature = 0.0
104+ )
105+ else :
106+ response = self .model (
107+ prompt ,
108+ do_sample = False ,
109+ eos_token_id = self .tokenizer .eos_token_id ,
110+ pad_token_id = self .tokenizer .eos_token_id ,
111+ max_length = min (self .max_length , len (self .tokenizer .encode (prompt , add_special_tokens = True )) + 4096 ),
112+ truncation = True ,
113+ stopping_criteria = stopping_criteria ,
114+ temperature = 0.0
115+ )
116+ ans = response [0 ]["generated_text" ]
117+ return ans
118+
119+
120+ class CustomStoppingCriteria (StoppingCriteria ):
121+ def __init__ (self , stop_words , tokenizer , input_len = 0 ):
122+ super ().__init__ ()
123+ self .tokenizer = tokenizer
124+ self .stops_words = stop_words
125+ self .input_len = input_len
126+
127+ def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ):
128+ tokens = self .tokenizer .decode (input_ids [0 ][self .input_len :])
129+ return any (stop in tokens for stop in self .stops_words )
0 commit comments