1
- import re
2
- from typing import List , Optional , Tuple
1
+ import os
2
+ from typing import List , Optional
3
+
4
+ from transformers import AutoConfig , AutoTokenizer
5
+ from datasets import Dataset
6
+ from llmcompressor .transformers import SparseAutoModelForCausalLM
7
+ from llmcompressor .transformers import oneshot
8
+ from llmcompressor .modifiers .quantization import QuantizationModifier
9
+
10
+
11
+ class BaseQuantizeConfig :
12
+ """Configuration for model quantization.
13
+
14
+ Args:
15
+ quant_method: Type/precision of quantization method to use.
16
+ At the moment, this is just "fp8" which specifically means
17
+ the fp8_e4m3 format in pytorch.
18
+ activation_scheme: Choice of either "dynamic" or "static" quantization
19
+ of activtions. If "static", then calibration samples are required
20
+ during quantization to produce accurate per-tensor scales for
21
+ activations of Linear modules.
22
+ ignore_patterns: List of patterns used to ignore layers. If a string
23
+ starts with "re:", then everything afterwards is used as python
24
+ regex style matching i.e. re.search(), for each Linear layer.
25
+ By default, "lm_head" is included to ignore the embedding
26
+ Linear layer usually at the end of decoder LLMs
27
+ """
3
28
4
- import torch
5
- from transformers import AutoModelForCausalLM
6
-
7
- from auto_fp8 . config import BaseQuantizeConfig
8
- from auto_fp8 . quantize import (
9
- quantize_activations ,
10
- quantize_weights ,
11
- save_quantized_model ,
12
- )
29
+ def __init__ (
30
+ self ,
31
+ quant_method : str = "fp8" ,
32
+ activation_scheme : str = "static" ,
33
+ ignore_patterns : List [ str ] = [ "lm_head" ],
34
+ ):
35
+ self . quant_method = quant_method
36
+ self . activation_scheme = activation_scheme
37
+ self . ignore_patterns = ignore_patterns
13
38
14
39
15
40
class AutoFP8ForCausalLM :
16
41
def __init__ (
17
- self ,
18
- model : AutoModelForCausalLM ,
19
- quantize_config : BaseQuantizeConfig ,
42
+ self , model : SparseAutoModelForCausalLM , quantize_config : BaseQuantizeConfig
20
43
):
21
44
self .model = model
22
45
self .model_type = self .model .config .model_type
23
46
self .config = self .model .config
24
-
25
- # Gather the Linear module names that we want to ignore
26
- quantize_config .ignored_layers = get_layers_to_ignore (
27
- self .model , quantize_config .ignore_patterns
28
- )
29
-
30
- if quantize_config .kv_cache_quant_targets :
31
- kv_cache_quant_layers = get_kv_cache_quant_layers (
32
- self .model , quantize_config .kv_cache_quant_targets
33
- )
34
- if len (kv_cache_quant_layers ) == 0 :
35
- raise ValueError (
36
- f"Could not find any kv cache layers using kv_cache_quant_targets={ quantize_config .kv_cache_quant_targets } , please fix your argument."
37
- )
38
- quantize_config .kv_cache_quant_layers = kv_cache_quant_layers
39
-
40
47
self .quantize_config = quantize_config
41
48
42
49
@classmethod
43
50
def from_pretrained (
44
51
cls ,
45
52
pretrained_model_name_or_path : str ,
46
53
quantize_config : BaseQuantizeConfig ,
47
- ** model_init_kwargs ,
54
+ ** kwargs ,
48
55
):
49
- """Load the un-quantized pretrained model"""
50
-
51
- def skip (* args , ** kwargs ):
52
- pass
53
-
54
- torch .nn .init .kaiming_uniform_ = skip
55
- torch .nn .init .uniform_ = skip
56
- torch .nn .init .normal_ = skip
57
-
58
- # Parameters related to loading from Hugging Face Hub
59
- cache_dir = model_init_kwargs .pop ("cache_dir" , None )
60
- force_download = model_init_kwargs .pop ("force_download" , False )
61
- resume_download = model_init_kwargs .pop ("resume_download" , False )
62
- proxies = model_init_kwargs .pop ("proxies" , None )
63
- local_files_only = model_init_kwargs .pop ("local_files_only" , False )
64
- use_auth_token = model_init_kwargs .pop ("use_auth_token" , None )
65
- revision = model_init_kwargs .pop ("revision" , None )
66
- subfolder = model_init_kwargs .pop ("subfolder" , "" )
67
- commit_hash = model_init_kwargs .pop ("_commit_hash" , None )
68
-
69
- cached_file_kwargs = {
70
- "cache_dir" : cache_dir ,
71
- "force_download" : force_download ,
72
- "proxies" : proxies ,
73
- "resume_download" : resume_download ,
74
- "local_files_only" : local_files_only ,
75
- "use_auth_token" : use_auth_token ,
76
- "revision" : revision ,
77
- "subfolder" : subfolder ,
78
- "_commit_hash" : commit_hash ,
79
- }
80
-
81
- torch .cuda .empty_cache ()
82
-
83
- # Important defaults
84
- if "torch_dtype" not in model_init_kwargs :
85
- model_init_kwargs ["torch_dtype" ] = "auto"
86
-
87
- if "device_map" not in model_init_kwargs :
88
- model_init_kwargs ["device_map" ] = "auto"
89
-
90
- merged_kwargs = {** model_init_kwargs , ** cached_file_kwargs }
91
- print ("Loading model with the following kwargs:" , merged_kwargs )
92
- model = AutoModelForCausalLM .from_pretrained (
93
- pretrained_model_name_or_path , ** merged_kwargs
56
+ config = AutoConfig .from_pretrained (pretrained_model_name_or_path )
57
+ model = SparseAutoModelForCausalLM .from_pretrained (
58
+ pretrained_model_name_or_path ,
59
+ config = config ,
60
+ device_map = "auto" ,
61
+ torch_dtype = "auto" ,
62
+ ** kwargs ,
94
63
)
95
-
96
- model_config = model .config .to_dict ()
97
- seq_len_keys = ["max_position_embeddings" , "seq_length" , "n_positions" ]
98
- if any (k in model_config for k in seq_len_keys ):
99
- for key in seq_len_keys :
100
- if key in model_config :
101
- model .seqlen = model_config [key ]
102
- break
103
- else :
104
- print ("Can't get model's sequence length, setting to 2048." )
105
- model .seqlen = 2048
106
- model .eval ()
107
-
108
64
return cls (model , quantize_config )
109
65
110
- def quantize (self , calibration_tokens : Optional [torch .Tensor ] = None ):
111
-
112
- # Always quantize the weights as they do not require calibration data
113
- quantize_weights (self .model , self .quantize_config )
114
-
115
- if self .quantize_config .activation_scheme == "static" :
116
- assert (
117
- calibration_tokens is not None
118
- ), "Calibration tokens required for activation quantization"
119
-
120
-
121
- def _prepare_calibration_data (calibration_tokens ):
122
- if hasattr (calibration_tokens , "input_ids" ):
123
- return calibration_tokens .input_ids
124
- return calibration_tokens
125
-
126
- quantize_activations (
127
- self .model ,
128
- self .quantize_config ,
129
- _prepare_calibration_data (calibration_tokens ),
130
- )
66
+ def quantize (self , dataset : Optional [Dataset ] = None ):
67
+ assert (
68
+ self .quantize_config .activation_scheme == "static"
69
+ ), "Dynamic isn't supported yet"
70
+ assert (
71
+ dataset is not None
72
+ ), "Calibration tokens required for static activation quantization"
131
73
132
- def save_quantized (self , save_dir ):
133
- save_quantized_model (
134
- self .model ,
135
- quant_config = self .quantize_config ,
136
- save_dir = save_dir ,
74
+ recipe = QuantizationModifier (
75
+ targets = "Linear" , scheme = "FP8" , ignore = self .quantize_config .ignore_patterns
137
76
)
138
77
78
+ oneshot (
79
+ model = self .model ,
80
+ dataset = dataset ,
81
+ recipe = recipe ,
82
+ )
139
83
140
- def get_layers_to_ignore (model , ignore_patterns ) -> List [str ]:
141
- ignored_layers = set ()
142
-
143
- for name , linear in model .named_modules ():
144
- if not isinstance (linear , torch .nn .Linear ):
145
- continue
146
-
147
- for ignore_pattern in ignore_patterns :
148
- regex_prefix = "re:"
149
- if ignore_pattern .startswith (regex_prefix ):
150
- # check if name matches regex and add to set if true
151
- regex_pattern = ignore_pattern [len (regex_prefix ) :]
152
- if re .search (regex_pattern , name ):
153
- ignored_layers .add (name )
154
- else :
155
- # else, exact match
156
- if ignore_pattern == name :
157
- ignored_layers .add (name )
158
-
159
- return list (ignored_layers )
160
-
161
-
162
- def get_kv_cache_quant_layers (model , kv_cache_quant_targets : Tuple [str ]) -> List [str ]:
163
- kv_cache_quant_layers = []
164
-
165
- for name , linear in model .named_modules ():
166
- if not isinstance (linear , torch .nn .Linear ):
167
- continue
168
-
169
- for output_quant_target in kv_cache_quant_targets :
170
- if name .endswith (output_quant_target ):
171
- kv_cache_quant_layers .append (name )
84
+ def save_quantized (self , save_directory : str ):
85
+ self .save_pretrained (save_directory , save_compressed = True )
172
86
173
- return kv_cache_quant_layers
87
+ def save_pretrained (self , save_directory : str , save_compressed : bool = True ):
88
+ self .model .save_pretrained (save_directory , save_compressed = save_compressed )
89
+ tokenizer = AutoTokenizer .from_pretrained (self .model .config ._name_or_path )
90
+ tokenizer .save_pretrained (save_directory )
91
+ print (f"Saved final checkpoint to { os .path .abspath (save_directory )} " )
0 commit comments