Skip to content

Commit ba7d420

Browse files
committed
Switch backend to use llm-compressor
1 parent 2f4f28a commit ba7d420

File tree

4 files changed

+79
-208
lines changed

4 files changed

+79
-208
lines changed

auto_fp8/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from .config import BaseQuantizeConfig
2-
from .modeling import AutoFP8ForCausalLM
1+
from .modeling import AutoFP8ForCausalLM, BaseQuantizeConfig
32

43
__all__ = [
54
"AutoFP8ForCausalLM",

auto_fp8/config.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

auto_fp8/modeling.py

Lines changed: 63 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -1,173 +1,80 @@
1-
import re
2-
from typing import List, Optional, Tuple
3-
4-
import torch
5-
from transformers import AutoModelForCausalLM
6-
7-
from auto_fp8.config import BaseQuantizeConfig
8-
from auto_fp8.quantize import (
9-
quantize_activations,
10-
quantize_weights,
11-
save_quantized_model,
12-
)
13-
14-
15-
class AutoFP8ForCausalLM:
1+
import os
2+
from typing import List, Optional
3+
4+
from transformers import AutoConfig, AutoTokenizer
5+
from datasets import Dataset
6+
from llmcompressor.transformers import SparseAutoModelForCausalLM
7+
from llmcompressor.transformers import oneshot
8+
from llmcompressor.modifiers.quantization import QuantizationModifier
9+
10+
class BaseQuantizeConfig:
11+
"""Configuration for model quantization.
12+
13+
Args:
14+
quant_method: Type/precision of quantization method to use.
15+
At the moment, this is just "fp8" which specifically means
16+
the fp8_e4m3 format in pytorch.
17+
activation_scheme: Choice of either "dynamic" or "static" quantization
18+
of activtions. If "static", then calibration samples are required
19+
during quantization to produce accurate per-tensor scales for
20+
activations of Linear modules.
21+
ignore_patterns: List of patterns used to ignore layers. If a string
22+
starts with "re:", then everything afterwards is used as python
23+
regex style matching i.e. re.search(), for each Linear layer.
24+
By default, "lm_head" is included to ignore the embedding
25+
Linear layer usually at the end of decoder LLMs
26+
"""
1627
def __init__(
1728
self,
18-
model: AutoModelForCausalLM,
19-
quantize_config: BaseQuantizeConfig,
29+
quant_method: str = "fp8",
30+
activation_scheme: str = "static",
31+
ignore_patterns: List[str] = ["lm_head"],
2032
):
33+
self.quant_method = quant_method
34+
self.activation_scheme = activation_scheme
35+
self.ignore_patterns = ignore_patterns
36+
37+
38+
class AutoFP8ForCausalLM:
39+
def __init__(self, model: SparseAutoModelForCausalLM, quantize_config: BaseQuantizeConfig):
2140
self.model = model
2241
self.model_type = self.model.config.model_type
2342
self.config = self.model.config
24-
25-
# Gather the Linear module names that we want to ignore
26-
quantize_config.ignored_layers = get_layers_to_ignore(
27-
self.model, quantize_config.ignore_patterns
28-
)
29-
30-
if quantize_config.kv_cache_quant_targets:
31-
kv_cache_quant_layers = get_kv_cache_quant_layers(
32-
self.model, quantize_config.kv_cache_quant_targets
33-
)
34-
if len(kv_cache_quant_layers) == 0:
35-
raise ValueError(
36-
f"Could not find any kv cache layers using kv_cache_quant_targets={quantize_config.kv_cache_quant_targets}, please fix your argument."
37-
)
38-
quantize_config.kv_cache_quant_layers = kv_cache_quant_layers
39-
4043
self.quantize_config = quantize_config
4144

4245
@classmethod
43-
def from_pretrained(
44-
cls,
45-
pretrained_model_name_or_path: str,
46-
quantize_config: BaseQuantizeConfig,
47-
**model_init_kwargs,
48-
):
49-
"""Load the un-quantized pretrained model"""
50-
51-
def skip(*args, **kwargs):
52-
pass
53-
54-
torch.nn.init.kaiming_uniform_ = skip
55-
torch.nn.init.uniform_ = skip
56-
torch.nn.init.normal_ = skip
57-
58-
# Parameters related to loading from Hugging Face Hub
59-
cache_dir = model_init_kwargs.pop("cache_dir", None)
60-
force_download = model_init_kwargs.pop("force_download", False)
61-
resume_download = model_init_kwargs.pop("resume_download", False)
62-
proxies = model_init_kwargs.pop("proxies", None)
63-
local_files_only = model_init_kwargs.pop("local_files_only", False)
64-
use_auth_token = model_init_kwargs.pop("use_auth_token", None)
65-
revision = model_init_kwargs.pop("revision", None)
66-
subfolder = model_init_kwargs.pop("subfolder", "")
67-
commit_hash = model_init_kwargs.pop("_commit_hash", None)
68-
69-
cached_file_kwargs = {
70-
"cache_dir": cache_dir,
71-
"force_download": force_download,
72-
"proxies": proxies,
73-
"resume_download": resume_download,
74-
"local_files_only": local_files_only,
75-
"use_auth_token": use_auth_token,
76-
"revision": revision,
77-
"subfolder": subfolder,
78-
"_commit_hash": commit_hash,
79-
}
80-
81-
torch.cuda.empty_cache()
82-
83-
# Important defaults
84-
if "torch_dtype" not in model_init_kwargs:
85-
model_init_kwargs["torch_dtype"] = "auto"
86-
87-
if "device_map" not in model_init_kwargs:
88-
model_init_kwargs["device_map"] = "auto"
89-
90-
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
91-
print("Loading model with the following kwargs:", merged_kwargs)
92-
model = AutoModelForCausalLM.from_pretrained(
93-
pretrained_model_name_or_path, **merged_kwargs
46+
def from_pretrained(cls, pretrained_model_name_or_path: str, quantize_config: BaseQuantizeConfig, **kwargs):
47+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
48+
model = SparseAutoModelForCausalLM.from_pretrained(
49+
pretrained_model_name_or_path,
50+
config=config,
51+
device_map="auto",
52+
torch_dtype="auto",
53+
**kwargs
9454
)
95-
96-
model_config = model.config.to_dict()
97-
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
98-
if any(k in model_config for k in seq_len_keys):
99-
for key in seq_len_keys:
100-
if key in model_config:
101-
model.seqlen = model_config[key]
102-
break
103-
else:
104-
print("Can't get model's sequence length, setting to 2048.")
105-
model.seqlen = 2048
106-
model.eval()
107-
10855
return cls(model, quantize_config)
10956

110-
def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):
111-
112-
# Always quantize the weights as they do not require calibration data
113-
quantize_weights(self.model, self.quantize_config)
114-
115-
if self.quantize_config.activation_scheme == "static":
116-
assert (
117-
calibration_tokens is not None
118-
), "Calibration tokens required for activation quantization"
119-
120-
121-
def _prepare_calibration_data(calibration_tokens):
122-
if hasattr(calibration_tokens, "input_ids"):
123-
return calibration_tokens.input_ids
124-
return calibration_tokens
57+
def quantize(self, dataset: Optional[Dataset] = None):
58+
assert self.quantize_config.activation_scheme == "static"
59+
assert dataset is not None, "Calibration tokens required for static activation quantization"
12560

126-
quantize_activations(
127-
self.model,
128-
self.quantize_config,
129-
_prepare_calibration_data(calibration_tokens),
130-
)
131-
132-
def save_quantized(self, save_dir):
133-
save_quantized_model(
134-
self.model,
135-
quant_config=self.quantize_config,
136-
save_dir=save_dir,
61+
recipe = QuantizationModifier(
62+
targets="Linear",
63+
scheme="FP8",
64+
ignore=self.quantize_config.ignore_patterns
13765
)
13866

67+
oneshot(
68+
model=self.model,
69+
dataset=dataset,
70+
recipe=recipe,
71+
)
13972

140-
def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
141-
ignored_layers = set()
142-
143-
for name, linear in model.named_modules():
144-
if not isinstance(linear, torch.nn.Linear):
145-
continue
146-
147-
for ignore_pattern in ignore_patterns:
148-
regex_prefix = "re:"
149-
if ignore_pattern.startswith(regex_prefix):
150-
# check if name matches regex and add to set if true
151-
regex_pattern = ignore_pattern[len(regex_prefix) :]
152-
if re.search(regex_pattern, name):
153-
ignored_layers.add(name)
154-
else:
155-
# else, exact match
156-
if ignore_pattern == name:
157-
ignored_layers.add(name)
158-
159-
return list(ignored_layers)
160-
161-
162-
def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
163-
kv_cache_quant_layers = []
164-
165-
for name, linear in model.named_modules():
166-
if not isinstance(linear, torch.nn.Linear):
167-
continue
168-
169-
for output_quant_target in kv_cache_quant_targets:
170-
if name.endswith(output_quant_target):
171-
kv_cache_quant_layers.append(name)
73+
def save_quantized(self, save_directory: str):
74+
self.save_pretrained(save_directory, save_compressed=True)
17275

173-
return kv_cache_quant_layers
76+
def save_pretrained(self, save_directory: str, save_compressed: bool = True):
77+
self.model.save_pretrained(save_directory, save_compressed=save_compressed)
78+
tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
79+
tokenizer.save_pretrained(save_directory)
80+
print(f"Saved final checkpoint to {os.path.abspath(save_directory)}")

example_dataset.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,27 @@
33

44
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
55

6-
pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
7-
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"
6+
pretrained_model_dir = "facebook/opt-125m"
7+
quantized_model_dir = "opt-125m-FP8"
88

99
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
1010
tokenizer.pad_token = tokenizer.eos_token
1111

12+
MAX_SEQUENCE_LENGTH = 2048
1213
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512))
13-
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
14-
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")
14+
def preprocess(example):
15+
example = tokenizer.apply_chat_template(example["messages"], tokenize=False)
16+
return tokenizer(
17+
example,
18+
padding=False,
19+
max_length=MAX_SEQUENCE_LENGTH,
20+
truncation=True,
21+
add_special_tokens=False,
22+
)
23+
ds = ds.map(preprocess, remove_columns=ds.column_names)
1524

1625
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static")
1726

18-
model = AutoFP8ForCausalLM.from_pretrained(
19-
pretrained_model_dir, quantize_config=quantize_config
20-
)
21-
model.quantize(examples)
27+
model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
28+
model.quantize(ds)
2229
model.save_quantized(quantized_model_dir)

0 commit comments

Comments
 (0)