Skip to content

Commit c50394e

Browse files
committed
j
1 parent 34b57c9 commit c50394e

File tree

3 files changed

+253
-10
lines changed

3 files changed

+253
-10
lines changed

optillm/inference.py

Lines changed: 248 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import time
1717
import threading
1818
import traceback
19+
import platform
20+
import sys
1921

2022
from optillm.cot_decoding import cot_decode
2123
from optillm.entropy_decoding import entropy_decode
@@ -26,6 +28,17 @@
2628
logging.basicConfig(level=logging.INFO)
2729
logger = logging.getLogger(__name__)
2830

31+
# MLX Support for Apple Silicon
32+
try:
33+
import mlx.core as mx
34+
from mlx_lm import load as mlx_load, generate as mlx_generate
35+
from mlx_lm.tokenizer_utils import TokenizerWrapper
36+
MLX_AVAILABLE = True
37+
logger.info("MLX framework available")
38+
except ImportError:
39+
MLX_AVAILABLE = False
40+
logger.debug("MLX framework not available - falling back to PyTorch")
41+
2942
@dataclass
3043
class ModelConfig:
3144
base_model_id: str
@@ -162,6 +175,223 @@ def calculate_logprobs(
162175
bytes_per_token=all_bytes
163176
)
164177

178+
# MLX Support Functions and Classes
179+
180+
def is_apple_silicon() -> bool:
181+
"""Check if running on Apple Silicon"""
182+
return platform.system() == "Darwin" and platform.machine() == "arm64"
183+
184+
def should_use_mlx(model_id: str) -> bool:
185+
"""Determine if a model should use MLX instead of PyTorch"""
186+
if not MLX_AVAILABLE or not is_apple_silicon():
187+
return False
188+
189+
# Models that should use MLX
190+
mlx_patterns = [
191+
"mlx-community/",
192+
"mlx-"
193+
]
194+
195+
# Known problematic models that should prefer MLX on Apple Silicon
196+
problematic_models = [
197+
"Qwen/Qwen3-",
198+
"google/gemma-3-",
199+
"google/gemma3-"
200+
]
201+
202+
model_lower = model_id.lower()
203+
204+
# Direct MLX model detection
205+
for pattern in mlx_patterns:
206+
if pattern.lower() in model_lower:
207+
return True
208+
209+
# Problematic model detection
210+
for pattern in problematic_models:
211+
if pattern.lower() in model_lower:
212+
logger.warning(f"Model {model_id} detected as potentially problematic with MPS backend")
213+
suggested_mlx = suggest_mlx_alternative(model_id)
214+
logger.warning(f"Consider using MLX model: {suggested_mlx}")
215+
# Don't auto-switch, but recommend
216+
return False
217+
218+
return False
219+
220+
def suggest_mlx_alternative(model_id: str) -> str:
221+
"""Suggest MLX alternative for a given model"""
222+
mlx_alternatives = {
223+
# Qwen3 models
224+
"Qwen/Qwen3-0.6B": "mlx-community/Qwen3-0.6B-4bit",
225+
"Qwen/Qwen3-1.7B": "mlx-community/Qwen3-1.7B-4bit",
226+
"Qwen/Qwen3-4B": "mlx-community/Qwen3-4B-4bit",
227+
"Qwen/Qwen3-8B": "mlx-community/Qwen3-8B-4bit",
228+
"Qwen/Qwen3-14B": "mlx-community/Qwen3-14B-4bit",
229+
"Qwen/Qwen3-32B": "mlx-community/Qwen3-32B-4bit",
230+
231+
# Gemma 3 models
232+
"google/gemma-3-1b-it": "mlx-community/gemma-3-1b-it-4bit",
233+
"google/gemma-3-4b-it": "mlx-community/gemma-3-4b-it-4bit",
234+
"google/gemma-3-12b-it": "mlx-community/gemma-3-12b-it-4bit",
235+
"google/gemma-3-27b-it": "mlx-community/gemma-3-27b-it-4bit",
236+
}
237+
238+
return mlx_alternatives.get(model_id, f"mlx-community/{model_id.split('/')[-1]}-4bit")
239+
240+
@dataclass
241+
class MLXModelConfig:
242+
"""Configuration for MLX models"""
243+
model_id: str
244+
max_new_tokens: int = 4096
245+
temperature: float = 0.7
246+
top_p: float = 0.9
247+
repetition_penalty: float = 1.0
248+
enable_prompt_caching: bool = True
249+
250+
class MLXInferencePipeline:
251+
"""MLX-based inference pipeline that mirrors PyTorch pipeline interface"""
252+
253+
def __init__(self, model_config: MLXModelConfig, cache_manager):
254+
self.model_config = model_config
255+
self.cache_manager = cache_manager
256+
self.last_used = time.time()
257+
258+
if not MLX_AVAILABLE:
259+
raise RuntimeError("MLX framework not available. Install with: pip install mlx-lm")
260+
261+
if not is_apple_silicon():
262+
raise RuntimeError("MLX framework is only supported on Apple Silicon")
263+
264+
try:
265+
logger.info(f"Loading MLX model: {model_config.model_id}")
266+
self.model, self.tokenizer = self._load_mlx_model(model_config.model_id)
267+
logger.info("MLX model loaded successfully")
268+
except Exception as e:
269+
logger.error(f"Failed to load MLX model: {str(e)}")
270+
raise
271+
272+
def _load_mlx_model(self, model_id: str):
273+
"""Load MLX model and tokenizer with caching"""
274+
def _load_model():
275+
start_time = time.time()
276+
logger.info(f"Loading MLX model: {model_id}")
277+
278+
try:
279+
model, tokenizer = mlx_load(model_id)
280+
load_time = time.time() - start_time
281+
logger.info(f"MLX model loaded in {load_time:.2f}s")
282+
return model, tokenizer
283+
except Exception as e:
284+
logger.error(f"Error loading MLX model {model_id}: {str(e)}")
285+
raise
286+
287+
return self.cache_manager.get_or_load_model(f"mlx_{model_id}", _load_model)
288+
289+
def generate(
290+
self,
291+
prompt: str,
292+
generation_params: Optional[Dict[str, Any]] = None
293+
) -> Tuple[List[str], List[int], List[Optional[Dict]]]:
294+
"""Generate text using MLX"""
295+
start_time = time.time()
296+
297+
if generation_params is None:
298+
generation_params = {}
299+
300+
# Extract parameters with defaults
301+
max_tokens = generation_params.get("max_new_tokens", self.model_config.max_new_tokens)
302+
temperature = generation_params.get("temperature", self.model_config.temperature)
303+
top_p = generation_params.get("top_p", self.model_config.top_p)
304+
repetition_penalty = generation_params.get("repetition_penalty", self.model_config.repetition_penalty)
305+
num_return_sequences = generation_params.get("num_return_sequences", 1)
306+
307+
# Handle seed
308+
if generation_params.get("seed") is not None:
309+
mx.random.seed(generation_params["seed"])
310+
311+
responses = []
312+
token_counts = []
313+
logprobs_results = []
314+
315+
# Generate multiple sequences if requested
316+
for _ in range(num_return_sequences):
317+
try:
318+
logger.debug(f"Generating with MLX: max_tokens={max_tokens}, temp={temperature}")
319+
320+
# Use MLX generate function
321+
response = mlx_generate(
322+
model=self.model,
323+
tokenizer=self.tokenizer,
324+
prompt=prompt,
325+
max_tokens=max_tokens,
326+
temperature=temperature,
327+
top_p=top_p,
328+
repetition_penalty=repetition_penalty,
329+
verbose=False
330+
)
331+
332+
responses.append(response)
333+
334+
# Count tokens (approximate)
335+
token_count = len(self.tokenizer.encode(response))
336+
token_counts.append(token_count)
337+
338+
# MLX doesn't provide logprobs by default
339+
logprobs_results.append(None)
340+
341+
except Exception as e:
342+
logger.error(f"Error during MLX generation: {str(e)}")
343+
responses.append("")
344+
token_counts.append(0)
345+
logprobs_results.append(None)
346+
347+
generation_time = time.time() - start_time
348+
logger.info(f"MLX generation completed in {generation_time:.2f}s")
349+
350+
return responses, token_counts, logprobs_results
351+
352+
def format_chat_prompt(self, system_prompt: str, user_prompt: str) -> str:
353+
"""Format the prompt according to model's chat template"""
354+
if hasattr(self.tokenizer, 'apply_chat_template'):
355+
messages = [
356+
{"role": "system", "content": system_prompt},
357+
{"role": "user", "content": user_prompt}
358+
]
359+
try:
360+
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
361+
except Exception as e:
362+
logger.warning(f"Failed to apply chat template: {e}, using fallback")
363+
return f"System: {system_prompt}\n\nUser: {user_prompt}\n\nAssistant:"
364+
else:
365+
return f"System: {system_prompt}\n\nUser: {user_prompt}\n\nAssistant:"
366+
367+
class MLXManager:
368+
"""Manager for MLX models and operations"""
369+
370+
def __init__(self, cache_manager):
371+
self.cache_manager = cache_manager
372+
self.available = MLX_AVAILABLE and is_apple_silicon()
373+
374+
if self.available:
375+
logger.info("MLX manager initialized - Apple Silicon detected")
376+
else:
377+
logger.debug("MLX manager not available - requires Apple Silicon and mlx-lm")
378+
379+
def create_pipeline(self, model_id: str, **kwargs) -> MLXInferencePipeline:
380+
"""Create an MLX inference pipeline"""
381+
if not self.available:
382+
raise RuntimeError("MLX not available on this platform")
383+
384+
config = MLXModelConfig(
385+
model_id=model_id,
386+
**kwargs
387+
)
388+
389+
return MLXInferencePipeline(config, self.cache_manager)
390+
391+
def is_mlx_model(self, model_id: str) -> bool:
392+
"""Check if model should use MLX"""
393+
return should_use_mlx(model_id)
394+
165395
class MemoryEfficientAttention(nn.Module):
166396
"""
167397
Memory-efficient attention using linear attention mechanism.
@@ -1286,18 +1516,27 @@ def __init__(self):
12861516
self.device_manager = DeviceManager()
12871517
self.model_manager = ModelManager(self.cache_manager, self.device_manager)
12881518
self.lora_manager = LoRAManager(self.cache_manager)
1519+
self.mlx_manager = MLXManager(self.cache_manager)
12891520
self.chat = self.Chat(self)
12901521
self.models = self.Models()
12911522

1292-
def get_pipeline(self, model: str) -> 'InferencePipeline':
1293-
model_config = parse_model_string(model)
1294-
return InferencePipeline(
1295-
model_config,
1296-
self.cache_manager,
1297-
self.device_manager,
1298-
self.model_manager,
1299-
self.lora_manager
1300-
)
1523+
def get_pipeline(self, model: str):
1524+
"""Get inference pipeline - automatically chooses MLX or PyTorch based on model"""
1525+
# Check if should use MLX
1526+
if self.mlx_manager.available and should_use_mlx(model):
1527+
logger.info(f"Using MLX pipeline for model: {model}")
1528+
return self.mlx_manager.create_pipeline(model)
1529+
else:
1530+
# Use existing PyTorch pipeline
1531+
logger.info(f"Using PyTorch pipeline for model: {model}")
1532+
model_config = parse_model_string(model)
1533+
return InferencePipeline(
1534+
model_config,
1535+
self.cache_manager,
1536+
self.device_manager,
1537+
self.model_manager,
1538+
self.lora_manager
1539+
)
13011540

13021541
class Chat:
13031542
"""OpenAI-compatible chat interface"""

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,6 @@ cerebras_cloud_sdk
2828
outlines[transformers]
2929
sentencepiece
3030
adaptive-classifier
31-
mcp
31+
mcp
32+
# MLX support for Apple Silicon optimization
33+
mlx-lm>=0.24.0; platform_machine=="arm64" and sys_platform=="darwin"

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
"sentencepiece",
4747
"mcp",
4848
"adaptive-classifier",
49+
# MLX support for Apple Silicon optimization
50+
'mlx-lm>=0.24.0; platform_machine=="arm64" and sys_platform=="darwin"',
4951
],
5052
entry_points={
5153
'console_scripts': [

0 commit comments

Comments
 (0)