1
- import re
2
- from typing import List , Optional , Tuple
3
-
4
- import torch
5
- from transformers import AutoModelForCausalLM
6
-
7
- from auto_fp8 .config import BaseQuantizeConfig
8
- from auto_fp8 .quantize import (
9
- quantize_activations ,
10
- quantize_weights ,
11
- save_quantized_model ,
12
- )
13
-
14
-
15
- class AutoFP8ForCausalLM :
1
+ import os
2
+ from typing import List , Optional
3
+
4
+ from transformers import AutoConfig , AutoTokenizer
5
+ from datasets import Dataset
6
+ from llmcompressor .transformers import SparseAutoModelForCausalLM
7
+ from llmcompressor .transformers import oneshot
8
+ from llmcompressor .modifiers .quantization import QuantizationModifier
9
+
10
+ class BaseQuantizeConfig :
11
+ """Configuration for model quantization.
12
+
13
+ Args:
14
+ quant_method: Type/precision of quantization method to use.
15
+ At the moment, this is just "fp8" which specifically means
16
+ the fp8_e4m3 format in pytorch.
17
+ activation_scheme: Choice of either "dynamic" or "static" quantization
18
+ of activtions. If "static", then calibration samples are required
19
+ during quantization to produce accurate per-tensor scales for
20
+ activations of Linear modules.
21
+ ignore_patterns: List of patterns used to ignore layers. If a string
22
+ starts with "re:", then everything afterwards is used as python
23
+ regex style matching i.e. re.search(), for each Linear layer.
24
+ By default, "lm_head" is included to ignore the embedding
25
+ Linear layer usually at the end of decoder LLMs
26
+ """
16
27
def __init__ (
17
28
self ,
18
- model : AutoModelForCausalLM ,
19
- quantize_config : BaseQuantizeConfig ,
29
+ quant_method : str = "fp8" ,
30
+ activation_scheme : str = "static" ,
31
+ ignore_patterns : List [str ] = ["lm_head" ],
20
32
):
33
+ self .quant_method = quant_method
34
+ self .activation_scheme = activation_scheme
35
+ self .ignore_patterns = ignore_patterns
36
+
37
+
38
+ class AutoFP8ForCausalLM :
39
+ def __init__ (self , model : SparseAutoModelForCausalLM , quantize_config : BaseQuantizeConfig ):
21
40
self .model = model
22
41
self .model_type = self .model .config .model_type
23
42
self .config = self .model .config
43
+ < << << << HEAD
24
44
25
45
# Gather the Linear module names that we want to ignore
26
46
quantize_config .ignored_layers = get_layers_to_ignore (
@@ -45,76 +65,23 @@ def __init__(
45
65
)
46
66
quantize_config .kv_cache_quant_layers = kv_cache_quant_layers
47
67
68
+ == == == =
69
+ >> > >> > > ba7d420 (Switch backend to use llm - compressor )
48
70
self .quantize_config = quantize_config
49
71
50
72
@classmethod
51
- def from_pretrained (
52
- cls ,
53
- pretrained_model_name_or_path : str ,
54
- quantize_config : BaseQuantizeConfig ,
55
- ** model_init_kwargs ,
56
- ):
57
- """Load the un-quantized pretrained model"""
58
-
59
- def skip (* args , ** kwargs ):
60
- pass
61
-
62
- torch .nn .init .kaiming_uniform_ = skip
63
- torch .nn .init .uniform_ = skip
64
- torch .nn .init .normal_ = skip
65
-
66
- # Parameters related to loading from Hugging Face Hub
67
- cache_dir = model_init_kwargs .pop ("cache_dir" , None )
68
- force_download = model_init_kwargs .pop ("force_download" , False )
69
- resume_download = model_init_kwargs .pop ("resume_download" , False )
70
- proxies = model_init_kwargs .pop ("proxies" , None )
71
- local_files_only = model_init_kwargs .pop ("local_files_only" , False )
72
- use_auth_token = model_init_kwargs .pop ("use_auth_token" , None )
73
- revision = model_init_kwargs .pop ("revision" , None )
74
- subfolder = model_init_kwargs .pop ("subfolder" , "" )
75
- commit_hash = model_init_kwargs .pop ("_commit_hash" , None )
76
-
77
- cached_file_kwargs = {
78
- "cache_dir" : cache_dir ,
79
- "force_download" : force_download ,
80
- "proxies" : proxies ,
81
- "resume_download" : resume_download ,
82
- "local_files_only" : local_files_only ,
83
- "use_auth_token" : use_auth_token ,
84
- "revision" : revision ,
85
- "subfolder" : subfolder ,
86
- "_commit_hash" : commit_hash ,
87
- }
88
-
89
- torch .cuda .empty_cache ()
90
-
91
- # Important defaults
92
- if "torch_dtype" not in model_init_kwargs :
93
- model_init_kwargs ["torch_dtype" ] = "auto"
94
-
95
- if "device_map" not in model_init_kwargs :
96
- model_init_kwargs ["device_map" ] = "auto"
97
-
98
- merged_kwargs = {** model_init_kwargs , ** cached_file_kwargs }
99
- print ("Loading model with the following kwargs:" , merged_kwargs )
100
- model = AutoModelForCausalLM .from_pretrained (
101
- pretrained_model_name_or_path , ** merged_kwargs
73
+ def from_pretrained (cls , pretrained_model_name_or_path : str , quantize_config : BaseQuantizeConfig , ** kwargs ):
74
+ config = AutoConfig .from_pretrained (pretrained_model_name_or_path )
75
+ model = SparseAutoModelForCausalLM .from_pretrained (
76
+ pretrained_model_name_or_path ,
77
+ config = config ,
78
+ device_map = "auto" ,
79
+ torch_dtype = "auto" ,
80
+ ** kwargs
102
81
)
103
-
104
- model_config = model .config .to_dict ()
105
- seq_len_keys = ["max_position_embeddings" , "seq_length" , "n_positions" ]
106
- if any (k in model_config for k in seq_len_keys ):
107
- for key in seq_len_keys :
108
- if key in model_config :
109
- model .seqlen = model_config [key ]
110
- break
111
- else :
112
- print ("Can't get model's sequence length, setting to 2048." )
113
- model .seqlen = 2048
114
- model .eval ()
115
-
116
82
return cls (model , quantize_config )
117
83
84
+ << << < << HEAD
118
85
def quantize (self , calibration_tokens : Optional [torch .Tensor ] = None ):
119
86
<< << << < HEAD
120
87
<< < << << HEAD
@@ -161,12 +128,28 @@ def save_quantized(self, save_dir):
161
128
self .model ,
162
129
quant_config = self .quantize_config ,
163
130
save_dir = save_dir ,
131
+ == == == =
132
+ def quantize (self , dataset : Optional [Dataset ] = None ):
133
+ assert self .quantize_config .activation_scheme == "static"
134
+ assert dataset is not None , "Calibration tokens required for static activation quantization"
135
+
136
+ recipe = QuantizationModifier (
137
+ targets = "Linear" ,
138
+ scheme = "FP8" ,
139
+ ignore = self .quantize_config .ignore_patterns
140
+ >> > >> >> ba7d420 (Switch backend to use llm - compressor )
164
141
)
165
142
143
+ oneshot (
144
+ model = self .model ,
145
+ dataset = dataset ,
146
+ recipe = recipe ,
147
+ )
166
148
167
- def get_layers_to_ignore ( model , ignore_patterns ) - > List [ str ] :
168
- ignored_layers = set ( )
149
+ def save_quantized ( self , save_directory : str ) :
150
+ self . save_pretrained ( save_directory , save_compressed = True )
169
151
152
+ << << < << HEAD
170
153
for name , linear in model .named_modules ():
171
154
if not isinstance (linear , torch .nn .Linear ):
172
155
continue
@@ -220,3 +203,10 @@ def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List
220
203
221
204
return kv_cache_quant_layers
222
205
>> >> > >> c3acdee (Switch from output_scale to kv_scale )
206
+ == == == =
207
+ def save_pretrained (self , save_directory : str , save_compressed : bool = True ):
208
+ self .model .save_pretrained (save_directory , save_compressed = save_compressed )
209
+ tokenizer = AutoTokenizer .from_pretrained (self .model .config ._name_or_path )
210
+ tokenizer .save_pretrained (save_directory )
211
+ print (f"Saved final checkpoint to { os .path .abspath (save_directory )} " )
212
+ >> >> > >> ba7d420 (Switch backend to use llm - compressor )
0 commit comments