diff --git a/libs/community/langchain_community/chat_models/mlx.py b/libs/community/langchain_community/chat_models/mlx.py index 64fd45ac..49c800a3 100644 --- a/libs/community/langchain_community/chat_models/mlx.py +++ b/libs/community/langchain_community/chat_models/mlx.py @@ -180,6 +180,7 @@ def _stream( top_p: float = model_kwargs.get("top_p", 1.0) min_p: float = model_kwargs.get("min_p", 0.0) min_tokens_to_keep: int = model_kwargs.get("min_tokens_to_keep", 1) + top_k: int = model_kwargs.get("top_k", 0) llm_input = self._to_chat_prompt(messages, tokenize=True, return_tensors="np") @@ -187,7 +188,7 @@ def _stream( eos_token_id = self.tokenizer.eos_token_id - sampler = make_sampler(temp or 0.0, top_p, min_p, min_tokens_to_keep) + sampler = make_sampler(temp or 0.0, top_p, min_p, min_tokens_to_keep, top_k) logits_processors = make_logits_processors( None, repetition_penalty, repetition_context_size diff --git a/libs/community/langchain_community/llms/mlx_pipeline.py b/libs/community/langchain_community/llms/mlx_pipeline.py index 4ef69d4c..780bb129 100644 --- a/libs/community/langchain_community/llms/mlx_pipeline.py +++ b/libs/community/langchain_community/llms/mlx_pipeline.py @@ -72,7 +72,12 @@ class MLXPipeline(LLM): for applying repetition penalty, default is None. - top_p (float): The cumulative probability threshold for top-p filtering, default is 1.0. - + - min_p (float): The minimum probability threshold for + top-p filtering, default is 0.0. + - min_tokens_to_keep (int): The minimum number of tokens to keep + for top-p filtering, default is 1. + - top_k (int): The number of highest probability vocabulary tokens + to keep for top-k filtering, default is 0. """ model_config = ConfigDict( @@ -166,8 +171,9 @@ def _call( top_p: float = pipeline_kwargs.get("top_p", 1.0) min_p: float = pipeline_kwargs.get("min_p", 0.0) min_tokens_to_keep: int = pipeline_kwargs.get("min_tokens_to_keep", 1) + top_k: int = pipeline_kwargs.get("top_k", 0) - sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) + sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep, top_k) logits_processors = make_logits_processors( None, repetition_penalty, repetition_context_size ) @@ -214,6 +220,7 @@ def _stream( top_p: float = pipeline_kwargs.get("top_p", 1.0) min_p: float = pipeline_kwargs.get("min_p", 0.0) min_tokens_to_keep: int = pipeline_kwargs.get("min_tokens_to_keep", 1) + top_k: int = pipeline_kwargs.get("top_k", 0) prompt = self.tokenizer.encode(prompt, return_tensors="np") @@ -223,7 +230,7 @@ def _stream( detokenizer = self.tokenizer.detokenizer detokenizer.reset() - sampler = make_sampler(temp or 0.0, top_p, min_p, min_tokens_to_keep) + sampler = make_sampler(temp or 0.0, top_p, min_p, min_tokens_to_keep, top_k) logits_processors = make_logits_processors( None, repetition_penalty, repetition_context_size