Skip to content

Commit d3bf847

Browse files
authored
Make mlx-lm more type-checker friendly (#573)
* Fix type annotation for `load` parameter * Add type annotations to all `load` parameters * Avoid using mutable types for `load` default parameters * Add return type annotation to `load_tokenizer` * Export public module attributes
1 parent df64341 commit d3bf847

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

mlx_lm/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,12 @@
99
from .convert import convert
1010
from .generate import batch_generate, generate, stream_generate
1111
from .utils import load
12+
13+
__all__ = [
14+
"__version__",
15+
"convert",
16+
"batch_generate",
17+
"generate",
18+
"stream_generate",
19+
"load",
20+
]

mlx_lm/tokenizer_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
from functools import partial
33
from json import JSONDecodeError
4-
from typing import List
4+
from typing import Any, Dict, List, Optional
55

66
from transformers import AutoTokenizer, PreTrainedTokenizerFast
77

@@ -424,8 +424,11 @@ def _is_bpe_decoder(decoder):
424424

425425

426426
def load_tokenizer(
427-
model_path, tokenizer_config_extra={}, return_tokenizer=True, eos_token_ids=None
428-
):
427+
model_path,
428+
tokenizer_config_extra: Optional[Dict[str, Any]] = None,
429+
return_tokenizer=True,
430+
eos_token_ids=None,
431+
) -> TokenizerWrapper:
429432
"""Load a huggingface tokenizer and try to infer the type of streaming
430433
detokenizer to use.
431434
@@ -454,8 +457,9 @@ def load_tokenizer(
454457
eos_token_ids = [eos_token_ids]
455458

456459
if return_tokenizer:
460+
kwargs = tokenizer_config_extra or {}
457461
return TokenizerWrapper(
458-
AutoTokenizer.from_pretrained(model_path, **tokenizer_config_extra),
462+
AutoTokenizer.from_pretrained(model_path, **kwargs),
459463
detokenizer_class,
460464
eos_token_ids=eos_token_ids,
461465
)

mlx_lm/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def load_model(
148148
model_path: Path,
149149
lazy: bool = False,
150150
strict: bool = True,
151-
model_config: dict = {},
151+
model_config: Optional[Dict[str, Any]] = None,
152152
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
153153
) -> Tuple[nn.Module, dict]:
154154
"""
@@ -175,7 +175,8 @@ def load_model(
175175
ValueError: If the model class or args class are not found or cannot be instantiated.
176176
"""
177177
config = load_config(model_path)
178-
config.update(model_config)
178+
if model_config is not None:
179+
config.update(model_config)
179180

180181
weight_files = glob.glob(str(model_path / "model*.safetensors"))
181182

@@ -245,12 +246,12 @@ def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
245246

246247
def load(
247248
path_or_hf_repo: str,
248-
tokenizer_config={},
249-
model_config={},
249+
tokenizer_config: Optional[Dict[str, Any]] = None,
250+
model_config: Optional[Dict[str, Any]] = None,
250251
adapter_path: Optional[str] = None,
251252
lazy: bool = False,
252253
return_config: bool = False,
253-
revision: str = None,
254+
revision: Optional[str] = None,
254255
) -> Union[
255256
Tuple[nn.Module, TokenizerWrapper],
256257
Tuple[nn.Module, TokenizerWrapper, Dict[str, Any]],

0 commit comments

Comments
 (0)