|
22 | 22 | from optillm.cot_decoding import cot_decode |
23 | 23 | from optillm.entropy_decoding import entropy_decode |
24 | 24 | from optillm.thinkdeeper import thinkdeeper_decode |
| 25 | +from optillm.thinkdeeper_mlx import thinkdeeper_decode_mlx |
25 | 26 | from optillm.autothink import autothink_decode |
26 | 27 |
|
27 | 28 | # Configure logging |
|
33 | 34 | import mlx.core as mx |
34 | 35 | from mlx_lm import load as mlx_load, generate as mlx_generate |
35 | 36 | from mlx_lm.tokenizer_utils import TokenizerWrapper |
| 37 | + from mlx_lm.sample_utils import make_sampler |
36 | 38 | MLX_AVAILABLE = True |
37 | 39 | logger.info("MLX framework available") |
38 | 40 | except ImportError: |
@@ -349,85 +351,46 @@ def generate( |
349 | 351 | return responses, token_counts, logprobs_results |
350 | 352 |
|
351 | 353 | def _robust_mlx_generate(self, prompt: str, max_tokens: int, temperature: float, top_p: float, repetition_penalty: float) -> str: |
352 | | - """Robust MLX generation with multiple parameter combinations""" |
353 | | - |
354 | | - # Try different parameter combinations based on MLX-LM version |
355 | | - parameter_combinations = [ |
356 | | - # Version 1: Current style with positional args and temp |
357 | | - { |
358 | | - "style": "positional_temp", |
359 | | - "args": (self.model, self.tokenizer, prompt), |
360 | | - "kwargs": { |
361 | | - "max_tokens": max_tokens, |
362 | | - "temp": temperature, |
363 | | - "top_p": top_p, |
364 | | - "repetition_penalty": repetition_penalty, |
365 | | - "verbose": False |
366 | | - } |
367 | | - }, |
368 | | - # Version 2: All keyword arguments with temp |
369 | | - { |
370 | | - "style": "keyword_temp", |
371 | | - "args": (), |
372 | | - "kwargs": { |
373 | | - "model": self.model, |
374 | | - "tokenizer": self.tokenizer, |
375 | | - "prompt": prompt, |
376 | | - "max_tokens": max_tokens, |
377 | | - "temp": temperature, |
378 | | - "top_p": top_p, |
379 | | - "repetition_penalty": repetition_penalty, |
380 | | - "verbose": False |
381 | | - } |
382 | | - }, |
383 | | - # Version 3: Using temperature instead of temp |
384 | | - { |
385 | | - "style": "positional_temperature", |
386 | | - "args": (self.model, self.tokenizer, prompt), |
387 | | - "kwargs": { |
388 | | - "max_tokens": max_tokens, |
389 | | - "temperature": temperature, |
390 | | - "top_p": top_p, |
391 | | - "repetition_penalty": repetition_penalty, |
392 | | - "verbose": False |
393 | | - } |
394 | | - }, |
395 | | - # Version 4: Minimal parameters only |
396 | | - { |
397 | | - "style": "minimal", |
398 | | - "args": (self.model, self.tokenizer, prompt), |
399 | | - "kwargs": { |
400 | | - "max_tokens": max_tokens, |
401 | | - "temp": temperature, |
402 | | - "verbose": False |
403 | | - } |
404 | | - }, |
405 | | - # Version 5: Just essential parameters |
406 | | - { |
407 | | - "style": "essential", |
408 | | - "args": (self.model, self.tokenizer, prompt), |
409 | | - "kwargs": { |
410 | | - "max_tokens": max_tokens |
411 | | - } |
412 | | - } |
413 | | - ] |
| 354 | + """Robust MLX generation using sampler approach""" |
414 | 355 |
|
415 | | - last_error = None |
416 | | - |
417 | | - for combo in parameter_combinations: |
| 356 | + try: |
| 357 | + # Create sampler with generation parameters |
| 358 | + sampler = make_sampler( |
| 359 | + temp=temperature, |
| 360 | + top_p=top_p, |
| 361 | + min_p=0.0, # Default min_p |
| 362 | + min_tokens_to_keep=1 # Default min_tokens_to_keep |
| 363 | + ) |
| 364 | + |
| 365 | + # Generate using the sampler |
| 366 | + response = mlx_generate( |
| 367 | + self.model, |
| 368 | + self.tokenizer, |
| 369 | + prompt, |
| 370 | + max_tokens=max_tokens, |
| 371 | + sampler=sampler, |
| 372 | + verbose=False |
| 373 | + ) |
| 374 | + |
| 375 | + return response |
| 376 | + |
| 377 | + except Exception as e: |
| 378 | + logger.error(f"MLX generation with sampler failed: {str(e)}") |
| 379 | + |
| 380 | + # Fallback: Try minimal parameters without sampler |
418 | 381 | try: |
419 | | - logger.debug(f"Trying MLX generation with style: {combo['style']}") |
420 | | - response = mlx_generate(*combo["args"], **combo["kwargs"]) |
421 | | - logger.debug(f"Successfully generated with style: {combo['style']}") |
| 382 | + logger.debug("Attempting MLX generation without sampler") |
| 383 | + response = mlx_generate( |
| 384 | + self.model, |
| 385 | + self.tokenizer, |
| 386 | + prompt, |
| 387 | + max_tokens=max_tokens, |
| 388 | + verbose=False |
| 389 | + ) |
422 | 390 | return response |
423 | | - |
424 | | - except Exception as e: |
425 | | - last_error = e |
426 | | - logger.debug(f"Failed with style {combo['style']}: {str(e)}") |
427 | | - continue |
428 | | - |
429 | | - # If all combinations failed, raise the last error |
430 | | - raise RuntimeError(f"All MLX generation methods failed. Last error: {str(last_error)}") |
| 391 | + except Exception as fallback_e: |
| 392 | + logger.error(f"MLX fallback generation also failed: {str(fallback_e)}") |
| 393 | + raise |
431 | 394 |
|
432 | 395 | def format_chat_prompt(self, system_prompt: str, user_prompt: str) -> str: |
433 | 396 | """Format the prompt according to model's chat template""" |
@@ -1691,37 +1654,47 @@ def create( |
1691 | 1654 | if decoding: |
1692 | 1655 | logger.info(f"Using specialized decoding approach: {decoding}") |
1693 | 1656 |
|
1694 | | - # Ensure model is in eval mode and on correct device |
1695 | | - pipeline.current_model.eval() |
1696 | | - device = pipeline.current_model.device |
| 1657 | + # Check if this decoding approach is supported for MLX |
| 1658 | + mlx_unsupported_decodings = ["cot_decoding", "entropy_decoding", "autothink"] |
| 1659 | + if isinstance(pipeline, MLXInferencePipeline) and decoding in mlx_unsupported_decodings: |
| 1660 | + logger.warning(f"{decoding} is not supported for MLX models. Falling back to standard generation.") |
| 1661 | + decoding = None |
| 1662 | + |
| 1663 | + if decoding: |
| 1664 | + # For PyTorch pipelines, ensure model is in eval mode and get device |
| 1665 | + # MLX pipelines handle this differently |
| 1666 | + if not isinstance(pipeline, MLXInferencePipeline): |
| 1667 | + pipeline.current_model.eval() |
| 1668 | + device = pipeline.current_model.device |
| 1669 | + else: |
| 1670 | + device = None # MLX doesn't use torch devices |
1697 | 1671 |
|
1698 | 1672 | if decoding == "cot_decoding": |
1699 | 1673 | # Use directly available parameters for CoT |
1700 | | - cot_params = { |
1701 | | - "k": k, |
1702 | | - "num_beams": num_beams, |
1703 | | - "max_new_tokens": max_tokens if max_tokens is not None else 512, |
1704 | | - "temperature": temperature, |
1705 | | - "top_p": top_p, |
1706 | | - "repetition_penalty": 1.0, |
1707 | | - "length_penalty": length_penalty, |
1708 | | - "no_repeat_ngram_size": no_repeat_ngram_size, |
1709 | | - "early_stopping": early_stopping, |
1710 | | - "aggregate_paths": aggregate_paths, |
1711 | | - } |
1712 | | - |
1713 | | - result, confidence = cot_decode( |
1714 | | - pipeline.current_model, |
1715 | | - pipeline.tokenizer, |
1716 | | - messages, |
1717 | | - **cot_params |
1718 | | - ) |
1719 | | - responses = [result] |
1720 | | - logprobs_results = [{"confidence_score": confidence} if confidence is not None else None] |
1721 | | - completion_tokens = len(pipeline.tokenizer.encode(result)) |
| 1674 | + cot_params = { |
| 1675 | + "k": k, |
| 1676 | + "num_beams": num_beams, |
| 1677 | + "max_new_tokens": max_tokens if max_tokens is not None else 512, |
| 1678 | + "temperature": temperature, |
| 1679 | + "top_p": top_p, |
| 1680 | + "repetition_penalty": 1.0, |
| 1681 | + "length_penalty": length_penalty, |
| 1682 | + "no_repeat_ngram_size": no_repeat_ngram_size, |
| 1683 | + "early_stopping": early_stopping, |
| 1684 | + "aggregate_paths": aggregate_paths, |
| 1685 | + } |
| 1686 | + |
| 1687 | + result, confidence = cot_decode( |
| 1688 | + pipeline.current_model, |
| 1689 | + pipeline.tokenizer, |
| 1690 | + messages, |
| 1691 | + **cot_params |
| 1692 | + ) |
| 1693 | + responses = [result] |
| 1694 | + logprobs_results = [{"confidence_score": confidence} if confidence is not None else None] |
| 1695 | + completion_tokens = len(pipeline.tokenizer.encode(result)) |
1722 | 1696 |
|
1723 | 1697 | elif decoding == "entropy_decoding": |
1724 | | - |
1725 | 1698 | # Ensure model is using full precision |
1726 | 1699 | original_dtype = pipeline.current_model.dtype |
1727 | 1700 | pipeline.current_model = pipeline.current_model.to(torch.float32) |
@@ -1778,43 +1751,66 @@ def create( |
1778 | 1751 | } |
1779 | 1752 | thinkdeeper_config.update(custom_config) |
1780 | 1753 |
|
1781 | | - result = thinkdeeper_decode( |
1782 | | - pipeline.current_model, |
1783 | | - pipeline.tokenizer, |
1784 | | - messages, |
1785 | | - thinkdeeper_config |
| 1754 | + # Check if we're using MLX pipeline |
| 1755 | + if isinstance(pipeline, MLXInferencePipeline): |
| 1756 | + logger.info("Using MLX ThinkDeeper implementation") |
| 1757 | + |
| 1758 | + # Ensure we have enough tokens for thinking + response |
| 1759 | + user_max_tokens = max_tokens if max_tokens is not None else 512 |
| 1760 | + total_tokens_needed = max_thinking_tokens + 512 # thinking + response buffer |
| 1761 | + adjusted_max_tokens = max(user_max_tokens, total_tokens_needed) |
| 1762 | + |
| 1763 | + # Add max_tokens to thinkdeeper config |
| 1764 | + thinkdeeper_config_with_tokens = thinkdeeper_config.copy() |
| 1765 | + thinkdeeper_config_with_tokens["max_tokens"] = adjusted_max_tokens |
| 1766 | + |
| 1767 | + logger.debug(f"ThinkDeeper tokens: user={user_max_tokens}, thinking={max_thinking_tokens}, adjusted={adjusted_max_tokens}") |
| 1768 | + |
| 1769 | + result = thinkdeeper_decode_mlx( |
| 1770 | + pipeline.model, |
| 1771 | + pipeline.tokenizer, |
| 1772 | + messages, |
| 1773 | + thinkdeeper_config_with_tokens |
| 1774 | + ) |
| 1775 | + else: |
| 1776 | + logger.info("Using PyTorch ThinkDeeper implementation") |
| 1777 | + result = thinkdeeper_decode( |
| 1778 | + pipeline.current_model, |
| 1779 | + pipeline.tokenizer, |
| 1780 | + messages, |
| 1781 | + thinkdeeper_config |
1786 | 1782 | ) |
1787 | 1783 | responses = [result] |
1788 | 1784 | logprobs_results = [None] |
1789 | 1785 | completion_tokens = len(pipeline.tokenizer.encode(result)) |
1790 | 1786 | elif decoding == "autothink": |
1791 | 1787 | # Get steering dataset configuration |
1792 | | - steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors") |
1793 | | - target_layer = kwargs.get("target_layer", 19) |
1794 | | - |
1795 | | - # Prepare AutoThink configuration |
1796 | | - autothink_config = { |
1797 | | - "steering_dataset": steering_dataset, |
1798 | | - "target_layer": target_layer, |
1799 | | - "pattern_strengths": kwargs.get("pattern_strengths", { |
1800 | | - "depth_and_thoroughness": 2.5, |
1801 | | - "numerical_accuracy": 2.0, |
1802 | | - "self_correction": 3.0, |
1803 | | - "exploration": 2.0, |
1804 | | - "organization": 1.5 |
1805 | | - }) |
1806 | | - } |
1807 | | - |
1808 | | - # Process with AutoThink |
1809 | | - result = autothink_decode( |
1810 | | - pipeline.current_model, |
1811 | | - pipeline.tokenizer, |
1812 | | - messages, |
1813 | | - autothink_config |
1814 | | - ) |
1815 | | - responses = [result] |
1816 | | - logprobs_results = [None] |
1817 | | - completion_tokens = len(pipeline.tokenizer.encode(result)) |
| 1788 | + steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors") |
| 1789 | + target_layer = kwargs.get("target_layer", 19) |
| 1790 | + |
| 1791 | + # Prepare AutoThink configuration |
| 1792 | + autothink_config = { |
| 1793 | + "steering_dataset": steering_dataset, |
| 1794 | + "target_layer": target_layer, |
| 1795 | + "pattern_strengths": kwargs.get("pattern_strengths", { |
| 1796 | + "depth_and_thoroughness": 2.5, |
| 1797 | + "numerical_accuracy": 2.0, |
| 1798 | + "self_correction": 3.0, |
| 1799 | + "exploration": 2.0, |
| 1800 | + "organization": 1.5 |
| 1801 | + }) |
| 1802 | + } |
| 1803 | + |
| 1804 | + # Process with AutoThink |
| 1805 | + result = autothink_decode( |
| 1806 | + pipeline.current_model, |
| 1807 | + pipeline.tokenizer, |
| 1808 | + messages, |
| 1809 | + autothink_config |
| 1810 | + ) |
| 1811 | + responses = [result] |
| 1812 | + logprobs_results = [None] |
| 1813 | + completion_tokens = len(pipeline.tokenizer.encode(result)) |
1818 | 1814 | else: |
1819 | 1815 | raise ValueError(f"Unknown specialized decoding approach: {decoding}") |
1820 | 1816 |
|
|
0 commit comments