diff --git a/mlx_lm/__init__.py b/mlx_lm/__init__.py index 7180a422..7fb41942 100644 --- a/mlx_lm/__init__.py +++ b/mlx_lm/__init__.py @@ -9,3 +9,12 @@ from .convert import convert from .generate import batch_generate, generate, stream_generate from .utils import load + +__all__ = [ + "__version__", + "convert", + "batch_generate", + "generate", + "stream_generate", + "load", +] diff --git a/mlx_lm/tokenizer_utils.py b/mlx_lm/tokenizer_utils.py index f6bcd98f..c612c4bc 100644 --- a/mlx_lm/tokenizer_utils.py +++ b/mlx_lm/tokenizer_utils.py @@ -1,7 +1,7 @@ import json from functools import partial from json import JSONDecodeError -from typing import List +from typing import Any, Dict, List, Optional from transformers import AutoTokenizer, PreTrainedTokenizerFast @@ -424,8 +424,11 @@ def _is_bpe_decoder(decoder): def load_tokenizer( - model_path, tokenizer_config_extra={}, return_tokenizer=True, eos_token_ids=None -): + model_path, + tokenizer_config_extra: Optional[Dict[str, Any]] = None, + return_tokenizer=True, + eos_token_ids=None, +) -> TokenizerWrapper: """Load a huggingface tokenizer and try to infer the type of streaming detokenizer to use. @@ -454,8 +457,9 @@ def load_tokenizer( eos_token_ids = [eos_token_ids] if return_tokenizer: + kwargs = tokenizer_config_extra or {} return TokenizerWrapper( - AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra), + AutoTokenizer.from_pretrained(model_path, **kwargs), detokenizer_class, eos_token_ids=eos_token_ids, ) diff --git a/mlx_lm/utils.py b/mlx_lm/utils.py index 54d60dab..9021e212 100644 --- a/mlx_lm/utils.py +++ b/mlx_lm/utils.py @@ -136,7 +136,7 @@ def load_model( model_path: Path, lazy: bool = False, strict: bool = True, - model_config: dict = {}, + model_config: Optional[Dict[str, Any]] = None, get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, ) -> Tuple[nn.Module, dict]: """ @@ -163,7 +163,8 @@ def load_model( ValueError: If the model class or args class are not found or cannot be instantiated. """ config = load_config(model_path) - config.update(model_config) + if model_config is not None: + config.update(model_config) weight_files = glob.glob(str(model_path / "model*.safetensors")) @@ -227,12 +228,12 @@ def class_predicate(p, m): def load( path_or_hf_repo: str, - tokenizer_config={}, - model_config={}, + tokenizer_config: Optional[Dict[str, Any]] = None, + model_config: Optional[Dict[str, Any]] = None, adapter_path: Optional[str] = None, lazy: bool = False, return_config: bool = False, - revision: str = None, + revision: Optional[str] = None, ) -> Union[ Tuple[nn.Module, TokenizerWrapper], Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]],