Skip to content

Commit 7546f76

Browse files
committed
Switch backend to use llm-compressor
1 parent 0249168 commit 7546f76

File tree

4 files changed

+90
-136
lines changed

4 files changed

+90
-136
lines changed

auto_fp8/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from .config import BaseQuantizeConfig
2-
from .modeling import AutoFP8ForCausalLM
1+
from .modeling import AutoFP8ForCausalLM, BaseQuantizeConfig
32

43
__all__ = [
54
"AutoFP8ForCausalLM",

auto_fp8/config.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

auto_fp8/modeling.py

Lines changed: 73 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,46 @@
1-
import re
2-
from typing import List, Optional, Tuple
3-
4-
import torch
5-
from transformers import AutoModelForCausalLM
6-
7-
from auto_fp8.config import BaseQuantizeConfig
8-
from auto_fp8.quantize import (
9-
quantize_activations,
10-
quantize_weights,
11-
save_quantized_model,
12-
)
13-
14-
15-
class AutoFP8ForCausalLM:
1+
import os
2+
from typing import List, Optional
3+
4+
from transformers import AutoConfig, AutoTokenizer
5+
from datasets import Dataset
6+
from llmcompressor.transformers import SparseAutoModelForCausalLM
7+
from llmcompressor.transformers import oneshot
8+
from llmcompressor.modifiers.quantization import QuantizationModifier
9+
10+
class BaseQuantizeConfig:
11+
"""Configuration for model quantization.
12+
13+
Args:
14+
quant_method: Type/precision of quantization method to use.
15+
At the moment, this is just "fp8" which specifically means
16+
the fp8_e4m3 format in pytorch.
17+
activation_scheme: Choice of either "dynamic" or "static" quantization
18+
of activtions. If "static", then calibration samples are required
19+
during quantization to produce accurate per-tensor scales for
20+
activations of Linear modules.
21+
ignore_patterns: List of patterns used to ignore layers. If a string
22+
starts with "re:", then everything afterwards is used as python
23+
regex style matching i.e. re.search(), for each Linear layer.
24+
By default, "lm_head" is included to ignore the embedding
25+
Linear layer usually at the end of decoder LLMs
26+
"""
1627
def __init__(
1728
self,
18-
model: AutoModelForCausalLM,
19-
quantize_config: BaseQuantizeConfig,
29+
quant_method: str = "fp8",
30+
activation_scheme: str = "static",
31+
ignore_patterns: List[str] = ["lm_head"],
2032
):
33+
self.quant_method = quant_method
34+
self.activation_scheme = activation_scheme
35+
self.ignore_patterns = ignore_patterns
36+
37+
38+
class AutoFP8ForCausalLM:
39+
def __init__(self, model: SparseAutoModelForCausalLM, quantize_config: BaseQuantizeConfig):
2140
self.model = model
2241
self.model_type = self.model.config.model_type
2342
self.config = self.model.config
43+
<<<<<<< HEAD
2444

2545
# Gather the Linear module names that we want to ignore
2646
quantize_config.ignored_layers = get_layers_to_ignore(
@@ -45,76 +65,23 @@ def __init__(
4565
)
4666
quantize_config.kv_cache_quant_layers = kv_cache_quant_layers
4767

68+
=======
69+
>>>>>>> ba7d420 (Switch backend to use llm-compressor)
4870
self.quantize_config = quantize_config
4971

5072
@classmethod
51-
def from_pretrained(
52-
cls,
53-
pretrained_model_name_or_path: str,
54-
quantize_config: BaseQuantizeConfig,
55-
**model_init_kwargs,
56-
):
57-
"""Load the un-quantized pretrained model"""
58-
59-
def skip(*args, **kwargs):
60-
pass
61-
62-
torch.nn.init.kaiming_uniform_ = skip
63-
torch.nn.init.uniform_ = skip
64-
torch.nn.init.normal_ = skip
65-
66-
# Parameters related to loading from Hugging Face Hub
67-
cache_dir = model_init_kwargs.pop("cache_dir", None)
68-
force_download = model_init_kwargs.pop("force_download", False)
69-
resume_download = model_init_kwargs.pop("resume_download", False)
70-
proxies = model_init_kwargs.pop("proxies", None)
71-
local_files_only = model_init_kwargs.pop("local_files_only", False)
72-
use_auth_token = model_init_kwargs.pop("use_auth_token", None)
73-
revision = model_init_kwargs.pop("revision", None)
74-
subfolder = model_init_kwargs.pop("subfolder", "")
75-
commit_hash = model_init_kwargs.pop("_commit_hash", None)
76-
77-
cached_file_kwargs = {
78-
"cache_dir": cache_dir,
79-
"force_download": force_download,
80-
"proxies": proxies,
81-
"resume_download": resume_download,
82-
"local_files_only": local_files_only,
83-
"use_auth_token": use_auth_token,
84-
"revision": revision,
85-
"subfolder": subfolder,
86-
"_commit_hash": commit_hash,
87-
}
88-
89-
torch.cuda.empty_cache()
90-
91-
# Important defaults
92-
if "torch_dtype" not in model_init_kwargs:
93-
model_init_kwargs["torch_dtype"] = "auto"
94-
95-
if "device_map" not in model_init_kwargs:
96-
model_init_kwargs["device_map"] = "auto"
97-
98-
merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
99-
print("Loading model with the following kwargs:", merged_kwargs)
100-
model = AutoModelForCausalLM.from_pretrained(
101-
pretrained_model_name_or_path, **merged_kwargs
73+
def from_pretrained(cls, pretrained_model_name_or_path: str, quantize_config: BaseQuantizeConfig, **kwargs):
74+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
75+
model = SparseAutoModelForCausalLM.from_pretrained(
76+
pretrained_model_name_or_path,
77+
config=config,
78+
device_map="auto",
79+
torch_dtype="auto",
80+
**kwargs
10281
)
103-
104-
model_config = model.config.to_dict()
105-
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
106-
if any(k in model_config for k in seq_len_keys):
107-
for key in seq_len_keys:
108-
if key in model_config:
109-
model.seqlen = model_config[key]
110-
break
111-
else:
112-
print("Can't get model's sequence length, setting to 2048.")
113-
model.seqlen = 2048
114-
model.eval()
115-
11682
return cls(model, quantize_config)
11783

84+
<<<<<<< HEAD
11885
def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):
11986
<<<<<<< HEAD
12087
<<<<<<< HEAD
@@ -161,12 +128,28 @@ def save_quantized(self, save_dir):
161128
self.model,
162129
quant_config=self.quantize_config,
163130
save_dir=save_dir,
131+
=======
132+
def quantize(self, dataset: Optional[Dataset] = None):
133+
assert self.quantize_config.activation_scheme == "static"
134+
assert dataset is not None, "Calibration tokens required for static activation quantization"
135+
136+
recipe = QuantizationModifier(
137+
targets="Linear",
138+
scheme="FP8",
139+
ignore=self.quantize_config.ignore_patterns
140+
>>>>>>> ba7d420 (Switch backend to use llm-compressor)
164141
)
165142

143+
oneshot(
144+
model=self.model,
145+
dataset=dataset,
146+
recipe=recipe,
147+
)
166148

167-
def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
168-
ignored_layers = set()
149+
def save_quantized(self, save_directory: str):
150+
self.save_pretrained(save_directory, save_compressed=True)
169151

152+
<<<<<<< HEAD
170153
for name, linear in model.named_modules():
171154
if not isinstance(linear, torch.nn.Linear):
172155
continue
@@ -220,3 +203,10 @@ def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List
220203

221204
return kv_cache_quant_layers
222205
>>>>>>> c3acdee (Switch from output_scale to kv_scale)
206+
=======
207+
def save_pretrained(self, save_directory: str, save_compressed: bool = True):
208+
self.model.save_pretrained(save_directory, save_compressed=save_compressed)
209+
tokenizer = AutoTokenizer.from_pretrained(self.model.config._name_or_path)
210+
tokenizer.save_pretrained(save_directory)
211+
print(f"Saved final checkpoint to {os.path.abspath(save_directory)}")
212+
>>>>>>> ba7d420 (Switch backend to use llm-compressor)

example_dataset.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,27 @@
33

44
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
55

6-
pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
7-
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"
6+
pretrained_model_dir = "facebook/opt-125m"
7+
quantized_model_dir = "opt-125m-FP8"
88

99
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
1010
tokenizer.pad_token = tokenizer.eos_token
1111

12-
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft")
13-
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
14-
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")
12+
MAX_SEQUENCE_LENGTH = 2048
13+
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512))
14+
def preprocess(example):
15+
example = tokenizer.apply_chat_template(example["messages"], tokenize=False)
16+
return tokenizer(
17+
example,
18+
padding=False,
19+
max_length=MAX_SEQUENCE_LENGTH,
20+
truncation=True,
21+
add_special_tokens=False,
22+
)
23+
ds = ds.map(preprocess, remove_columns=ds.column_names)
1524

1625
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static")
1726

18-
model = AutoFP8ForCausalLM.from_pretrained(
19-
pretrained_model_dir, quantize_config=quantize_config
20-
)
21-
model.quantize(examples)
27+
model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)
28+
model.quantize(ds)
2229
model.save_quantized(quantized_model_dir)

0 commit comments

Comments
 (0)