Skip to content

Commit e286fa9

Browse files
committed
Switch backend with llm-compressor
1 parent 4b2092c commit e286fa9

File tree

6 files changed

+358
-201
lines changed

6 files changed

+358
-201
lines changed

auto_fp8/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from .config import BaseQuantizeConfig
2-
from .modeling import AutoFP8ForCausalLM
1+
from .modeling import AutoFP8ForCausalLM, BaseQuantizeConfig
32

43
__all__ = [
54
"AutoFP8ForCausalLM",

auto_fp8/config.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

auto_fp8/modeling.py

Lines changed: 66 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -1,173 +1,91 @@
1-
import re
2-
from typing import List, Optional, Tuple
1+
import os
2+
from typing import List, Optional
3+
4+
from transformers import AutoConfig, AutoTokenizer
5+
from datasets import Dataset
6+
from llmcompressor.transformers import SparseAutoModelForCausalLM
7+
from llmcompressor.transformers import oneshot
8+
from llmcompressor.modifiers.quantization import QuantizationModifier
9+
10+
11+
class BaseQuantizeConfig:
12+
"""Configuration for model quantization.
13+
14+
Args:
15+
quant_method: Type/precision of quantization method to use.
16+
At the moment, this is just "fp8" which specifically means
17+
the fp8_e4m3 format in pytorch.
18+
activation_scheme: Choice of either "dynamic" or "static" quantization
19+
of activtions. If "static", then calibration samples are required
20+
during quantization to produce accurate per-tensor scales for
21+
activations of Linear modules.
22+
ignore_patterns: List of patterns used to ignore layers. If a string
23+
starts with "re:", then everything afterwards is used as python
24+
regex style matching i.e. re.search(), for each Linear layer.
25+
By default, "lm_head" is included to ignore the embedding
26+
Linear layer usually at the end of decoder LLMs
27+
"""
328

4-
import torch
5-
from transformers import AutoModelForCausalLM
6-
7-
from auto_fp8.config import BaseQuantizeConfig
8-
from auto_fp8.quantize import (
9-
quantize_activations,
10-
quantize_weights,
11-
save_quantized_model,
12-
)
29+
def __init__(
30+
self,
31+
quant_method: str = "fp8",
32+
activation_scheme: str = "static",
33+
ignore_patterns: List[str] = ["lm_head"],
34+
):
35+
self.quant_method = quant_method
36+
self.activation_scheme = activation_scheme
37+
self.ignore_patterns = ignore_patterns
1338

1439

1540
class AutoFP8ForCausalLM:
1641
def __init__(
17-
self,
18-
model: AutoModelForCausalLM,
19-
quantize_config: BaseQuantizeConfig,
42+
self, model: SparseAutoModelForCausalLM, quantize_config: BaseQuantizeConfig
2043
):
2144
self.model = model
2245
self.model_type = self.model.config.model_type
2346
self.config = self.model.config
24-
25-
# Gather the Linear module names that we want to ignore
26-
quantize_config.ignored_layers = get_layers_to_ignore(
27-
self.model, quantize_config.ignore_patterns
28-
)
29-
30-
if quantize_config.kv_cache_quant_targets:
31-
kv_cache_quant_layers = get_kv_cache_quant_layers(
32-
self.model, quantize_config.kv_cache_quant_targets
33-
)
34-
if len(kv_cache_quant_layers) == 0:
35-
raise ValueError(
36-
f"Could not find any kv cache layers using kv_cache_quant_targets={quantize_config.kv_cache_quant_targets}, please fix your argument."
37-
)
38-
quantize_config.kv_cache_quant_layers = kv_cache_quant_layers
39-
4047
self.quantize_config = quantize_config
4148

4249
@classmethod
4350
def from_pretrained(
4451
cls,
4552
pretrained_model_name_or_path: str,
4653
quantize_config: BaseQuantizeConfig,
47-
**model_init_kwargs,
54+
**kwargs,
4855
):
49-
"""Load the un-quantized pretrained model"""
50-
51-
def skip(*args, **kwargs):
52-
pass
53-
54-
torch.nn.init.kaiming_uniform_ = skip
55-
torch.nn.init.uniform_ = skip
56-
torch.nn.init.normal_ = skip
57-
58-
# Parameters related to loading from Hugging Face Hub
59-
cache_dir = model_init_kwargs.pop("cache_dir", None)
60-
force_download = model_init_kwargs.pop("force_download", False)
61-
resume_download = model_init_kwargs.pop("resume_download", False)
62-
proxies = model_init_kwargs.pop("proxies", None)
63-
local_files_only = model_init_kwargs.pop("local_files_only", False)
64-
use_auth_token = model_init_kwargs.pop("use_auth_token", None)
65-
revision = model_init_kwargs.pop("revision", None)
66-
subfolder = model_init_kwargs.pop("subfolder", "")
67-
commit_hash = model_init_kwargs.pop("_commit_hash", None)
68-
69-
cached_file_kwargs = {
70-
"cache_dir": cache_dir,
71-
"force_download": force_download,
72-
"proxies": proxies,
73-
"resume_download": resume_download,
74-
"local_files_only": local_files_only,
75-
"use_auth_token": use_auth_token,
76-
"revision": revision,
77-
"subfolder": subfolder,
78-
"_commit_hash": commit_hash,
79-
}
80-
81-
torch.cuda.empty_cache()
82-
83-
# Important defaults
84-
if "torch_dtype" not in model_init_kwargs:
85-
model_init_kwargs["torch_dtype"] = "auto"
86-
87-
if "device_map" not in model_init_kwargs:
88-
model_init_kwargs["device_map"] = "auto"
89-
90-
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
91-
print("Loading model with the following kwargs:", merged_kwargs)
92-
model = AutoModelForCausalLM.from_pretrained(
93-
pretrained_model_name_or_path, **merged_kwargs
56+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
57+
model = SparseAutoModelForCausalLM.from_pretrained(
58+
pretrained_model_name_or_path,
59+
config=config,
60+
device_map="auto",
61+
torch_dtype="auto",
62+
**kwargs,
9463
)
95-
96-
model_config = model.config.to_dict()
97-
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
98-
if any(k in model_config for k in seq_len_keys):
99-
for key in seq_len_keys:
100-
if key in model_config:
101-
model.seqlen = model_config[key]
102-
break
103-
else:
104-
print("Can't get model's sequence length, setting to 2048.")
105-
model.seqlen = 2048
106-
model.eval()
107-
10864
return cls(model, quantize_config)
10965

110-
def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):
111-
112-
# Always quantize the weights as they do not require calibration data
113-
quantize_weights(self.model, self.quantize_config)
114-
115-
if self.quantize_config.activation_scheme == "static":
116-
assert (
117-
calibration_tokens is not None
118-
), "Calibration tokens required for activation quantization"
119-
120-
121-
def _prepare_calibration_data(calibration_tokens):
122-
if hasattr(calibration_tokens, "input_ids"):
123-
return calibration_tokens.input_ids
124-
return calibration_tokens
125-
126-
quantize_activations(
127-
self.model,
128-
self.quantize_config,
129-
_prepare_calibration_data(calibration_tokens),
130-
)
66+
def quantize(self, dataset: Optional[Dataset] = None):
67+
assert (
68+
self.quantize_config.activation_scheme == "static"
69+
), "Dynamic isn't supported yet"
70+
assert (
71+
dataset is not None
72+
), "Calibration tokens required for static activation quantization"
13173

132-
def save_quantized(self, save_dir):
133-
save_quantized_model(
134-
self.model,
135-
quant_config=self.quantize_config,
136-
save_dir=save_dir,
74+
recipe = QuantizationModifier(
75+
targets="Linear", scheme="FP8", ignore=self.quantize_config.ignore_patterns
13776
)
13877

78+
oneshot(
79+
model=self.model,
80+
dataset=dataset,
81+
recipe=recipe,
82+
)
13983

140-
def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
141-
ignored_layers = set()
142-
143-
for name, linear in model.named_modules():
144-
if not isinstance(linear, torch.nn.Linear):
145-
continue
146-
147-
for ignore_pattern in ignore_patterns:
148-
regex_prefix = "re:"
149-
if ignore_pattern.startswith(regex_prefix):
150-
# check if name matches regex and add to set if true
151-
regex_pattern = ignore_pattern[len(regex_prefix) :]
152-
if re.search(regex_pattern, name):
153-
ignored_layers.add(name)
154-
else:
155-
# else, exact match
156-
if ignore_pattern == name:
157-
ignored_layers.add(name)
158-
159-
return list(ignored_layers)
160-
161-
162-
def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
163-
kv_cache_quant_layers = []
164-
165-
for name, linear in model.named_modules():
166-
if not isinstance(linear, torch.nn.Linear):
167-
continue
168-
169-
for output_quant_target in kv_cache_quant_targets:
170-
if name.endswith(output_quant_target):
171-
kv_cache_quant_layers.append(name)
84+
def save_quantized(self, save_directory: str):
85+
self.save_pretrained(save_directory, save_compressed=True)
17286

173-
return kv_cache_quant_layers
87+
def save_pretrained(self, save_directory: str, save_compressed: bool = True):
88+
self.model.save_pretrained(save_directory, save_compressed=save_compressed)
89+
tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
90+
tokenizer.save_pretrained(save_directory)
91+
print(f"Saved final checkpoint to {os.path.abspath(save_directory)}")

0 commit comments

Comments
 (0)