Skip to content

Commit b2e8023

Browse files
committed
feat: qwen代码上传
1 parent 65b0e54 commit b2e8023

File tree

6 files changed

+3976
-0
lines changed

6 files changed

+3976
-0
lines changed

code/chapter-10/cli_demo_qwen.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
# Copyright (c) Alibaba Cloud.
2+
#
3+
# This source code is licensed under the license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""A simple command-line interactive chat demo."""
7+
8+
import argparse
9+
import os
10+
import platform
11+
import shutil
12+
from copy import deepcopy
13+
14+
import torch
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
from transformers.generation import GenerationConfig
17+
from transformers.trainer_utils import set_seed
18+
19+
# DEFAULT_CKPT_PATH = r"G:\04-model-weights\qwen\Qwen-1_8B-Chat"
20+
DEFAULT_CKPT_PATH = r"G:\04-model-weights\qwen\Qwen-7B-Chat-Int4"
21+
22+
_WELCOME_MSG = '''\
23+
Welcome to use Qwen-Chat model, type text to start chat, type :h to show command help.
24+
(欢迎使用 Qwen-Chat 模型,输入内容即可进行对话,:h 显示命令帮助。)
25+
26+
Note: This demo is governed by the original license of Qwen.
27+
We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc.
28+
(注:本演示受Qwen的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)
29+
'''
30+
_HELP_MSG = '''\
31+
Commands:
32+
:help / :h Show this help message 显示帮助信息
33+
:exit / :quit / :q Exit the demo 退出Demo
34+
:clear / :cl Clear screen 清屏
35+
:clear-his / :clh Clear history 清除对话历史
36+
:history / :his Show history 显示对话历史
37+
:seed Show current random seed 显示当前随机种子
38+
:seed <N> Set random seed to <N> 设置随机种子
39+
:conf Show current generation config 显示生成配置
40+
:conf <key>=<value> Change generation config 修改生成配置
41+
:reset-conf Reset generation config 重置生成配置
42+
'''
43+
44+
45+
def _load_model_tokenizer(args):
46+
tokenizer = AutoTokenizer.from_pretrained(
47+
args.checkpoint_path, trust_remote_code=True, resume_download=True,
48+
)
49+
50+
if args.cpu_only:
51+
device_map = "cpu"
52+
else:
53+
device_map = "auto"
54+
55+
model = AutoModelForCausalLM.from_pretrained(
56+
args.checkpoint_path,
57+
device_map=device_map,
58+
trust_remote_code=True,
59+
resume_download=True,
60+
).eval()
61+
62+
config = GenerationConfig.from_pretrained(
63+
args.checkpoint_path, trust_remote_code=True, resume_download=True,
64+
)
65+
66+
return model, tokenizer, config
67+
68+
69+
def _gc():
70+
import gc
71+
gc.collect()
72+
if torch.cuda.is_available():
73+
torch.cuda.empty_cache()
74+
75+
76+
def _clear_screen():
77+
if platform.system() == "Windows":
78+
os.system("cls")
79+
else:
80+
os.system("clear")
81+
82+
83+
def _print_history(history):
84+
terminal_width = shutil.get_terminal_size()[0]
85+
print(f'History ({len(history)})'.center(terminal_width, '='))
86+
for index, (query, response) in enumerate(history):
87+
print(f'User[{index}]: {query}')
88+
print(f'QWen[{index}]: {response}')
89+
print('=' * terminal_width)
90+
91+
92+
def _get_input() -> str:
93+
while True:
94+
try:
95+
message = input('User> ').strip()
96+
except UnicodeDecodeError:
97+
print('[ERROR] Encoding error in input')
98+
continue
99+
except KeyboardInterrupt:
100+
exit(1)
101+
if message:
102+
return message
103+
print('[ERROR] Query is empty')
104+
105+
106+
def main():
107+
path_log = r"gpu_usage_log.txt"
108+
f = open(path_log, "w")
109+
110+
parser = argparse.ArgumentParser(
111+
description='QWen-Chat command-line interactive chat demo.')
112+
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
113+
help="Checkpoint name or path, default to %(default)r")
114+
parser.add_argument("-s", "--seed", type=int, default=1234, help="Random seed")
115+
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
116+
args = parser.parse_args()
117+
118+
history, response = [], ''
119+
120+
model, tokenizer, config = _load_model_tokenizer(args)
121+
orig_gen_config = deepcopy(model.generation_config)
122+
123+
_clear_screen()
124+
print(_WELCOME_MSG)
125+
126+
seed = args.seed
127+
128+
while True:
129+
query = _get_input()
130+
131+
# Process commands.
132+
if query.startswith(':'):
133+
command_words = query[1:].strip().split()
134+
if not command_words:
135+
command = ''
136+
else:
137+
command = command_words[0]
138+
139+
if command in ['exit', 'quit', 'q']:
140+
break
141+
elif command in ['clear', 'cl']:
142+
_clear_screen()
143+
print(_WELCOME_MSG)
144+
_gc()
145+
continue
146+
elif command in ['clear-history', 'clh']:
147+
print(f'[INFO] All {len(history)} history cleared')
148+
history.clear()
149+
_gc()
150+
continue
151+
elif command in ['help', 'h']:
152+
print(_HELP_MSG)
153+
continue
154+
elif command in ['history', 'his']:
155+
_print_history(history)
156+
continue
157+
elif command in ['seed']:
158+
if len(command_words) == 1:
159+
print(f'[INFO] Current random seed: {seed}')
160+
continue
161+
else:
162+
new_seed_s = command_words[1]
163+
try:
164+
new_seed = int(new_seed_s)
165+
except ValueError:
166+
print(f'[WARNING] Fail to change random seed: {new_seed_s!r} is not a valid number')
167+
else:
168+
print(f'[INFO] Random seed changed to {new_seed}')
169+
seed = new_seed
170+
continue
171+
elif command in ['conf']:
172+
if len(command_words) == 1:
173+
print(model.generation_config)
174+
else:
175+
for key_value_pairs_str in command_words[1:]:
176+
eq_idx = key_value_pairs_str.find('=')
177+
if eq_idx == -1:
178+
print('[WARNING] format: <key>=<value>')
179+
continue
180+
conf_key, conf_value_str = key_value_pairs_str[:eq_idx], key_value_pairs_str[eq_idx + 1:]
181+
try:
182+
conf_value = eval(conf_value_str)
183+
except Exception as e:
184+
print(e)
185+
continue
186+
else:
187+
print(f'[INFO] Change config: model.generation_config.{conf_key} = {conf_value}')
188+
setattr(model.generation_config, conf_key, conf_value)
189+
continue
190+
elif command in ['reset-conf']:
191+
print('[INFO] Reset generation config')
192+
model.generation_config = deepcopy(orig_gen_config)
193+
print(model.generation_config)
194+
continue
195+
else:
196+
# As normal query.
197+
pass
198+
199+
# Run chat.
200+
set_seed(seed)
201+
try:
202+
for response in model.chat_stream(tokenizer, query, history=history, generation_config=config):
203+
_clear_screen()
204+
print(f"\nUser: {query}")
205+
print(f"\nQwen-Chat: {response}")
206+
except KeyboardInterrupt:
207+
print('[WARNING] Generation interrupted')
208+
continue
209+
210+
history.append((query, response))
211+
212+
# 统计文本长度
213+
conversation_length = sum([len(text) for content in history for text in content])
214+
import subprocess
215+
import json
216+
result = subprocess.run(['gpustat', '--json'], stdout=subprocess.PIPE)
217+
output = result.stdout.decode()
218+
data = json.loads(output)
219+
used_memory = data['gpus'][0]['memory.used']
220+
f.writelines("{}, {}\n".format(conversation_length, used_memory))
221+
f.flush()
222+
223+
224+
if __name__ == "__main__":
225+
main()

0 commit comments

Comments
 (0)