Skip to content

Commit 1e8fca4

Browse files
authored
fix dequant + minor refactor (#572)
1 parent 61669b2 commit 1e8fca4

File tree

9 files changed

+92
-86
lines changed

9 files changed

+92
-86
lines changed

mlx_lm/benchmark.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,8 @@ def main():
6868
prompt_tokens = args.prompt_tokens
6969
generation_tokens = args.generation_tokens
7070
batch_size = args.batch_size
71-
prompts = mx.random.randint(
72-
0, config["vocab_size"], (batch_size, prompt_tokens)
73-
).tolist()
71+
vocab_size = config.get("vocab_size") or config["text_config"]["vocab_size"]
72+
prompts = mx.random.randint(0, vocab_size, (batch_size, prompt_tokens)).tolist()
7473
prompt = prompts[0]
7574

7675
def single_bench():

mlx_lm/fuse.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from mlx.utils import tree_flatten, tree_unflatten
55

66
from .gguf import convert_to_gguf
7-
from .tuner.utils import dequantize, load_adapters
87
from .utils import (
8+
dequantize_model,
99
load,
1010
save,
1111
upload_to_hub,
@@ -39,8 +39,8 @@ def parse_arguments() -> argparse.Namespace:
3939
default=None,
4040
)
4141
parser.add_argument(
42-
"--de-quantize",
43-
help="Generate a de-quantized model.",
42+
"--dequantize",
43+
help="Generate a dequantized model.",
4444
action="store_true",
4545
)
4646
parser.add_argument(
@@ -66,16 +66,16 @@ def main() -> None:
6666
)
6767

6868
fused_linears = [
69-
(n, m.fuse(de_quantize=args.de_quantize))
69+
(n, m.fuse(dequantize=args.dequantize))
7070
for n, m in model.named_modules()
7171
if hasattr(m, "fuse")
7272
]
7373

7474
if fused_linears:
7575
model.update_modules(tree_unflatten(fused_linears))
7676

77-
if args.de_quantize:
78-
print("De-quantizing model")
77+
if args.dequantize:
78+
print("Dequantizing model")
7979
model = dequantize(model)
8080
config.pop("quantization", None)
8181

mlx_lm/models/afm7.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
]
5151
self.lora_b = [mx.zeros((r, od)) for od in output_dims]
5252

53-
def fuse(self, de_quantize: bool = False):
53+
def fuse(self, dequantize: bool = False):
5454
linear = self.linear
5555
weight = linear.weight
5656
is_quantized = isinstance(linear, FusedQuantizedLinear)
@@ -79,7 +79,7 @@ def fuse(self, de_quantize: bool = False):
7979
delta = mx.concatenate(deltas, axis=0)
8080
fused_linear.weight = weight + delta
8181

82-
if is_quantized and not de_quantize:
82+
if is_quantized and not dequantize:
8383
fused_linear = fused_linear.to_quantized(linear.group_size, linear.bits)
8484

8585
return fused_linear

mlx_lm/perplexity.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
import numpy as np
1414

1515
from mlx_lm.tuner.datasets import load_dataset
16-
from mlx_lm.tuner.utils import get_total_parameters
17-
from mlx_lm.utils import load
16+
from mlx_lm.utils import get_total_parameters, load
1817

1918

2019
def load_data(

mlx_lm/tuner/dora.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def from_base(
2929
dora_lin.set_linear(linear)
3030
return dora_lin
3131

32-
def fuse(self, de_quantize: bool = False):
32+
def fuse(self, dequantize: bool = False):
3333
linear = self.linear
3434
bias = "bias" in linear
3535
weight = self._dequantized_weight()
@@ -49,7 +49,7 @@ def fuse(self, de_quantize: bool = False):
4949
if bias:
5050
fused_linear.bias = linear.bias
5151

52-
if self._is_quantized() and not de_quantize:
52+
if self._is_quantized() and not dequantize:
5353
fused_linear = nn.QuantizedLinear.from_linear(
5454
fused_linear,
5555
linear.group_size,
@@ -151,7 +151,7 @@ def from_base(
151151
dora_embedding.set_embedding(embedding)
152152
return dora_embedding
153153

154-
def fuse(self, de_quantize: bool = False):
154+
def fuse(self, dequantize: bool = False):
155155
embedding = self.embedding
156156
weight = embedding.weight
157157

mlx_lm/tuner/lora.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def from_base(
3131
lora_lin.linear = linear
3232
return lora_lin
3333

34-
def fuse(self, de_quantize: bool = False):
34+
def fuse(self, dequantize: bool = False):
3535
linear = self.linear
3636
bias = "bias" in linear
3737
weight = linear.weight
@@ -57,7 +57,7 @@ def fuse(self, de_quantize: bool = False):
5757
if bias:
5858
fused_linear.bias = linear.bias
5959

60-
if is_quantized and not de_quantize:
60+
if is_quantized and not dequantize:
6161
fused_linear = nn.QuantizedLinear.from_linear(
6262
fused_linear,
6363
linear.group_size,
@@ -119,7 +119,7 @@ def from_base(
119119
lora_lin.linear = linear
120120
return lora_lin
121121

122-
def fuse(self, de_quantize: bool = False):
122+
def fuse(self, dequantize: bool = False):
123123
linear = self.linear
124124
bias = "bias" in linear
125125
weight = linear.weight
@@ -146,7 +146,7 @@ def fuse(self, de_quantize: bool = False):
146146
if bias:
147147
fused_linear.bias = linear.bias
148148

149-
if is_quantized and not de_quantize:
149+
if is_quantized and not dequantize:
150150
fused_linear = fused_linear.to_quantized(linear.group_size, linear.bits)
151151

152152
return fused_linear
@@ -219,7 +219,7 @@ def from_base(
219219
lora_embedding.embedding = embedding
220220
return lora_embedding
221221

222-
def fuse(self, de_quantize: bool = False):
222+
def fuse(self, dequantize: bool = False):
223223
embedding = self.embedding
224224
weight = embedding.weight
225225
is_quantized = isinstance(embedding, nn.QuantizedEmbedding)
@@ -243,7 +243,7 @@ def fuse(self, de_quantize: bool = False):
243243
lora_b = self.lora_b.astype(dtype)
244244
fused_embedding.weight = weight + lora_a @ lora_b
245245

246-
if is_quantized and not de_quantize:
246+
if is_quantized and not dequantize:
247247
fused_embedding = nn.QuantizedEmbedding.from_embedding(
248248
fused_embedding,
249249
embedding.group_size,

mlx_lm/tuner/utils.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
import mlx.core as mx
88
import mlx.nn as nn
99
import mlx.optimizers as opt
10-
from mlx.utils import tree_flatten, tree_map_with_path, tree_unflatten
10+
from mlx.utils import tree_flatten, tree_unflatten
1111

1212
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
13+
from ..utils import get_total_parameters
1314
from .dora import DoRAEmbedding, DoRALinear
1415
from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
1516

@@ -137,49 +138,6 @@ def load_adapters(model: nn.Module, adapter_path: str) -> nn.Module:
137138
return model
138139

139140

140-
def dequantize(model: nn.Module) -> nn.Module:
141-
"""
142-
Dequantize the quantized linear layers in the model.
143-
144-
Args:
145-
model (nn.Module): The model with quantized linear layers.
146-
147-
Returns:
148-
nn.Module: The model with dequantized layers.
149-
"""
150-
dequantize_layers = []
151-
for name, module in model.named_modules():
152-
bias = "bias" in module
153-
if isinstance(module, nn.QuantizedLinear):
154-
cls = nn.Linear
155-
kwargs = {"bias": bias}
156-
elif isinstance(module, nn.QuantizedEmbedding):
157-
kwargs = {}
158-
cls = nn.Embedding
159-
elif isinstance(module, QuantizedSwitchLinear):
160-
kwargs = {"bias": bias}
161-
cls = SwitchLinear
162-
else:
163-
continue
164-
weight = mx.dequantize(
165-
module.weight,
166-
module.scales,
167-
module.biases,
168-
module.group_size,
169-
module.bits,
170-
)
171-
args = weight.shape[::-1]
172-
m = cls(*args, **kwargs)
173-
if bias:
174-
m.bias = module.bias
175-
m.weight = weight
176-
dequantize_layers.append((name, m))
177-
178-
if len(dequantize_layers) > 0:
179-
model.update_modules(tree_unflatten(dequantize_layers))
180-
return model
181-
182-
183141
def remove_lora_layers(model: nn.Module) -> nn.Module:
184142
"""
185143
Remove the LoRA layers from the model.
@@ -199,20 +157,6 @@ def remove_lora_layers(model: nn.Module) -> nn.Module:
199157
return model
200158

201159

202-
def get_total_parameters(model):
203-
leaf_modules = tree_flatten(
204-
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
205-
)
206-
207-
def nparams(m):
208-
if hasattr(m, "bits"):
209-
n = 0 if not hasattr(m, "bias") else m.bias.size
210-
return n + m.weight.size * 32 // m.bits
211-
return sum(v.size for _, v in tree_flatten(m.parameters()))
212-
213-
return sum(nparams(m) for _, m in leaf_modules)
214-
215-
216160
def print_trainable_parameters(model):
217161
total_p = get_total_parameters(model) / 1e6
218162
trainable_p = (

mlx_lm/utils.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,11 @@
3131
else:
3232
from huggingface_hub import snapshot_download
3333

34-
from mlx.utils import tree_flatten, tree_map, tree_reduce
34+
from mlx.utils import tree_flatten, tree_map, tree_reduce, tree_unflatten
3535
from transformers import PreTrainedTokenizer
3636

3737
# Local imports
3838
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
39-
from .tuner.utils import dequantize as dequantize_model
40-
from .tuner.utils import get_total_parameters, load_adapters
4139

4240
# Constants
4341
MODEL_REMAPPING = {
@@ -74,6 +72,20 @@ def _get_classes(config: dict):
7472
return arch.Model, arch.ModelArgs
7573

7674

75+
def get_total_parameters(model):
76+
leaf_modules = tree_flatten(
77+
model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module)
78+
)
79+
80+
def nparams(m):
81+
if hasattr(m, "bits"):
82+
n = 0 if not hasattr(m, "bias") else m.bias.size
83+
return n + m.weight.size * 32 // m.bits
84+
return sum(v.size for _, v in tree_flatten(m.parameters()))
85+
86+
return sum(nparams(m) for _, m in leaf_modules)
87+
88+
7789
def compute_bits_per_weight(model):
7890
model_bytes = tree_reduce(
7991
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
@@ -225,6 +237,12 @@ def class_predicate(p, m):
225237
return model, config
226238

227239

240+
def load_adapeters(model: nn.Module, adapter_path: str) -> nn.Module:
241+
from .tuner.utils import load_adapters as _load_adapters
242+
243+
return _load_adapters(model, adapter_path)
244+
245+
228246
def load(
229247
path_or_hf_repo: str,
230248
tokenizer_config={},
@@ -520,6 +538,52 @@ def wrapped_predicate(path, module):
520538
return model, quantized_config
521539

522540

541+
def dequantize_model(model: nn.Module) -> nn.Module:
542+
"""
543+
Dequantize the quantized layers in the model.
544+
545+
Args:
546+
model (nn.Module): The model with quantized layers.
547+
548+
Returns:
549+
nn.Module: The model with dequantized layers.
550+
"""
551+
from .models.switch_layers import QuantizedSwitchLinear, SwitchLinear
552+
553+
dequantize_layers = []
554+
for name, module in model.named_modules():
555+
bias = "bias" in module
556+
if isinstance(module, nn.QuantizedLinear):
557+
cls = nn.Linear
558+
kwargs = {"bias": bias}
559+
elif isinstance(module, nn.QuantizedEmbedding):
560+
kwargs = {}
561+
cls = nn.Embedding
562+
elif isinstance(module, QuantizedSwitchLinear):
563+
kwargs = {"bias": bias}
564+
cls = SwitchLinear
565+
else:
566+
continue
567+
weight = mx.dequantize(
568+
module.weight,
569+
module.scales,
570+
module.biases,
571+
module.group_size,
572+
module.bits,
573+
module.mode,
574+
)
575+
args = weight.shape[::-1]
576+
m = cls(*args, **kwargs)
577+
if bias:
578+
m.bias = module.bias
579+
m.weight = weight
580+
dequantize_layers.append((name, m))
581+
582+
if len(dequantize_layers) > 0:
583+
model.update_modules(tree_unflatten(dequantize_layers))
584+
return model
585+
586+
523587
def save_config(
524588
config: dict,
525589
config_path: Union[str, Path],

tests/test_finetune.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def test_lora_embedding(self):
123123
embedding.bits,
124124
)
125125
lora_emb = LoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10)
126-
new_embedding = lora_emb.fuse(de_quantize=True)
126+
new_embedding = lora_emb.fuse(dequantize=True)
127127
self.assertTrue(mx.array_equal(dequantized_weight, new_embedding.weight))
128128
self.assertTrue(mx.array_equal(embedding(tokens), lora_emb(tokens)))
129129

@@ -137,7 +137,7 @@ def test_lora_embedding(self):
137137

138138
# change the value of lora_b and the embeddings will no longer be equal
139139
lora_emb.lora_b = mx.random.uniform(shape=lora_emb.lora_b.shape)
140-
new_embedding = lora_emb.fuse(de_quantize=True)
140+
new_embedding = lora_emb.fuse(dequantize=True)
141141
self.assertFalse(mx.array_equal(dequantized_weight, new_embedding.weight))
142142
self.assertFalse(mx.array_equal(embedding(tokens), lora_emb(tokens)))
143143

@@ -300,7 +300,7 @@ def dequantize_weight(quantized_linear):
300300
quantized_linear = nn.QuantizedLinear(in_dims, out_dims, bias=True)
301301
dora_quantized_linear = DoRALinear.from_base(quantized_linear, r)
302302
# Dequantize
303-
to_linear_from_quantized = dora_quantized_linear.fuse(de_quantize=True)
303+
to_linear_from_quantized = dora_quantized_linear.fuse(dequantize=True)
304304
self.assertTrue(
305305
mx.allclose(quantized_linear.bias, to_linear_from_quantized.bias)
306306
)

0 commit comments

Comments
 (0)