1
- import re
2
- from typing import List , Optional , Tuple
3
-
4
- import torch
5
- from transformers import AutoModelForCausalLM
6
-
7
- from auto_fp8 .config import BaseQuantizeConfig
8
- from auto_fp8 .quantize import (
9
- quantize_activations ,
10
- quantize_weights ,
11
- save_quantized_model ,
12
- )
13
-
14
-
15
- class AutoFP8ForCausalLM :
1
+ import os
2
+ from typing import List , Optional
3
+
4
+ from transformers import AutoConfig , AutoTokenizer
5
+ from datasets import Dataset
6
+ from llmcompressor .transformers import SparseAutoModelForCausalLM
7
+ from llmcompressor .transformers import oneshot
8
+ from llmcompressor .modifiers .quantization import QuantizationModifier
9
+
10
+ class BaseQuantizeConfig :
11
+ """Configuration for model quantization.
12
+
13
+ Args:
14
+ quant_method: Type/precision of quantization method to use.
15
+ At the moment, this is just "fp8" which specifically means
16
+ the fp8_e4m3 format in pytorch.
17
+ activation_scheme: Choice of either "dynamic" or "static" quantization
18
+ of activtions. If "static", then calibration samples are required
19
+ during quantization to produce accurate per-tensor scales for
20
+ activations of Linear modules.
21
+ ignore_patterns: List of patterns used to ignore layers. If a string
22
+ starts with "re:", then everything afterwards is used as python
23
+ regex style matching i.e. re.search(), for each Linear layer.
24
+ By default, "lm_head" is included to ignore the embedding
25
+ Linear layer usually at the end of decoder LLMs
26
+ """
16
27
def __init__ (
17
28
self ,
18
- model : AutoModelForCausalLM ,
19
- quantize_config : BaseQuantizeConfig ,
29
+ quant_method : str = "fp8" ,
30
+ activation_scheme : str = "static" ,
31
+ ignore_patterns : List [str ] = ["lm_head" ],
20
32
):
33
+ self .quant_method = quant_method
34
+ self .activation_scheme = activation_scheme
35
+ self .ignore_patterns = ignore_patterns
36
+
37
+
38
+ class AutoFP8ForCausalLM :
39
+ def __init__ (self , model : SparseAutoModelForCausalLM , quantize_config : BaseQuantizeConfig ):
21
40
self .model = model
22
41
self .model_type = self .model .config .model_type
23
42
self .config = self .model .config
24
-
25
- # Gather the Linear module names that we want to ignore
26
- quantize_config .ignored_layers = get_layers_to_ignore (
27
- self .model , quantize_config .ignore_patterns
28
- )
29
-
30
- if quantize_config .kv_cache_quant_targets :
31
- kv_cache_quant_layers = get_kv_cache_quant_layers (
32
- self .model , quantize_config .kv_cache_quant_targets
33
- )
34
- if len (kv_cache_quant_layers ) == 0 :
35
- raise ValueError (
36
- f"Could not find any kv cache layers using kv_cache_quant_targets={ quantize_config .kv_cache_quant_targets } , please fix your argument."
37
- )
38
- quantize_config .kv_cache_quant_layers = kv_cache_quant_layers
39
-
40
43
self .quantize_config = quantize_config
41
44
42
45
@classmethod
43
- def from_pretrained (
44
- cls ,
45
- pretrained_model_name_or_path : str ,
46
- quantize_config : BaseQuantizeConfig ,
47
- ** model_init_kwargs ,
48
- ):
49
- """Load the un-quantized pretrained model"""
50
-
51
- def skip (* args , ** kwargs ):
52
- pass
53
-
54
- torch .nn .init .kaiming_uniform_ = skip
55
- torch .nn .init .uniform_ = skip
56
- torch .nn .init .normal_ = skip
57
-
58
- # Parameters related to loading from Hugging Face Hub
59
- cache_dir = model_init_kwargs .pop ("cache_dir" , None )
60
- force_download = model_init_kwargs .pop ("force_download" , False )
61
- resume_download = model_init_kwargs .pop ("resume_download" , False )
62
- proxies = model_init_kwargs .pop ("proxies" , None )
63
- local_files_only = model_init_kwargs .pop ("local_files_only" , False )
64
- use_auth_token = model_init_kwargs .pop ("use_auth_token" , None )
65
- revision = model_init_kwargs .pop ("revision" , None )
66
- subfolder = model_init_kwargs .pop ("subfolder" , "" )
67
- commit_hash = model_init_kwargs .pop ("_commit_hash" , None )
68
-
69
- cached_file_kwargs = {
70
- "cache_dir" : cache_dir ,
71
- "force_download" : force_download ,
72
- "proxies" : proxies ,
73
- "resume_download" : resume_download ,
74
- "local_files_only" : local_files_only ,
75
- "use_auth_token" : use_auth_token ,
76
- "revision" : revision ,
77
- "subfolder" : subfolder ,
78
- "_commit_hash" : commit_hash ,
79
- }
80
-
81
- torch .cuda .empty_cache ()
82
-
83
- # Important defaults
84
- if "torch_dtype" not in model_init_kwargs :
85
- model_init_kwargs ["torch_dtype" ] = "auto"
86
-
87
- if "device_map" not in model_init_kwargs :
88
- model_init_kwargs ["device_map" ] = "auto"
89
-
90
- merged_kwargs = {** model_init_kwargs , ** cached_file_kwargs }
91
- print ("Loading model with the following kwargs:" , merged_kwargs )
92
- model = AutoModelForCausalLM .from_pretrained (
93
- pretrained_model_name_or_path , ** merged_kwargs
46
+ def from_pretrained (cls , pretrained_model_name_or_path : str , quantize_config : BaseQuantizeConfig , ** kwargs ):
47
+ config = AutoConfig .from_pretrained (pretrained_model_name_or_path )
48
+ model = SparseAutoModelForCausalLM .from_pretrained (
49
+ pretrained_model_name_or_path ,
50
+ config = config ,
51
+ device_map = "auto" ,
52
+ torch_dtype = "auto" ,
53
+ ** kwargs
94
54
)
95
-
96
- model_config = model .config .to_dict ()
97
- seq_len_keys = ["max_position_embeddings" , "seq_length" , "n_positions" ]
98
- if any (k in model_config for k in seq_len_keys ):
99
- for key in seq_len_keys :
100
- if key in model_config :
101
- model .seqlen = model_config [key ]
102
- break
103
- else :
104
- print ("Can't get model's sequence length, setting to 2048." )
105
- model .seqlen = 2048
106
- model .eval ()
107
-
108
55
return cls (model , quantize_config )
109
56
110
- def quantize (self , calibration_tokens : Optional [torch .Tensor ] = None ):
111
-
112
- # Always quantize the weights as they do not require calibration data
113
- quantize_weights (self .model , self .quantize_config )
114
-
115
- if self .quantize_config .activation_scheme == "static" :
116
- assert (
117
- calibration_tokens is not None
118
- ), "Calibration tokens required for activation quantization"
119
-
120
-
121
- def _prepare_calibration_data (calibration_tokens ):
122
- if hasattr (calibration_tokens , "input_ids" ):
123
- return calibration_tokens .input_ids
124
- return calibration_tokens
57
+ def quantize (self , dataset : Optional [Dataset ] = None ):
58
+ assert self .quantize_config .activation_scheme == "static"
59
+ assert dataset is not None , "Calibration tokens required for static activation quantization"
125
60
126
- quantize_activations (
127
- self .model ,
128
- self .quantize_config ,
129
- _prepare_calibration_data (calibration_tokens ),
130
- )
131
-
132
- def save_quantized (self , save_dir ):
133
- save_quantized_model (
134
- self .model ,
135
- quant_config = self .quantize_config ,
136
- save_dir = save_dir ,
61
+ recipe = QuantizationModifier (
62
+ targets = "Linear" ,
63
+ scheme = "FP8" ,
64
+ ignore = self .quantize_config .ignore_patterns
137
65
)
138
66
67
+ oneshot (
68
+ model = self .model ,
69
+ dataset = dataset ,
70
+ recipe = recipe ,
71
+ )
139
72
140
- def get_layers_to_ignore (model , ignore_patterns ) -> List [str ]:
141
- ignored_layers = set ()
142
-
143
- for name , linear in model .named_modules ():
144
- if not isinstance (linear , torch .nn .Linear ):
145
- continue
146
-
147
- for ignore_pattern in ignore_patterns :
148
- regex_prefix = "re:"
149
- if ignore_pattern .startswith (regex_prefix ):
150
- # check if name matches regex and add to set if true
151
- regex_pattern = ignore_pattern [len (regex_prefix ) :]
152
- if re .search (regex_pattern , name ):
153
- ignored_layers .add (name )
154
- else :
155
- # else, exact match
156
- if ignore_pattern == name :
157
- ignored_layers .add (name )
158
-
159
- return list (ignored_layers )
160
-
161
-
162
- def get_kv_cache_quant_layers (model , kv_cache_quant_targets : Tuple [str ]) -> List [str ]:
163
- kv_cache_quant_layers = []
164
-
165
- for name , linear in model .named_modules ():
166
- if not isinstance (linear , torch .nn .Linear ):
167
- continue
168
-
169
- for output_quant_target in kv_cache_quant_targets :
170
- if name .endswith (output_quant_target ):
171
- kv_cache_quant_layers .append (name )
73
+ def save_quantized (self , save_directory : str ):
74
+ self .save_pretrained (save_directory , save_compressed = True )
172
75
173
- return kv_cache_quant_layers
76
+ def save_pretrained (self , save_directory : str , save_compressed : bool = True ):
77
+ self .model .save_pretrained (save_directory , save_compressed = save_compressed )
78
+ tokenizer = AutoTokenizer .from_pretrained (self .model .config ._name_or_path )
79
+ tokenizer .save_pretrained (save_directory )
80
+ print (f"Saved final checkpoint to { os .path .abspath (save_directory )} " )
0 commit comments