-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredictor.py
More file actions
397 lines (321 loc) · 17.9 KB
/
predictor.py
File metadata and controls
397 lines (321 loc) · 17.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
import os
import time
import torch
import threading
from transformers import AutoModelForCausalLM, AutoTokenizer
class Predictor:
"""Ultra-lightweight LLM-based predictive text system with performance optimizations"""
# NOTE: Use newer Qwen/Qwen2.5-0.5B for higher confidence but higher latency
# NOTE: For even lower latency, use distilgpt2 for worse results but lower latency
# NOTE: Test microsoft/Phi-4-mini-instruct (might be better on faster machines) vs qwen
# Default to Qwen/Qwen2-0.5B for lower latency
def __init__(self, model_name="Qwen/Qwen2-0.5B", cache_dir="model_cache", quantize=True,
precompute_common=True, use_half_precision=True):
"""Initialize the lightweight LLM model and tokenizer with optimizations"""
self.cache_dir = cache_dir
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Configure max context length for Qwen2 models
self.max_context_length = 32768 if "qwen2" in model_name.lower() else 1024
# Create cache directory if it doesn't exist
os.makedirs(cache_dir, exist_ok=True)
# Load the tokenizer with padding configuration to fix the warning
print(f"Loading tokenizer for {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=cache_dir,
padding_side='left', # Consistent padding for faster processing
use_fast=True, # Use the fast tokenizer implementation
trust_remote_code=True # Required for Qwen models
)
# Ensure padding token is set correctly
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load the model with optimizations
print(f"Loading model {model_name} on {self.device}...")
start_time = time.time()
flash_attn_available = False
try:
import flash_attn
flash_attn_available = True
print("Flash attention available!")
except ImportError:
pass
# Select appropriate attention implementation
attn_implementation = None
if self.device == "cuda" and flash_attn_available:
attn_implementation = "flash_attention_2"
# Add optimization flags
model_kwargs = {
"cache_dir": cache_dir,
"low_cpu_mem_usage": True,
"torch_dtype": torch.float16 if use_half_precision and self.device == "cuda" else None,
"trust_remote_code": True,
"device_map": "auto" # This will handle device placement automatically
}
# Only add attn_implementation if it's set
if attn_implementation:
model_kwargs["attn_implementation"] = attn_implementation
self.model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
# Apply quantization to reduce memory usage if requested
if quantize and self.device == "cpu":
print("Applying 8-bit quantization to reduce memory usage...")
self.model = torch.quantization.quantize_dynamic(
self.model, {torch.nn.Linear}, dtype=torch.qint8
)
# Move model to device
self.model.to(self.device)
# Enable evaluation mode
self.model.eval()
print(f"Model loaded in {time.time() - start_time:.2f} seconds")
print(f"Model size: {self._get_model_size_mb():.1f} MB")
# Improved multi-level cache for predictions
# Level 1: Exact match cache
self.prediction_cache = {}
# Level 2: Prefix cache - store known good contexts
self.context_cache = {}
# Max cache sizes to prevent memory issues
self.max_cache_size = 200
self.max_context_cache_size = 50
# Precompute common starting predictions
if precompute_common:
self._precompute_common_predictions()
# Start a background thread to periodically clear old cache entries
self._start_cache_maintenance()
# Pre-allocate buffers for inference to reduce memory allocations
self.inference_buffer = None
if self.device == "cuda":
# Pre-warm the GPU to avoid cold start latency
self._warm_up_model()
def _get_model_size_mb(self):
"""Get approximate model size in MB"""
model_size = sum(p.numel() * p.element_size() for p in self.model.parameters())
return model_size / (1024 * 1024)
def _precompute_common_predictions(self):
"""Precompute predictions for common starting phrases"""
common_phrases = [
"", "I", "The", "A", "What", "How", "When", "My name is",
"I am", "Can you", "Please", "Thank you", "Hello", "Hi"
]
print("Precomputing predictions for common phrases...")
for phrase in common_phrases:
# Compute in the background to not block initialization
threading.Thread(
target=lambda p: self.predict_next(p, preload_only=True),
args=(phrase,)
).start()
def _warm_up_model(self):
"""Warm up the model to avoid cold-start latency"""
print("Warming up model...")
dummy_input = self.tokenizer("Hello world", return_tensors="pt").to(self.device)
with torch.no_grad():
_ = self.model(**dummy_input)
def _start_cache_maintenance(self):
"""Start a background thread to periodically clean up the cache"""
def cache_maintenance():
while True:
time.sleep(60) # Run every minute
try:
self._clean_cache()
except Exception as e:
print(f"Error in cache maintenance: {e}")
# Start the maintenance thread
thread = threading.Thread(target=cache_maintenance, daemon=True)
thread.start()
def _clean_cache(self):
"""Clean up old cache entries to prevent memory issues"""
if len(self.prediction_cache) > self.max_cache_size:
# Remove oldest entries (assuming keys added chronologically)
keys = list(self.prediction_cache.keys())
for key in keys[:len(keys) // 2]: # Remove half of the entries
del self.prediction_cache[key]
if len(self.context_cache) > self.max_context_cache_size:
keys = list(self.context_cache.keys())
for key in keys[:len(keys) // 2]:
del self.context_cache[key]
def _get_cached_context(self, text):
"""Find the longest cached context that matches the beginning of the given text"""
matching_contexts = []
for context in self.context_cache:
if text.lower().startswith(context.lower()):
matching_contexts.append(context)
if not matching_contexts:
return None
# Return the longest matching context
return max(matching_contexts, key=len)
def predict_next(self, text, top_k=3, max_length=None, preload_only=False):
"""Predict the most likely next words given the input text
Args:
text: Input text to predict from
top_k: Number of suggestions to return
max_length: Maximum number of words to consider from input
preload_only: If True, only compute and cache the result without returning it
"""
# Handle empty input with fixed predictions to avoid unnecessary computation
if not text.strip():
default_predictions = [("I", 0.25), ("The", 0.2), ("Hello", 0.15)]
return (default_predictions, 0) if not preload_only else None
# Check exact cache match first
cache_key = f"{text.lower()}|{top_k}"
if cache_key in self.prediction_cache:
cached_result = self.prediction_cache[cache_key]
return cached_result if not preload_only else None
# Set max length to ensure we don't process too much text
if max_length is None:
# Only process the last few words for efficiency
max_length = min(10, len(text.split()))
start_time = time.time()
try:
# Get the context (last few words)
tokens = text.strip().split()
context = " ".join(tokens[-max_length:])
# For Qwen2 models, we can leverage more context if available
if len(tokens) > max_length and hasattr(self, 'max_context_length') and self.max_context_length > 1024:
# Take advantage of longer context while keeping efficiency
extended_length = min(len(tokens), 50) # Use up to 50 tokens for extended context
if extended_length > max_length:
extended_context = " ".join(tokens[-extended_length:])
# Check if the extended context is reasonable in size
tokenized = self.tokenizer(extended_context, return_tensors="pt")
if tokenized["input_ids"].shape[1] <= 512: # Still keep it reasonable
context = extended_context
# Use the tokenizer with explicit attention mask to avoid warnings
inputs = self.tokenizer(context, return_tensors="pt", padding=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate predictions more efficiently
with torch.no_grad():
# Use inference mode for additional speedup
with torch.inference_mode():
# Get model's probability distribution over next tokens
outputs = self.model(**inputs)
logits = outputs.logits[0, -1, :]
# Convert to probabilities efficiently
probs = torch.nn.functional.softmax(logits, dim=0)
# Get top predicted token IDs efficiently
# Request more tokens than needed to filter out special tokens
multiplier = 2 if top_k <= 5 else 1.5 # Adaptive multiplier
candidates_k = min(int(top_k * multiplier), len(probs) - 1)
top_tokens = torch.topk(probs, candidates_k)
# Process predictions efficiently
predictions = []
for i, token_id in enumerate(top_tokens.indices):
# Skip special tokens and single character tokens (except common ones)
token = self.tokenizer.decode(token_id).strip()
if (token and
not token.startswith("<") and
not token.endswith(">") and
(len(token) > 1 or token in "aIiAoO")):
predictions.append((token, float(top_tokens.values[i])))
# Stop when we have enough predictions
if len(predictions) >= top_k:
break
# Only generate more tokens if we don't have enough predictions
# and if the model seems to give good quality responses
if len(predictions) < top_k and len(context) > 1:
# Generate efficiently with batched approach
# FIX: Use only max_new_tokens parameter to avoid the warning about max_length
gen_outputs = self.model.generate(
inputs["input_ids"],
max_new_tokens=1, # Just one more token
do_sample=True,
top_k=30,
top_p=0.95,
num_return_sequences=min(top_k - len(predictions), 2),
pad_token_id=self.tokenizer.pad_token_id,
attention_mask=inputs["attention_mask"]
)
for output in gen_outputs:
# Extract only the newly generated token
new_token = self.tokenizer.decode(
output[inputs["input_ids"].shape[1]:].squeeze()
).strip()
if new_token and not new_token.startswith("<") and not new_token.endswith(">"):
# Add with a lower probability to differentiate from top predictions
if (new_token, 0.5) not in predictions:
predictions.append((new_token, 0.5))
# Ensure we have the requested number of predictions
predictions = predictions[:top_k]
# Normalize probabilities for consistency
if predictions:
total = sum(prob for _, prob in predictions)
if total > 0:
predictions = [(word, prob/total) for word, prob in predictions]
# Add default predictions if we don't have enough
while len(predictions) < top_k:
default_words = ["the", "and", "to", "a", "of", "is", "in", "for", "that"]
for word in default_words:
if not any(word == p[0].lower() for p in predictions):
predictions.append((word, 0.1))
break
# Break if we've gone through all default words
if len(predictions) < top_k:
break
prediction_time = time.time() - start_time
# Cache the result
result = (predictions, prediction_time)
self.prediction_cache[cache_key] = result
# Also cache a simplified form of the context
simple_context = " ".join(context.strip().lower().split()[:5])
if simple_context and len(simple_context) > 2:
self.context_cache[simple_context] = True
return result if not preload_only else None
except Exception as e:
print(f"Error in prediction: {e}")
# Provide fallback predictions
return ([("the", 0.3), ("a", 0.2), ("is", 0.1)], 0) if not preload_only else None
def batch_predict(self, texts, top_k=3):
"""Predict for multiple inputs at once for improved efficiency"""
results = []
# Check cache first for all inputs
cached_results = []
texts_to_process = []
for text in texts:
cache_key = f"{text.lower()}|{top_k}"
if cache_key in self.prediction_cache:
cached_results.append((True, self.prediction_cache[cache_key]))
texts_to_process.append(None)
else:
cached_results.append((False, None))
texts_to_process.append(text)
# Process all non-cached texts in a batch
if any(text is not None for text in texts_to_process):
valid_texts = [t for t in texts_to_process if t is not None]
# Tokenize all inputs at once
batch_inputs = self.tokenizer(valid_texts, return_tensors="pt", padding=True)
batch_inputs = {k: v.to(self.device) for k, v in batch_inputs.items()}
with torch.no_grad():
with torch.inference_mode():
# Get model outputs for the batch
batch_outputs = self.model(**batch_inputs)
# Process each output individually
batch_results = []
for i, logits in enumerate(batch_outputs.logits):
# Get probabilities for the last token
next_token_logits = logits[-1, :]
next_token_probs = torch.nn.functional.softmax(next_token_logits, dim=0)
# Get top tokens
top_tokens = torch.topk(next_token_probs, min(top_k * 2, len(next_token_probs)))
# Convert to words
predictions = []
for j, token_id in enumerate(top_tokens.indices):
token = self.tokenizer.decode(token_id).strip()
if token and not token.startswith("<") and not token.endswith(">") and len(token) > 1:
predictions.append((token, float(top_tokens.values[j])))
if len(predictions) >= top_k:
break
# Normalize probabilities
if predictions:
total = sum(prob for _, prob in predictions)
predictions = [(word, prob/total) for word, prob in predictions]
batch_results.append(predictions)
# Map batch results back to the original texts
idx = 0
for i, (is_cached, _) in enumerate(cached_results):
if not is_cached:
result = batch_results[idx]
idx += 1
# Cache the new result
cache_key = f"{texts[i].lower()}|{top_k}"
self.prediction_cache[cache_key] = (result, 0) # No timing info for batch
cached_results[i] = (True, (result, 0))
# Return all results
return [result for _, result in cached_results]