@@ -317,29 +317,27 @@ def generate(
317317 try :
318318 logger .debug (f"Generating with MLX: max_tokens={ max_tokens } , temp={ temperature } " )
319319
320- # Use MLX generate function
321- response = mlx_generate (
322- model = self .model ,
323- tokenizer = self .tokenizer ,
324- prompt = prompt ,
325- max_tokens = max_tokens ,
326- temperature = temperature ,
327- top_p = top_p ,
328- repetition_penalty = repetition_penalty ,
329- verbose = False
320+ # Use robust MLX generation with multiple fallback approaches
321+ response = self ._robust_mlx_generate (
322+ prompt , max_tokens , temperature , top_p , repetition_penalty
330323 )
331324
332325 responses .append (response )
333326
334- # Count tokens (approximate)
335- token_count = len (self .tokenizer .encode (response ))
327+ # Count tokens (approximate) - check if response is string
328+ if isinstance (response , str ):
329+ token_count = len (self .tokenizer .encode (response ))
330+ else :
331+ # Sometimes MLX returns just the new tokens, get the actual text
332+ token_count = len (response ) if hasattr (response , '__len__' ) else 0
336333 token_counts .append (token_count )
337334
338335 # MLX doesn't provide logprobs by default
339336 logprobs_results .append (None )
340337
341338 except Exception as e :
342339 logger .error (f"Error during MLX generation: { str (e )} " )
340+ logger .error (f"MLX generation parameters: max_tokens={ max_tokens } , temp={ temperature } , top_p={ top_p } " )
343341 responses .append ("" )
344342 token_counts .append (0 )
345343 logprobs_results .append (None )
@@ -349,6 +347,87 @@ def generate(
349347
350348 return responses , token_counts , logprobs_results
351349
350+ def _robust_mlx_generate (self , prompt : str , max_tokens : int , temperature : float , top_p : float , repetition_penalty : float ) -> str :
351+ """Robust MLX generation with multiple parameter combinations"""
352+
353+ # Try different parameter combinations based on MLX-LM version
354+ parameter_combinations = [
355+ # Version 1: Current style with positional args and temp
356+ {
357+ "style" : "positional_temp" ,
358+ "args" : (self .model , self .tokenizer , prompt ),
359+ "kwargs" : {
360+ "max_tokens" : max_tokens ,
361+ "temp" : temperature ,
362+ "top_p" : top_p ,
363+ "repetition_penalty" : repetition_penalty ,
364+ "verbose" : False
365+ }
366+ },
367+ # Version 2: All keyword arguments with temp
368+ {
369+ "style" : "keyword_temp" ,
370+ "args" : (),
371+ "kwargs" : {
372+ "model" : self .model ,
373+ "tokenizer" : self .tokenizer ,
374+ "prompt" : prompt ,
375+ "max_tokens" : max_tokens ,
376+ "temp" : temperature ,
377+ "top_p" : top_p ,
378+ "repetition_penalty" : repetition_penalty ,
379+ "verbose" : False
380+ }
381+ },
382+ # Version 3: Using temperature instead of temp
383+ {
384+ "style" : "positional_temperature" ,
385+ "args" : (self .model , self .tokenizer , prompt ),
386+ "kwargs" : {
387+ "max_tokens" : max_tokens ,
388+ "temperature" : temperature ,
389+ "top_p" : top_p ,
390+ "repetition_penalty" : repetition_penalty ,
391+ "verbose" : False
392+ }
393+ },
394+ # Version 4: Minimal parameters only
395+ {
396+ "style" : "minimal" ,
397+ "args" : (self .model , self .tokenizer , prompt ),
398+ "kwargs" : {
399+ "max_tokens" : max_tokens ,
400+ "temp" : temperature ,
401+ "verbose" : False
402+ }
403+ },
404+ # Version 5: Just essential parameters
405+ {
406+ "style" : "essential" ,
407+ "args" : (self .model , self .tokenizer , prompt ),
408+ "kwargs" : {
409+ "max_tokens" : max_tokens
410+ }
411+ }
412+ ]
413+
414+ last_error = None
415+
416+ for combo in parameter_combinations :
417+ try :
418+ logger .debug (f"Trying MLX generation with style: { combo ['style' ]} " )
419+ response = mlx_generate (* combo ["args" ], ** combo ["kwargs" ])
420+ logger .debug (f"Successfully generated with style: { combo ['style' ]} " )
421+ return response
422+
423+ except Exception as e :
424+ last_error = e
425+ logger .debug (f"Failed with style { combo ['style' ]} : { str (e )} " )
426+ continue
427+
428+ # If all combinations failed, raise the last error
429+ raise RuntimeError (f"All MLX generation methods failed. Last error: { str (last_error )} " )
430+
352431 def format_chat_prompt (self , system_prompt : str , user_prompt : str ) -> str :
353432 """Format the prompt according to model's chat template"""
354433 if hasattr (self .tokenizer , 'apply_chat_template' ):
0 commit comments