-
Notifications
You must be signed in to change notification settings - Fork 56
Expand file tree
/
Copy pathchat.py
More file actions
118 lines (112 loc) · 3.53 KB
/
chat.py
File metadata and controls
118 lines (112 loc) · 3.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from transformers import (
TextStreamer,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from argparse import ArgumentParser
import torch
parser = ArgumentParser()
parser.add_argument(
"--model", "-m", type=str, required=True, help="Path to model directory"
)
parser.add_argument(
"--precision",
"-p",
type=str,
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Precision of model",
)
parser.add_argument(
"--device",
"-d",
type=str,
choices=["auto", "cuda", "cpu"],
default="auto",
help="Target device to process abliteration. Warning, bitsandbytes quantization DOES NOT support CPU",
)
parser.add_argument(
"--max-new-tokens", "-n", type=int, default=256, help="Max new tokens to generate"
)
parser.add_argument(
"--system-prompt", "-s", type=str, default=None, help="System prompt"
)
quant = parser.add_mutually_exclusive_group()
quant.add_argument(
"--load-in-4bit",
action="store_true",
default=False,
help="Load model in 4-bit precision using bitsandbytes",
)
quant.add_argument(
"--load-in-8bit",
action="store_true",
default=False,
help="Load model in 8-bit precision using bitsandbytes",
)
parser.add_argument(
"--flash-attn", action="store_true", default=False, help="Use flash attention 2"
)
args = parser.parse_args()
if __name__ == "__main__":
if args.precision == "fp16":
precision = torch.float16
elif args.precision == "bf16":
precision = torch.bfloat16
elif args.precision == "fp32":
precision = torch.float32
if args.load_in_4bit:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=precision,
bnb_4bit_use_double_quant=True,
)
elif args.load_in_8bit:
quant_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_enable_fp32_cpu_offload=True,
llm_int8_has_fp16_weight=True,
)
else:
quant_config = None
model = AutoModelForCausalLM.from_pretrained(
args.model,
trust_remote_code=True,
dtype=precision,
low_cpu_mem_usage=True,
device_map=args.device,
quantization_config=quant_config,
attn_implementation="flash_attention_2" if args.flash_attn else None,
)
tokenizer = AutoTokenizer.from_pretrained(
args.model, trust_remote_code=True, device_map=args.device
)
conversation = []
if args.system_prompt is not None:
print(f"System Prompt: {args.system_prompt}")
conversation.append({"role": "system", "content": args.system_prompt})
streamer = TextStreamer(tokenizer)
print("Type /clear to clear history, /exit to quit.")
while True:
prompt = input("User> ")
if prompt == "/clear":
conversation = []
print("! History cleared.")
continue
elif prompt == "/exit":
break
elif prompt == "":
print("! Please type a message.")
continue
conversation.append({"role": "user", "content": prompt})
toks = tokenizer.apply_chat_template(
conversation=conversation, add_generation_prompt=True, return_tensors="pt"
)
gen = model.generate(
toks.to(model.device), streamer=streamer, max_new_tokens=args.max_new_tokens
)
decoded = tokenizer.batch_decode(
gen[0][len(toks[0]) :], skip_special_tokens=True
)
conversation.append({"role": "assistant", "content": "".join(decoded)})