forked from lmstudio-ai/mlx-engine
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdemo.py
More file actions
276 lines (245 loc) · 8.61 KB
/
demo.py
File metadata and controls
276 lines (245 loc) · 8.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
import argparse
import base64
import time
import os
import sys
from mlx_engine.generate import load_model, load_draft_model, create_generator, tokenize
from mlx_engine.utils.token import Token
from mlx_engine.utils.kv_cache_quantization import VALID_KV_BITS, VALID_KV_GROUP_SIZE
from mlx_engine.utils.prompt_progress_reporter import LoggerReporter
from transformers import AutoTokenizer, AutoProcessor
DEFAULT_PROMPT = "Explain the rules of chess in one sentence"
DEFAULT_TEMP = 0.8
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="LM Studio mlx-engine inference script"
)
parser.add_argument(
"--model",
required=True,
type=str,
help="The file system path to the model",
)
parser.add_argument(
"--prompt",
default=DEFAULT_PROMPT,
type=str,
help="Message to be processed by the model. Use '-' to read from stdin",
)
parser.add_argument(
"--system",
default=DEFAULT_SYSTEM_PROMPT,
type=str,
help="System prompt for the model",
)
parser.add_argument(
"--no-system",
action="store_true",
help="Disable the system prompt",
)
parser.add_argument(
"--images",
type=str,
nargs="+",
help="Path of the images to process",
)
parser.add_argument(
"--temp",
default=DEFAULT_TEMP,
type=float,
help="Sampling temperature",
)
parser.add_argument(
"--stop-strings",
type=str,
nargs="+",
help="Strings that will stop the generation",
)
parser.add_argument(
"--top-logprobs",
type=int,
default=0,
help="Number of top logprobs to return",
)
parser.add_argument(
"--max-kv-size",
type=int,
help="Max context size of the model",
)
parser.add_argument(
"--kv-bits",
type=int,
choices=VALID_KV_BITS,
help="Number of bits for KV cache quantization. Must be between 3 and 8 (inclusive)",
)
parser.add_argument(
"--kv-group-size",
type=int,
choices=VALID_KV_GROUP_SIZE,
help="Group size for KV cache quantization",
)
parser.add_argument(
"--quantized-kv-start",
type=int,
help="When --kv-bits is set, start quantizing the KV cache from this step onwards",
)
parser.add_argument(
"--draft-model",
type=str,
help="The file system path to the draft model for speculative decoding.",
)
parser.add_argument(
"--num-draft-tokens",
type=int,
help="Number of tokens to draft when using speculative decoding.",
)
parser.add_argument(
"--print-prompt-progress",
action="store_true",
help="Enable printed prompt processing progress callback",
)
parser.add_argument(
"--max-img-size", type=int, help="Downscale images to this side length (px)"
)
return parser
def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
class GenerationStatsCollector:
def __init__(self):
self.start_time = time.time()
self.first_token_time = None
self.total_tokens = 0
self.num_accepted_draft_tokens: int | None = None
def add_tokens(self, tokens: list[Token]):
"""Record new tokens and their timing."""
if self.first_token_time is None:
self.first_token_time = time.time()
draft_tokens = sum(1 for token in tokens if token.from_draft)
if self.num_accepted_draft_tokens is None:
self.num_accepted_draft_tokens = 0
self.num_accepted_draft_tokens += draft_tokens
self.total_tokens += len(tokens)
def print_stats(self):
"""Print generation statistics."""
end_time = time.time()
total_time = end_time - self.start_time
time_to_first_token = self.first_token_time - self.start_time
effective_time = total_time - time_to_first_token
tokens_per_second = (
self.total_tokens / effective_time if effective_time > 0 else float("inf")
)
print("\n\nGeneration stats:")
print(f" - Tokens per second: {tokens_per_second:.2f}")
if self.num_accepted_draft_tokens is not None:
print(
f" - Number of accepted draft tokens: {self.num_accepted_draft_tokens}"
)
print(f" - Time to first token: {time_to_first_token:.2f}s")
print(f" - Total tokens generated: {self.total_tokens}")
print(f" - Total time: {total_time:.2f}s")
def resolve_model_path(model_arg):
# If it's a full path or local file, return as-is
if os.path.exists(model_arg):
return model_arg
# Check common local directories
local_paths = [
os.path.expanduser("~/.lmstudio/models"),
os.path.expanduser("~/.cache/lm-studio/models"),
]
for path in local_paths:
full_path = os.path.join(path, model_arg)
if os.path.exists(full_path):
return full_path
raise ValueError(f"Could not find model '{model_arg}' in local directories")
if __name__ == "__main__":
# Parse arguments
parser = setup_arg_parser()
args = parser.parse_args()
if isinstance(args.images, str):
args.images = [args.images]
# Load the model
model_path = resolve_model_path(args.model)
print("Loading model...", end="\n", flush=True)
model_kit = load_model(
str(model_path),
max_kv_size=args.max_kv_size,
trust_remote_code=False,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
)
print("\rModel load complete ✓", end="\n", flush=True)
# Load draft model if requested
if args.draft_model:
load_draft_model(model_kit=model_kit, path=resolve_model_path(args.draft_model))
# Tokenize the prompt
prompt = args.prompt
if prompt == "-":
stdin_prompt = sys.stdin.read()
prompt = stdin_prompt
# Build conversation with optional system prompt
conversation = []
if not args.no_system:
conversation.append({"role": "system", "content": args.system})
# Handle the prompt according to the input type
# If images are provided, add them to the prompt
images_base64 = []
if args.images:
tf_tokenizer = AutoProcessor.from_pretrained(model_path)
images_base64 = [image_to_base64(img_path) for img_path in args.images]
conversation.append(
{
"role": "user",
"content": [
*[
{"type": "image", "base64": image_b64}
for image_b64 in images_base64
],
{"type": "text", "text": prompt},
],
}
)
else:
tf_tokenizer = AutoTokenizer.from_pretrained(model_path)
conversation.append({"role": "user", "content": prompt})
prompt = tf_tokenizer.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=True
)
prompt_tokens = tokenize(model_kit, prompt)
# Record top logprobs
logprobs_list = []
# Initialize generation stats collector
stats_collector = GenerationStatsCollector()
# Clamp image size
max_img_size = (args.max_img_size, args.max_img_size) if args.max_img_size else None
# Generate the response
generator = create_generator(
model_kit,
prompt_tokens,
images_b64=images_base64,
max_image_size=max_img_size,
stop_strings=args.stop_strings,
max_tokens=1024,
top_logprobs=args.top_logprobs,
prompt_progress_reporter=LoggerReporter()
if args.print_prompt_progress
else None,
num_draft_tokens=args.num_draft_tokens,
temp=args.temp,
)
for generation_result in generator:
print(generation_result.text, end="", flush=True)
stats_collector.add_tokens(generation_result.tokens)
logprobs_list.extend(generation_result.top_logprobs)
if generation_result.stop_condition:
stats_collector.print_stats()
print(
f"\nStopped generation due to: {generation_result.stop_condition.stop_reason}"
)
if generation_result.stop_condition.stop_string:
print(f"Stop string: {generation_result.stop_condition.stop_string}")
if args.top_logprobs:
[print(x) for x in logprobs_list]