|
31 | 31 | else: |
32 | 32 | from huggingface_hub import snapshot_download |
33 | 33 |
|
34 | | -from mlx.utils import tree_flatten, tree_map, tree_reduce |
| 34 | +from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten |
35 | 35 | from transformers import PreTrainedTokenizer |
36 | 36 |
|
37 | 37 | # Local imports |
38 | 38 | from .tokenizer_utils import TokenizerWrapper, load_tokenizer |
39 | | -from .tuner.utils import dequantize as dequantize_model |
40 | | -from .tuner.utils import get_total_parameters, load_adapters |
41 | 39 |
|
42 | 40 | # Constants |
43 | 41 | MODEL_REMAPPING = { |
@@ -74,6 +72,20 @@ def _get_classes(config: dict): |
74 | 72 | return arch.Model, arch.ModelArgs |
75 | 73 |
|
76 | 74 |
|
| 75 | +def get_total_parameters(model): |
| 76 | + leaf_modules = tree_flatten( |
| 77 | + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) |
| 78 | + ) |
| 79 | + |
| 80 | + def nparams(m): |
| 81 | + if hasattr(m, "bits"): |
| 82 | + n = 0 if not hasattr(m, "bias") else m.bias.size |
| 83 | + return n + m.weight.size * 32 // m.bits |
| 84 | + return sum(v.size for _, v in tree_flatten(m.parameters())) |
| 85 | + |
| 86 | + return sum(nparams(m) for _, m in leaf_modules) |
| 87 | + |
| 88 | + |
77 | 89 | def compute_bits_per_weight(model): |
78 | 90 | model_bytes = tree_reduce( |
79 | 91 | lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 |
@@ -225,6 +237,12 @@ def class_predicate(p, m): |
225 | 237 | return model, config |
226 | 238 |
|
227 | 239 |
|
| 240 | +def load_adapeters(model: nn.Module, adapter_path: str) -> nn.Module: |
| 241 | + from .tuner.utils import load_adapters as _load_adapters |
| 242 | + |
| 243 | + return _load_adapters(model, adapter_path) |
| 244 | + |
| 245 | + |
228 | 246 | def load( |
229 | 247 | path_or_hf_repo: str, |
230 | 248 | tokenizer_config={}, |
@@ -520,6 +538,52 @@ def wrapped_predicate(path, module): |
520 | 538 | return model, quantized_config |
521 | 539 |
|
522 | 540 |
|
| 541 | +def dequantize_model(model: nn.Module) -> nn.Module: |
| 542 | + """ |
| 543 | + Dequantize the quantized layers in the model. |
| 544 | +
|
| 545 | + Args: |
| 546 | + model (nn.Module): The model with quantized layers. |
| 547 | +
|
| 548 | + Returns: |
| 549 | + nn.Module: The model with dequantized layers. |
| 550 | + """ |
| 551 | + from .models.switch_layers import QuantizedSwitchLinear, SwitchLinear |
| 552 | + |
| 553 | + dequantize_layers = [] |
| 554 | + for name, module in model.named_modules(): |
| 555 | + bias = "bias" in module |
| 556 | + if isinstance(module, nn.QuantizedLinear): |
| 557 | + cls = nn.Linear |
| 558 | + kwargs = {"bias": bias} |
| 559 | + elif isinstance(module, nn.QuantizedEmbedding): |
| 560 | + kwargs = {} |
| 561 | + cls = nn.Embedding |
| 562 | + elif isinstance(module, QuantizedSwitchLinear): |
| 563 | + kwargs = {"bias": bias} |
| 564 | + cls = SwitchLinear |
| 565 | + else: |
| 566 | + continue |
| 567 | + weight = mx.dequantize( |
| 568 | + module.weight, |
| 569 | + module.scales, |
| 570 | + module.biases, |
| 571 | + module.group_size, |
| 572 | + module.bits, |
| 573 | + module.mode, |
| 574 | + ) |
| 575 | + args = weight.shape[::-1] |
| 576 | + m = cls(*args, **kwargs) |
| 577 | + if bias: |
| 578 | + m.bias = module.bias |
| 579 | + m.weight = weight |
| 580 | + dequantize_layers.append((name, m)) |
| 581 | + |
| 582 | + if len(dequantize_layers) > 0: |
| 583 | + model.update_modules(tree_unflatten(dequantize_layers)) |
| 584 | + return model |
| 585 | + |
| 586 | + |
523 | 587 | def save_config( |
524 | 588 | config: dict, |
525 | 589 | config_path: Union[str, Path], |
|
0 commit comments