Skip to content

Commit d828754

Browse files
committed
feat: LLM的demo
1 parent b2e8023 commit d828754

File tree

4 files changed

+780
-0
lines changed

4 files changed

+780
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import os
2+
import torch
3+
import platform
4+
import subprocess
5+
from colorama import Fore, Style
6+
from tempfile import NamedTemporaryFile
7+
from transformers import AutoModelForCausalLM, AutoTokenizer
8+
from transformers.generation.utils import GenerationConfig
9+
10+
11+
def init_model():
12+
print("init model ...")
13+
model = AutoModelForCausalLM.from_pretrained(
14+
r"G:\04-model-weights\Baichuan2-7B-Chat-4bits",
15+
torch_dtype=torch.float16,
16+
device_map="auto",
17+
trust_remote_code=True
18+
)
19+
model.generation_config = GenerationConfig.from_pretrained(
20+
r"G:\04-model-weights\Baichuan2-7B-Chat-4bits"
21+
)
22+
tokenizer = AutoTokenizer.from_pretrained(
23+
r"G:\04-model-weights\Baichuan2-7B-Chat-4bits",
24+
use_fast=False,
25+
trust_remote_code=True
26+
)
27+
return model, tokenizer
28+
29+
30+
def clear_screen():
31+
if platform.system() == "Windows":
32+
os.system("cls")
33+
else:
34+
os.system("clear")
35+
print(Fore.YELLOW + Style.BRIGHT + "欢迎使用百川大模型,输入进行对话,vim 多行输入,clear 清空历史,CTRL+C 中断生成,stream 开关流式生成,exit 结束。")
36+
return []
37+
38+
39+
def vim_input():
40+
with NamedTemporaryFile() as tempfile:
41+
tempfile.close()
42+
subprocess.call(['vim', '+star', tempfile.name])
43+
text = open(tempfile.name).read()
44+
return text
45+
46+
47+
def main(stream=True):
48+
model, tokenizer = init_model()
49+
messages = clear_screen()
50+
51+
path_log = r"gpu_usage_log.txt"
52+
f = open(path_log, "w")
53+
54+
while True:
55+
prompt = input(Fore.GREEN + Style.BRIGHT + "\n用户:" + Style.NORMAL)
56+
if prompt.strip() == "exit":
57+
break
58+
if prompt.strip() == "clear":
59+
messages = clear_screen()
60+
continue
61+
if prompt.strip() == 'vim':
62+
prompt = vim_input()
63+
print(prompt)
64+
print(Fore.CYAN + Style.BRIGHT + "\nBaichuan 2:" + Style.NORMAL, end='')
65+
if prompt.strip() == "stream":
66+
stream = not stream
67+
print(Fore.YELLOW + "({}流式生成)\n".format("开启" if stream else "关闭"), end='')
68+
continue
69+
messages.append({"role": "user", "content": prompt})
70+
if stream:
71+
position = 0
72+
try:
73+
for response in model.chat(tokenizer, messages, stream=True):
74+
print(response[position:], end='', flush=True)
75+
position = len(response)
76+
if torch.backends.mps.is_available():
77+
torch.mps.empty_cache()
78+
except KeyboardInterrupt:
79+
pass
80+
print()
81+
else:
82+
response = model.chat(tokenizer, messages)
83+
print(response)
84+
if torch.backends.mps.is_available():
85+
torch.mps.empty_cache()
86+
messages.append({"role": "assistant", "content": response})
87+
88+
conversation_length = sum([len(content['content']) for content in messages])
89+
import subprocess
90+
import json
91+
result = subprocess.run(['gpustat', '--json'], stdout=subprocess.PIPE)
92+
output = result.stdout.decode()
93+
data = json.loads(output)
94+
used_memory = data['gpus'][0]['memory.used']
95+
f.writelines("{}, {}\n".format(conversation_length, used_memory))
96+
f.flush()
97+
98+
print(Style.RESET_ALL)
99+
100+
101+
if __name__ == "__main__":
102+
main()
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import os
2+
import platform
3+
from transformers import AutoTokenizer, AutoModel
4+
5+
# MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
6+
# TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
7+
8+
MODEL_PATH = r"G:\04-model-weights\chatglm\chatglm3-6b"
9+
TOKENIZER_PATH = r"G:\04-model-weights\chatglm\chatglm3-6b"
10+
11+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)
12+
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).quantize(bits=4, device="cuda").cuda().eval()
13+
# model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True).cuda().eval()
14+
# add .quantize(bits=4, device="cuda").cuda() before .eval() to use int4 model
15+
# must use cuda to load int4 model
16+
17+
os_name = platform.system()
18+
clear_command = 'cls' if os_name == 'Windows' else 'clear'
19+
stop_stream = False
20+
21+
welcome_prompt = "欢迎使用 ChatGLM3-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
22+
23+
24+
def build_prompt(history):
25+
prompt = welcome_prompt
26+
for query, response in history:
27+
prompt += f"\n\n用户:{query}"
28+
prompt += f"\n\nChatGLM3-6B:{response}"
29+
return prompt
30+
31+
32+
def main():
33+
past_key_values, history = None, []
34+
global stop_stream
35+
print(welcome_prompt)
36+
path_log = r"gpu_usage_log.txt"
37+
f = open(path_log, "w")
38+
while True:
39+
query = input("\n用户:")
40+
if query.strip() == "stop":
41+
break
42+
if query.strip() == "clear":
43+
past_key_values, history = None, []
44+
os.system(clear_command)
45+
print(welcome_prompt)
46+
continue
47+
print("\nChatGLM:", end="")
48+
current_length = 0
49+
for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, top_p=1,
50+
temperature=0.01,
51+
past_key_values=past_key_values,
52+
return_past_key_values=True):
53+
if stop_stream:
54+
stop_stream = False
55+
break
56+
else:
57+
print(response[current_length:], end="", flush=True)
58+
current_length = len(response)
59+
60+
# 统计文本长度
61+
conversation_length = sum([len(content['content']) for content in history])
62+
import subprocess
63+
import json
64+
result = subprocess.run(['gpustat', '--json'], stdout=subprocess.PIPE)
65+
output = result.stdout.decode()
66+
data = json.loads(output)
67+
used_memory = data['gpus'][0]['memory.used']
68+
f.writelines("{}, {}\n".format(conversation_length, used_memory))
69+
f.flush()
70+
print("")
71+
72+
73+
if __name__ == "__main__":
74+
main()

0 commit comments

Comments
 (0)