|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import platform |
| 4 | +import subprocess |
| 5 | +from colorama import Fore, Style |
| 6 | +from tempfile import NamedTemporaryFile |
| 7 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 8 | +from transformers.generation.utils import GenerationConfig |
| 9 | + |
| 10 | + |
| 11 | +def init_model(): |
| 12 | + print("init model ...") |
| 13 | + model = AutoModelForCausalLM.from_pretrained( |
| 14 | + r"G:\04-model-weights\Baichuan2-7B-Chat-4bits", |
| 15 | + torch_dtype=torch.float16, |
| 16 | + device_map="auto", |
| 17 | + trust_remote_code=True |
| 18 | + ) |
| 19 | + model.generation_config = GenerationConfig.from_pretrained( |
| 20 | + r"G:\04-model-weights\Baichuan2-7B-Chat-4bits" |
| 21 | + ) |
| 22 | + tokenizer = AutoTokenizer.from_pretrained( |
| 23 | + r"G:\04-model-weights\Baichuan2-7B-Chat-4bits", |
| 24 | + use_fast=False, |
| 25 | + trust_remote_code=True |
| 26 | + ) |
| 27 | + return model, tokenizer |
| 28 | + |
| 29 | + |
| 30 | +def clear_screen(): |
| 31 | + if platform.system() == "Windows": |
| 32 | + os.system("cls") |
| 33 | + else: |
| 34 | + os.system("clear") |
| 35 | + print(Fore.YELLOW + Style.BRIGHT + "欢迎使用百川大模型,输入进行对话,vim 多行输入,clear 清空历史,CTRL+C 中断生成,stream 开关流式生成,exit 结束。") |
| 36 | + return [] |
| 37 | + |
| 38 | + |
| 39 | +def vim_input(): |
| 40 | + with NamedTemporaryFile() as tempfile: |
| 41 | + tempfile.close() |
| 42 | + subprocess.call(['vim', '+star', tempfile.name]) |
| 43 | + text = open(tempfile.name).read() |
| 44 | + return text |
| 45 | + |
| 46 | + |
| 47 | +def main(stream=True): |
| 48 | + model, tokenizer = init_model() |
| 49 | + messages = clear_screen() |
| 50 | + |
| 51 | + path_log = r"gpu_usage_log.txt" |
| 52 | + f = open(path_log, "w") |
| 53 | + |
| 54 | + while True: |
| 55 | + prompt = input(Fore.GREEN + Style.BRIGHT + "\n用户:" + Style.NORMAL) |
| 56 | + if prompt.strip() == "exit": |
| 57 | + break |
| 58 | + if prompt.strip() == "clear": |
| 59 | + messages = clear_screen() |
| 60 | + continue |
| 61 | + if prompt.strip() == 'vim': |
| 62 | + prompt = vim_input() |
| 63 | + print(prompt) |
| 64 | + print(Fore.CYAN + Style.BRIGHT + "\nBaichuan 2:" + Style.NORMAL, end='') |
| 65 | + if prompt.strip() == "stream": |
| 66 | + stream = not stream |
| 67 | + print(Fore.YELLOW + "({}流式生成)\n".format("开启" if stream else "关闭"), end='') |
| 68 | + continue |
| 69 | + messages.append({"role": "user", "content": prompt}) |
| 70 | + if stream: |
| 71 | + position = 0 |
| 72 | + try: |
| 73 | + for response in model.chat(tokenizer, messages, stream=True): |
| 74 | + print(response[position:], end='', flush=True) |
| 75 | + position = len(response) |
| 76 | + if torch.backends.mps.is_available(): |
| 77 | + torch.mps.empty_cache() |
| 78 | + except KeyboardInterrupt: |
| 79 | + pass |
| 80 | + print() |
| 81 | + else: |
| 82 | + response = model.chat(tokenizer, messages) |
| 83 | + print(response) |
| 84 | + if torch.backends.mps.is_available(): |
| 85 | + torch.mps.empty_cache() |
| 86 | + messages.append({"role": "assistant", "content": response}) |
| 87 | + |
| 88 | + conversation_length = sum([len(content['content']) for content in messages]) |
| 89 | + import subprocess |
| 90 | + import json |
| 91 | + result = subprocess.run(['gpustat', '--json'], stdout=subprocess.PIPE) |
| 92 | + output = result.stdout.decode() |
| 93 | + data = json.loads(output) |
| 94 | + used_memory = data['gpus'][0]['memory.used'] |
| 95 | + f.writelines("{}, {}\n".format(conversation_length, used_memory)) |
| 96 | + f.flush() |
| 97 | + |
| 98 | + print(Style.RESET_ALL) |
| 99 | + |
| 100 | + |
| 101 | +if __name__ == "__main__": |
| 102 | + main() |
0 commit comments