Skip to content

Commit 3e3c4ef

Browse files
committed
Add robust MLX generation with multiple fallbacks
Introduces a _robust_mlx_generate method that attempts MLX text generation using several parameter combinations to handle different MLX-LM versions. Improves error handling and logging for easier debugging, and ensures token counting is robust to different response types.
1 parent c50394e commit 3e3c4ef

File tree

1 file changed

+91
-12
lines changed

1 file changed

+91
-12
lines changed

optillm/inference.py

Lines changed: 91 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -317,29 +317,27 @@ def generate(
317317
try:
318318
logger.debug(f"Generating with MLX: max_tokens={max_tokens}, temp={temperature}")
319319

320-
# Use MLX generate function
321-
response = mlx_generate(
322-
model=self.model,
323-
tokenizer=self.tokenizer,
324-
prompt=prompt,
325-
max_tokens=max_tokens,
326-
temperature=temperature,
327-
top_p=top_p,
328-
repetition_penalty=repetition_penalty,
329-
verbose=False
320+
# Use robust MLX generation with multiple fallback approaches
321+
response = self._robust_mlx_generate(
322+
prompt, max_tokens, temperature, top_p, repetition_penalty
330323
)
331324

332325
responses.append(response)
333326

334-
# Count tokens (approximate)
335-
token_count = len(self.tokenizer.encode(response))
327+
# Count tokens (approximate) - check if response is string
328+
if isinstance(response, str):
329+
token_count = len(self.tokenizer.encode(response))
330+
else:
331+
# Sometimes MLX returns just the new tokens, get the actual text
332+
token_count = len(response) if hasattr(response, '__len__') else 0
336333
token_counts.append(token_count)
337334

338335
# MLX doesn't provide logprobs by default
339336
logprobs_results.append(None)
340337

341338
except Exception as e:
342339
logger.error(f"Error during MLX generation: {str(e)}")
340+
logger.error(f"MLX generation parameters: max_tokens={max_tokens}, temp={temperature}, top_p={top_p}")
343341
responses.append("")
344342
token_counts.append(0)
345343
logprobs_results.append(None)
@@ -349,6 +347,87 @@ def generate(
349347

350348
return responses, token_counts, logprobs_results
351349

350+
def _robust_mlx_generate(self, prompt: str, max_tokens: int, temperature: float, top_p: float, repetition_penalty: float) -> str:
351+
"""Robust MLX generation with multiple parameter combinations"""
352+
353+
# Try different parameter combinations based on MLX-LM version
354+
parameter_combinations = [
355+
# Version 1: Current style with positional args and temp
356+
{
357+
"style": "positional_temp",
358+
"args": (self.model, self.tokenizer, prompt),
359+
"kwargs": {
360+
"max_tokens": max_tokens,
361+
"temp": temperature,
362+
"top_p": top_p,
363+
"repetition_penalty": repetition_penalty,
364+
"verbose": False
365+
}
366+
},
367+
# Version 2: All keyword arguments with temp
368+
{
369+
"style": "keyword_temp",
370+
"args": (),
371+
"kwargs": {
372+
"model": self.model,
373+
"tokenizer": self.tokenizer,
374+
"prompt": prompt,
375+
"max_tokens": max_tokens,
376+
"temp": temperature,
377+
"top_p": top_p,
378+
"repetition_penalty": repetition_penalty,
379+
"verbose": False
380+
}
381+
},
382+
# Version 3: Using temperature instead of temp
383+
{
384+
"style": "positional_temperature",
385+
"args": (self.model, self.tokenizer, prompt),
386+
"kwargs": {
387+
"max_tokens": max_tokens,
388+
"temperature": temperature,
389+
"top_p": top_p,
390+
"repetition_penalty": repetition_penalty,
391+
"verbose": False
392+
}
393+
},
394+
# Version 4: Minimal parameters only
395+
{
396+
"style": "minimal",
397+
"args": (self.model, self.tokenizer, prompt),
398+
"kwargs": {
399+
"max_tokens": max_tokens,
400+
"temp": temperature,
401+
"verbose": False
402+
}
403+
},
404+
# Version 5: Just essential parameters
405+
{
406+
"style": "essential",
407+
"args": (self.model, self.tokenizer, prompt),
408+
"kwargs": {
409+
"max_tokens": max_tokens
410+
}
411+
}
412+
]
413+
414+
last_error = None
415+
416+
for combo in parameter_combinations:
417+
try:
418+
logger.debug(f"Trying MLX generation with style: {combo['style']}")
419+
response = mlx_generate(*combo["args"], **combo["kwargs"])
420+
logger.debug(f"Successfully generated with style: {combo['style']}")
421+
return response
422+
423+
except Exception as e:
424+
last_error = e
425+
logger.debug(f"Failed with style {combo['style']}: {str(e)}")
426+
continue
427+
428+
# If all combinations failed, raise the last error
429+
raise RuntimeError(f"All MLX generation methods failed. Last error: {str(last_error)}")
430+
352431
def format_chat_prompt(self, system_prompt: str, user_prompt: str) -> str:
353432
"""Format the prompt according to model's chat template"""
354433
if hasattr(self.tokenizer, 'apply_chat_template'):

0 commit comments

Comments
 (0)