forked from foundation-model-stack/fms-acceleration
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase.py
More file actions
1318 lines (1139 loc) · 52.4 KB
/
base.py
File metadata and controls
1318 lines (1139 loc) · 52.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
###############################################################################
# Adapted from https://github.com/ModelCloud/GPTQModel
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Standard
from os.path import isfile, join
from typing import Dict, List, Optional, Union
import copy
import json
import logging
import os
import re
# Third Party
from accelerate.hooks import remove_hook_from_module
from safetensors.torch import save_file as safe_save
from tqdm import tqdm
from transformers import (
AutoConfig,
AutoModelForCausalLM,
PretrainedConfig,
PreTrainedModel,
)
from transformers.modeling_utils import (
dtype_byte_size,
is_local_dist_rank_0,
no_init_weights,
)
from transformers.pytorch_utils import id_tensor_storage
from transformers.utils import WEIGHTS_NAME
from transformers.utils.generic import ContextManagers
from transformers.utils.hub import convert_file_size_to_int
import accelerate
import torch
import torch.nn as nn
import transformers
# Local
from ..quantization import GPTQ, QuantizeConfig
from ..quantization.config import (
FORMAT,
FORMAT_FIELD_JSON,
META_FIELD_QUANTIZER,
META_QUANTIZER_GPTQMODEL,
MIN_VERSION_WITH_V2,
QUANTIZE_BLACK_LIST,
)
from ..utils.backend import Backend
from ..utils.data import collate_data
from ..utils.importer import select_quant_linear
from ..utils.model import (
auto_dtype_from_config,
convert_gptq_v1_to_v2_format,
convert_gptq_v2_to_v1_format,
find_layers,
get_checkpoints,
get_device,
get_module_by_name_prefix,
get_module_by_name_suffix,
get_moe_layer_modules,
gptqmodel_post_init,
make_quant,
move_to,
nested_move_to,
pack_model,
simple_dispatch_model,
verify_model_hash,
verify_sharded_model_hashes,
)
from ._const import CPU, CUDA_0, SUPPORTED_MODELS
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.propagate = False
logger.addHandler(handler)
logger.setLevel(logging.INFO)
class BaseGPTQModel(nn.Module):
# these modules are non-repeating and at the root level
# does not include the node which holds all the repeating layers
base_modules: List[str] = None
# name of lm_head
lm_head: str = "lm_head"
# repeating layers
# node holding all the repeating layers
layers_node: str = None
# repeating layer type
layer_type: str = None
# for each repeating layer there are multiple modules within each layer
layer_modules: List[List[str]] = None
# some models require trust_remove_code = True (dbrx_converted)
require_trust_remote_code = None
# TODO: use a better name and what if the value is not at the config root?
# allow dynamic expert n-count layer extraction
# so moe model defs do not need to write out 64 layers if expert size is 64 (Qwen2Moe)
# usage: set to property in model.config that holds this int value: total number of experts
dynamic_expert_index: Optional[str] = None
# allow models to define optional notes that output messages to users that want to use this model
# list of supported keys: [ "notes" = print the notes value on model load ]
info: Dict[str, str] = {}
def __init__(
self,
model: PreTrainedModel,
quantized: bool,
quantize_config: QuantizeConfig,
qlinear_kernel: nn.Module = None,
):
super().__init__()
self.model = model
self.model_type = self.model.config.model_type
self._quantized = quantized
self.quantize_config = quantize_config
self.config = self.model.config
# compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion
self.qlinear_kernel = qlinear_kernel
@property
def quantized(self):
return self._quantized
@property
def hf_device_map(self):
return getattr(self.model, "hf_device_map", None)
def _prepare_dataset_for_quantization(
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
batch_size: int = 1,
):
def _convert_tensor_to_list(tensor):
if isinstance(tensor, torch.Tensor):
if len(tensor.shape) == 1:
tensor = tensor.unsqueeze(0)
tensor = tensor.long()
return tensor.cpu().numpy().tolist()
return [tensor]
new_calibration_dataset = []
for example in calibration_dataset:
input_ids = _convert_tensor_to_list(example["input_ids"])
attention_mask = _convert_tensor_to_list(example["attention_mask"])
if "labels" in example:
labels = _convert_tensor_to_list(example["labels"])
elif "label" in example:
labels = _convert_tensor_to_list(example["label"])
elif "label_ids" in example:
labels = _convert_tensor_to_list(example["label_ids"])
else:
labels = copy.deepcopy(input_ids)
new_calibration_dataset.append(
{
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
)
pad_token_id = self.config.pad_token_id
if not pad_token_id:
pad_token_id = self.config.eos_token_id
if pad_token_id is None:
raise ValueError(
"Calibration data requires model's `pad_token_id` or `eos_token_id` to be set: actual = `None`."
)
new_calibration_dataset_batched = [
collate_data(
new_calibration_dataset[start : start + batch_size], pad_token_id
)
for start in range(0, len(new_calibration_dataset), batch_size)
]
for new_example in new_calibration_dataset_batched:
del new_example["labels"]
return new_calibration_dataset_batched
@torch.inference_mode()
def quantize(
self,
calibration_dataset: List[Dict[str, Union[List[int], torch.LongTensor]]],
batch_size: int = 1,
# TODO: remove use_cuda_fp16 arg..why? doesn't pass smell test @ZX-ModelCloud
use_cuda_fp16: bool = True,
autotune_warmup_after_quantized: bool = False,
calibration_enable_gpu_cache: bool = True,
):
if self.quantized:
raise EnvironmentError(
"quantize() is called a model that is already quantized"
)
if self.quantize_config.quant_method in QUANTIZE_BLACK_LIST:
raise ValueError(
f"Unsupported quantization operation for quant method: {self.quantize_config.quant_method}"
)
# TODO: lm_head quantization is yet ready but pending
if self.quantize_config.lm_head:
raise ValueError(
"lm_head quantization is currently inference only and not applicable for quantization. Please set `lm_head=False`."
)
if len(calibration_dataset) == 0:
raise ValueError("Calibration dataset must not be empty.")
min_calibration_dataset_size = 256
min_calibration_dataset_input_ids_avg_length = 256
if len(calibration_dataset) < min_calibration_dataset_size:
logger.warning(
f"Calibration dataset size should be greater than {min_calibration_dataset_size}. "
f"Current size: {len(calibration_dataset)}."
)
# Calculate the average length of the average input_ids
total_input_ids_length = 0
for e in calibration_dataset:
input_ids_length = len(e["input_ids"])
total_input_ids_length += input_ids_length
avg = total_input_ids_length / len(calibration_dataset)
if avg < min_calibration_dataset_input_ids_avg_length:
logger.warning(
f"The average length of input_ids of calibration_dataset should be greater than "
f"{min_calibration_dataset_input_ids_avg_length}! Current AVG is {avg}."
)
device_map = self.hf_device_map
if device_map:
for name, device in device_map.items():
if device == "cpu":
logger.info(f"truly offloading {name} to cpu with hook.")
module = get_module_by_name_suffix(self.model, name)
remove_hook_from_module(module, recurse=True)
accelerate.cpu_offload_with_hook(module, CUDA_0)
layer_inputs = []
attention_masks = []
position_ids = []
layer_input_kwargs = []
layer_outputs = []
calibration_dataset = self._prepare_dataset_for_quantization(
calibration_dataset, batch_size
)
forward_pass_use_cache = self.model.config.use_cache
self.model.config.use_cache = False
num_batches = len(calibration_dataset)
layers = get_module_by_name_prefix(self.model, self.layers_node)
cur_layer_device = get_device(layers[0])
data_device = cur_layer_device if calibration_enable_gpu_cache else CPU
def store_input_hook(_, args, kwargs):
# Positional arguments.
layer_input = []
for inp in args:
layer_input.append(move_to(inp, data_device))
layer_inputs.append(layer_input)
# Keyword arguments.
if kwargs["attention_mask"] is not None:
attention_masks.append(kwargs["attention_mask"].to(data_device))
else:
attention_masks.append(None)
pos_ids = kwargs.get("position_ids", None)
if pos_ids is not None:
position_ids.append(move_to(pos_ids, data_device))
one_kwargs = {}
for k, v in kwargs.items(): # make sure other arguments also be captured
if k not in ["hidden_states", "attention_mask", "position_ids"]:
one_kwargs[k] = nested_move_to(v, data_device)
layer_input_kwargs.append(one_kwargs)
raise ValueError
force_layer_back_to_cpu = False
if get_device(layers[0]) == CPU:
layers[0] = layers[0].to(CUDA_0)
force_layer_back_to_cpu = True
ori_outside_layer_module_devices = {}
for module_name in self.base_modules:
module = get_module_by_name_prefix(self.model, module_name)
if module is None:
continue
ori_outside_layer_module_devices[module_name] = get_device(module)
if module is not None:
move_to(module, cur_layer_device)
# TODO: make this optional, backporting https://github.com/huggingface/optimum/blob/main/optimum/gptq/quantizer.py
handle = layers[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
for example in calibration_dataset:
for k, v in example.items():
if len(v.shape) == 1:
v = v.unsqueeze(0)
example[k] = move_to(v, cur_layer_device)
try:
self.model(**example)
except ValueError:
pass
handle.remove()
move_to(layers[0], CPU if force_layer_back_to_cpu else cur_layer_device)
for module_name in self.base_modules:
module = get_module_by_name_prefix(self.model, module_name)
if module is not None:
move_to(module, ori_outside_layer_module_devices[module_name])
torch.cuda.empty_cache()
layer_modules = self.layer_modules
if not self.quantize_config.true_sequential:
layer_modules = [sum(layer_modules, [])]
# dynamic expert layer index for model defs
if self.dynamic_expert_index is not None:
num_experts = getattr(self.model.config, self.dynamic_expert_index)
layer_modules = get_moe_layer_modules(
layer_modules=self.layer_modules, num_experts=num_experts
)
quantizers = {}
# stores all per-layer quant stats such as avg loss and processing time
quant_log = []
layer_count = len(layers)
layer_pb = tqdm(range(layer_count))
for i in layer_pb:
layer_pb.set_description(f"Quantizing layer {i + 1} of {layer_count}")
layer = layers[i]
force_layer_back_to_cpu = False
if get_device(layer) == CPU:
move_to(layer, CUDA_0)
force_layer_back_to_cpu = True
cur_layer_device = get_device(layer)
full = find_layers(layer)
for names in layer_modules:
subset = {n: full[n] for n in names if n in full}
gptq = {}
for name in subset:
gptq[name] = GPTQ(subset[name])
gptq[name].quantizer.configure(
self.quantize_config.bits,
perchannel=True,
sym=self.quantize_config.sym,
mse=False,
)
def add_batch(name):
def tmp(_, inp, out):
# gptq is mutable.
gptq[name].add_batch(inp[0].data, out.data) # noqa: F821
return tmp
handles = []
for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(num_batches):
layer_input = []
for k, layer_inp in enumerate(layer_inputs[j]):
layer_input.append(move_to(layer_inp, cur_layer_device))
mask = attention_masks[j]
layer_attention_mask = (
mask if mask is None else move_to(mask, cur_layer_device)
)
additional_layer_inputs = {"attention_mask": layer_attention_mask}
layer_position_ids = (
None
if not position_ids
else move_to(position_ids[j], cur_layer_device)
)
if layer_position_ids is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
additional_layer_inputs[k] = nested_move_to(v, cur_layer_device)
layer(*layer_input, **additional_layer_inputs)
for h in handles:
h.remove()
for name in subset:
layer_pb.set_description(
f"Quantizing {name} in layer {i + 1} of {layer_count}"
)
try:
scale, zero, g_idx, duration, avg_loss = gptq[name].fasterquant(
percdamp=self.quantize_config.damp_percent,
group_size=self.quantize_config.group_size,
actorder=self.quantize_config.desc_act,
static_groups=self.quantize_config.static_groups,
)
stat = {
"layer": i + 1,
"module": name,
"avg_loss": f"{avg_loss:.4f}",
"time": f"{duration:.4f}",
}
quant_log.append(stat)
logger.info(stat)
except torch._C._LinAlgError as e:
if "not positive-definite" in str(e).lower():
logger.warning(
"Please increase damp or nsamples for calibration data to avoid the following quant error. "
)
raise e
quantizers[f"{self.layers_node}.{i}.{name}"] = (
gptq[name].quantizer.to(
CPU if force_layer_back_to_cpu else cur_layer_device
),
move_to(
scale, CPU if force_layer_back_to_cpu else cur_layer_device
),
move_to(
zero, CPU if force_layer_back_to_cpu else cur_layer_device
),
move_to(
g_idx, CPU if force_layer_back_to_cpu else cur_layer_device
),
)
gptq[name].free()
for j in range(num_batches):
layer_input = []
for k, layer_inp in enumerate(layer_inputs[j]):
layer_input.append(move_to(layer_inp, cur_layer_device))
mask = attention_masks[j]
layer_attention_mask = (
mask if mask is None else move_to(mask, cur_layer_device)
)
additional_layer_inputs = {"attention_mask": layer_attention_mask}
layer_position_ids = (
None
if not position_ids
else move_to(position_ids[j], cur_layer_device)
)
if layer_position_ids is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
additional_layer_inputs[k] = nested_move_to(v, cur_layer_device)
layer_output = move_to(
layer(*layer_input, **additional_layer_inputs)[0],
cur_layer_device if calibration_enable_gpu_cache else CPU,
)
layer_outputs.append([layer_output])
layers[i] = move_to(
layer, CPU if force_layer_back_to_cpu else cur_layer_device
)
del layer
del gptq
del layer_inputs
layer_inputs, layer_outputs = (
layer_outputs,
[],
) # TODO: is it really OK to cache only the first positional argument?
torch.cuda.empty_cache()
logger.info(f"Quantization summary:\n{quant_log}")
for module_log in quant_log:
logger.info(module_log)
self.qlinear_kernel = pack_model(
model=self.model,
quantizers=quantizers,
bits=self.quantize_config.bits,
group_size=self.quantize_config.group_size,
backend=Backend.AUTO,
use_cuda_fp16=use_cuda_fp16,
desc_act=self.quantize_config.desc_act,
warmup_triton=autotune_warmup_after_quantized,
force_layer_back_to_cpu=force_layer_back_to_cpu,
format=self.quantize_config.format,
)
if device_map:
self.model = remove_hook_from_module(self.model, recurse=True)
self.model = simple_dispatch_model(self.model, device_map)
self.model.config.use_cache = forward_pass_use_cache
self._quantized = True
torch.cuda.empty_cache()
return quant_log
@property
def device(self):
if not self.hf_device_map:
return self.model.device
else:
device = [d for d in self.hf_device_map.values() if d not in {"disk"}][0]
return torch.device(device)
def to(self, device: Union[str, torch.device]):
self.model.to(device)
return self
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def generate(self, **kwargs):
"""shortcut for model.generate"""
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
return self.model.generate(**kwargs)
def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
return self.model.prepare_inputs_for_generation(*args, **kwargs)
def save_quantized(
self,
save_dir: str,
safetensors_metadata: Optional[Dict[str, str]] = None,
use_safetensors: bool = True,
max_shard_size: Optional[str] = None,
model_base_name: Optional[str] = None,
):
"""save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True)
# write autogptq tooling fingerprint to config
self.quantize_config.meta_set_versionable(
key=META_FIELD_QUANTIZER,
value=META_QUANTIZER_GPTQMODEL,
version=__version__,
)
# The config, quantize_config and model may be edited in place in save_quantized.
config = copy.deepcopy(self.model.config)
quantize_config = copy.deepcopy(self.quantize_config)
model = self.model
if not self.quantized:
raise ValueError(
"Save aborted as model is not quantized. Please call `quantize()` first."
)
if model_base_name is None:
model_base_name = (
self.quantize_config.model_file_base_name
or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
)
if quantize_config.format == FORMAT.GPTQ_V2:
logger.warning(
f"Using 'format = {FORMAT.GPTQ_V2}': the serialized model is only supported by GPTQModel version >= {MIN_VERSION_WITH_V2}."
)
# internal is always gptq v2 but allow users to pass gptq (v1) via config
if quantize_config.format == FORMAT.GPTQ:
# Model qzeros may be edited in place.
# TODO: avoid inplace modification of the weights
# fix ModelCloud/GPTQModel/issues/47
# fix gptqmodel_cuda cannot be serialized
# no need to set it back, no calculation below
if quantize_config.bits != 4:
cuda_name_modules = {}
# Third Party
from gptqmodel.nn_modules.qlinear.qlinear_cuda import (
BaseCudaQuantLinear,
)
for name, module in model.named_modules():
if isinstance(module, BaseCudaQuantLinear):
cuda_name_modules[name] = module.gptqmodel_cuda
module.gptqmodel_cuda = None
model = copy.deepcopy(self.model)
for name, modules in model.named_modules():
if (
isinstance(module, BaseCudaQuantLinear)
and name in cuda_name_modules
):
module.gptqmodel_cuda = cuda_name_modules[name]
del cuda_name_modules
else:
model = copy.deepcopy(self.model)
model = convert_gptq_v2_to_v1_format(
model,
quantize_config=quantize_config,
qlinear_kernel=self.qlinear_kernel,
)
model.to(CPU)
state_dict = model.state_dict()
if quantize_config.model_file_base_name is None:
if use_safetensors:
model_base_name = "model"
else:
model_base_name = "pytorch_model"
else:
model_base_name = quantize_config.model_file_base_name
if use_safetensors:
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
model_save_name = model_base_name + ".safetensors"
else:
model_save_name = model_base_name + ".bin"
if not self.qlinear_kernel.SUPPORTED_SHARDS and max_shard_size is not None:
logger.warning(
"Sharding is not supported for this quant. Disabling sharding."
)
max_shard_size = None
if max_shard_size is None:
if use_safetensors:
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
logger.debug(
f"Received safetensors_metadata: {safetensors_metadata}"
)
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
)
if new_key in new_safetensors_metadata:
logger.warning(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
)
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
)
# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata["format"] = "pt"
safe_save(
state_dict, join(save_dir, model_save_name), safetensors_metadata
)
else:
logger.warning(
"We highly suggest saving quantized model using safetensors format for security reasons. Please set `use_safetensors=True` whenever possible."
)
torch.save(model.state_dict(), join(save_dir, model_save_name))
else:
# Shard checkpoint
shards, index = self.shard_checkpoint(
state_dict, max_shard_size=max_shard_size, weights_name=model_save_name
)
# Clean the folder from a previous save
for filename in os.listdir(save_dir):
full_filename = join(save_dir, filename)
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
filename_no_suffix = filename.replace(".bin", "").replace(
".safetensors", ""
)
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
if (
filename.startswith(model_base_name)
and isfile(full_filename)
and filename not in shards.keys()
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)
# Save the model
for shard_file, shard in shards.items():
if use_safetensors:
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
logger.debug(
f"Received safetensors_metadata: {safetensors_metadata}"
)
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
)
if new_key in new_safetensors_metadata:
logger.warning(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
)
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
)
# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata["format"] = "pt"
safe_save(shard, join(save_dir, shard_file), safetensors_metadata)
else:
torch.save(shard, join(save_dir, shard_file))
if index is not None:
index_save_name = model_save_name + ".index.json"
index_save_path = join(save_dir, index_save_name)
# Save the index as well
with open(index_save_path, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
config.quantization_config = quantize_config.to_dict()
config.save_pretrained(save_dir)
quantize_config.model_name_or_path = save_dir
quantize_config.model_file_base_name = model_base_name
quantize_config.save_pretrained(save_dir)
# added by anh.uong@ibm.com
# adapted from transformers.modeling_utils.shard_checkpoint
# from transformers v4.46, removed in later versions
# TODO: split_torch_state_dict_into_shards from huggingface_hub library
def shard_checkpoint(
self,
state_dict: Dict[str, torch.Tensor],
max_shard_size: Union[int, str] = "10GB",
weights_name: str = WEIGHTS_NAME,
):
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no
optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the
limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB],
[6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB].
<Tip warning={true}>
If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will
have a size greater than `max_shard_size`.
</Tip>
Args:
state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save.
max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit
(like `"5MB"`).
weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`):
The name of the model save file.
"""
logger.warning(
"Note that `shard_checkpoint` is deprecated and will be removed in v4.44. We recommend you using "
"split_torch_state_dict_into_shards from huggingface_hub library"
)
max_shard_size = convert_file_size_to_int(max_shard_size)
sharded_state_dicts = [{}]
last_block_size = 0
total_size = 0
storage_id_to_block = {}
for key, weight in state_dict.items():
# when bnb serialization is used the weights in the state dict can be strings
# check: https://github.com/huggingface/transformers/pull/24416 for more details
if isinstance(weight, str):
continue
else:
storage_id = id_tensor_storage(weight)
# If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
if storage_id in storage_id_to_block and weight.device != torch.device(
"meta"
):
block_id = storage_id_to_block[storage_id]
sharded_state_dicts[block_id][key] = weight
continue
weight_size = weight.numel() * dtype_byte_size(weight.dtype)
# If this weight is going to tip up over the maximal size, we split, but only if we have put at least one
# weight in the current shard.
if (
last_block_size + weight_size > max_shard_size
and len(sharded_state_dicts[-1]) > 0
):
sharded_state_dicts.append({})
last_block_size = 0
sharded_state_dicts[-1][key] = weight
last_block_size += weight_size
total_size += weight_size
storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1
# If we only have one shard, we return it
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None
# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(
".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin"
)
shard_file = shard_file.replace(
".safetensors",
f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors",
)
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file
# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
def save_pretrained(
self,
save_dir: str,
**kwargs,
):
logger.warning(
"You are using save_pretrained, which will re-direct to save_quantized."
)
self.save_quantized(save_dir=save_dir, **kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
quantize_config: QuantizeConfig,
max_memory: Optional[dict] = None,
trust_remote_code: bool = False,
torch_dtype: Union[str, torch.dtype] = "auto",
**model_init_kwargs,
):
"""load un-quantized pretrained model to cpu"""
if not torch.cuda.is_available():
raise EnvironmentError(
"Load pretrained model to do quantization requires CUDA available."
)
if cls.require_trust_remote_code and not trust_remote_code:
raise ValueError(
f"{pretrained_model_name_or_path} requires trust_remote_code=True. Please set trust_remote_code=True to load this model."
)
# allow models to define optional notes that output messages to users that want to use this model
notes = cls.info.get("notes")
if notes:
logger.info(notes)
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model_init_kwargs["trust_remote_code"] = trust_remote_code
config = AutoConfig.from_pretrained(
pretrained_model_name_or_path, **model_init_kwargs
)
if torch_dtype == "auto":
torch_dtype = auto_dtype_from_config(config)
elif not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"torch_dtype value of `{torch_dtype}` is not a torch.dtype instance."
)
# enforce some values despite user specified
model_init_kwargs["torch_dtype"] = torch_dtype
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
if max_memory:
if "disk" in max_memory:
raise NotImplementedError("disk offload not support yet.")
with accelerate.init_empty_weights():
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
model.tie_weights()
max_memory = accelerate.utils.get_balanced_memory(
model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
dtype=model_init_kwargs["torch_dtype"],
low_zero=False,
)
model_init_kwargs["device_map"] = accelerate.infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=[cls.layer_type],
dtype=model_init_kwargs["torch_dtype"],
)
del model
else:
model_init_kwargs["device_map"] = None
torch.cuda.empty_cache()
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, **model_init_kwargs
)
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any(k in model_config for k in seq_len_keys):
for key in seq_len_keys:
if key in model_config:
model.seqlen = model_config[key]
break
else:
logger.warning(
"can't get model's sequence length from model config, will set to 4096."
)
model.seqlen = 4096
model.eval()
return cls(model, quantized=False, quantize_config=quantize_config)
@classmethod
def from_quantized(
cls,
model_name_or_path: Optional[str],
device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None,
max_memory: Optional[dict] = None,
device: Optional[Union[str, int]] = None,
low_cpu_mem_usage: bool = False,
backend: Backend = Backend.AUTO,
torch_dtype: Union[str, torch.dtype] = "auto",
use_cuda_fp16: bool = True,
quantize_config: Optional[QuantizeConfig] = None,
model_basename: Optional[str] = None,
use_safetensors: bool = True,
trust_remote_code: bool = False,
warmup_triton: bool = False,
format: Optional[FORMAT] = None,
allow_unsafe_loading: bool = False,
verify_hash: Optional[Union[str, List[str]]] = None,
**kwargs,