Skip to content

Commit d69a57f

Browse files
authored
Merge pull request #7 from neuralmagic/ignore-layers
Add `ignore_patterns` arg for ignoring layers
2 parents a111911 + 0e8fe08 commit d69a57f

File tree

10 files changed

+175
-36
lines changed

10 files changed

+175
-36
lines changed

auto_fp8/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from .modeling import AutoFP8ForCausalLM
21
from .config import BaseQuantizeConfig
2+
from .modeling import AutoFP8ForCausalLM
33

44
__all__ = [
55
"AutoFP8ForCausalLM",

auto_fp8/config.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1+
from typing import List
2+
3+
14
class BaseQuantizeConfig:
2-
def __init__(self, quant_method="fp8", activation_scheme="static"):
5+
def __init__(
6+
self,
7+
quant_method: str = "fp8",
8+
activation_scheme: str = "static",
9+
ignore_patterns: List[str] = [],
10+
):
311
if quant_method != "fp8":
412
raise ValueError("Only FP8 quantization is supported.")
513
if activation_scheme not in ["static", "dynamic"]:
@@ -8,3 +16,5 @@ def __init__(self, quant_method="fp8", activation_scheme="static"):
816
)
917
self.quant_method = quant_method
1018
self.activation_scheme = activation_scheme
19+
self.ignore_patterns = ignore_patterns
20+
self.ignored_layers = []

auto_fp8/modeling.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,34 @@
1+
import re
2+
from typing import List
3+
14
import torch
2-
from transformers import AutoModelForCausalLM, PreTrainedModel
5+
from transformers import AutoModelForCausalLM
6+
7+
from auto_fp8.config import BaseQuantizeConfig
38
from auto_fp8.quantize import (
4-
quantize_weights,
59
quantize_activations,
10+
quantize_weights,
611
save_quantized_model,
712
)
8-
from auto_fp8.config import BaseQuantizeConfig
913

1014

1115
class AutoFP8ForCausalLM:
1216
def __init__(
1317
self,
14-
model: PreTrainedModel,
18+
model: AutoModelForCausalLM,
1519
quantize_config: BaseQuantizeConfig,
1620
):
1721
self.model = model
1822
self.model_type = self.model.config.model_type
19-
self.quantize_config = quantize_config
2023
self.config = self.model.config
2124

25+
# Gather the Linear module names that we want to ignore
26+
quantize_config.ignored_layers = get_layers_to_ignore(
27+
self.model, quantize_config.ignore_patterns
28+
)
29+
30+
self.quantize_config = quantize_config
31+
2232
@classmethod
2333
def from_pretrained(
2434
cls,
@@ -94,16 +104,47 @@ def _prepare_calibration_data(calibration_tokens):
94104
return calibration_tokens
95105

96106
# Always quantize the weights as they do not require calibration data
97-
quantize_weights(self.model)
107+
quantize_weights(self.model, self.quantize_config)
98108

99109
if self.quantize_config.activation_scheme == "static":
100110
quantize_activations(
101-
self.model, _prepare_calibration_data(calibration_tokens)
111+
self.model,
112+
self.quantize_config,
113+
_prepare_calibration_data(calibration_tokens),
102114
)
103115

116+
# import copy
117+
# for layer in self.model.model.layers:
118+
# layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.act_scale)
119+
104120
def save_quantized(self, save_dir):
105121
save_quantized_model(
106122
self.model,
107-
activation_scheme=self.quantize_config.activation_scheme,
123+
quant_config=self.quantize_config,
108124
save_dir=save_dir,
109125
)
126+
127+
128+
def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
129+
ignored_layers = set()
130+
131+
# TODO: don't always ignore lm_head
132+
ignore_patterns.append("re:.*lm_head")
133+
134+
for name, linear in model.named_modules():
135+
if not isinstance(linear, torch.nn.Linear):
136+
continue
137+
138+
for ignore_pattern in ignore_patterns:
139+
regex_prefix = "re:"
140+
if ignore_pattern.startswith(regex_prefix):
141+
# check if name matches regex and add to set if true
142+
regex_pattern = ignore_pattern[len(regex_prefix) :]
143+
if re.search(regex_pattern, name):
144+
ignored_layers.add(name)
145+
else:
146+
# else, exact match
147+
if ignore_pattern == name:
148+
ignored_layers.add(name)
149+
150+
return list(ignored_layers)

auto_fp8/quantize.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import gc
22
import re
3-
from typing import Tuple
3+
from typing import List, Tuple
4+
45
import torch
5-
import transformers
66
import tqdm
7-
from transformers import AutoTokenizer
7+
import transformers
8+
from transformers import AutoModelForCausalLM, AutoTokenizer
9+
10+
from .config import BaseQuantizeConfig
811

912

1013
# HACK: Override the dtype_byte_size function in transformers to support float8 types
@@ -39,8 +42,8 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
3942
if tensor.numel() == 0:
4043
# Deal with empty tensors (triggered by empty MoE experts)
4144
min_val, max_val = (
42-
torch.tensor(0.0, dtype=tensor.dtype),
43-
torch.tensor(1.0, dtype=tensor.dtype),
45+
torch.tensor(-16.0, dtype=tensor.dtype),
46+
torch.tensor(16.0, dtype=tensor.dtype),
4447
)
4548
else:
4649
min_val, max_val = tensor.aminmax()
@@ -80,7 +83,9 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
8083

8184

8285
class FP8StaticLinearQuantizer(torch.nn.Module):
83-
def __init__(self, qweight, weight_scale, bias):
86+
def __init__(
87+
self, qweight: torch.Tensor, weight_scale: torch.Tensor, bias: torch.Tensor
88+
):
8489
super().__init__()
8590
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
8691
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
@@ -105,7 +110,13 @@ def forward(self, x):
105110

106111

107112
class FP8StaticLinear(torch.nn.Module):
108-
def __init__(self, qweight, weight_scale, bias, act_scale=0.0):
113+
def __init__(
114+
self,
115+
qweight: torch.Tensor,
116+
weight_scale: torch.Tensor,
117+
bias: torch.Tensor,
118+
act_scale: float = 1.0,
119+
):
109120
super().__init__()
110121
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
111122
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
@@ -133,7 +144,7 @@ def forward(self, x):
133144

134145

135146
class FP8DynamicLinear(torch.nn.Module):
136-
def __init__(self, qweight, scale, bias):
147+
def __init__(self, qweight: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor):
137148
super().__init__()
138149
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
139150
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
@@ -152,21 +163,28 @@ def forward(self, x):
152163
return output
153164

154165

155-
def replace_module(model, name, new_module):
166+
def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn.Module):
156167
if "." in name:
157168
parent_name = name.rsplit(".", 1)[0]
158169
child_name = name[len(parent_name) + 1 :]
159-
parent = model.model.get_submodule(parent_name)
170+
parent = model.get_submodule(parent_name)
160171
else:
161172
parent_name = ""
162-
parent = model.model
173+
parent = model
163174
child_name = name
164175
setattr(parent, child_name, new_module)
165176

166177

167-
def quantize_weights(model):
168-
for name, linear in model.model.named_modules():
169-
if not isinstance(linear, torch.nn.Linear):
178+
def quantize_weights(
179+
model: AutoModelForCausalLM,
180+
quantize_config: BaseQuantizeConfig,
181+
ignored_layers: List[str] = [],
182+
):
183+
for name, linear in model.named_modules():
184+
if (
185+
not isinstance(linear, torch.nn.Linear)
186+
or name in quantize_config.ignored_layers
187+
):
170188
continue
171189
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
172190
quant_linear = FP8DynamicLinear(quant_weight, quant_scale, linear.bias)
@@ -175,9 +193,17 @@ def quantize_weights(model):
175193
cleanup_memory()
176194

177195

178-
def quantize_activations(model, calibration_tokens):
179-
for name, dynamic_quant_linear in model.model.named_modules():
180-
if not isinstance(dynamic_quant_linear, FP8DynamicLinear):
196+
def quantize_activations(
197+
model: AutoModelForCausalLM,
198+
quantize_config: BaseQuantizeConfig,
199+
calibration_tokens,
200+
ignored_layers: List[str] = [],
201+
):
202+
for name, dynamic_quant_linear in model.named_modules():
203+
if (
204+
not isinstance(dynamic_quant_linear, FP8DynamicLinear)
205+
or name in quantize_config.ignored_layers
206+
):
181207
continue
182208
quantizer = FP8StaticLinearQuantizer(
183209
dynamic_quant_linear.weight,
@@ -196,8 +222,11 @@ def quantize_activations(model, calibration_tokens):
196222
pbar.update(1)
197223

198224
# Replace dynamic quantizer with StaticLinear for export
199-
for name, quantizer in model.model.named_modules():
200-
if not isinstance(quantizer, FP8StaticLinearQuantizer):
225+
for name, quantizer in model.named_modules():
226+
if (
227+
not isinstance(quantizer, FP8StaticLinearQuantizer)
228+
or name in quantize_config.ignored_layers
229+
):
201230
continue
202231
static_proj = FP8StaticLinear(
203232
quantizer.weight,
@@ -210,13 +239,19 @@ def quantize_activations(model, calibration_tokens):
210239
cleanup_memory()
211240

212241

213-
def save_quantized_model(model, activation_scheme, save_dir):
242+
def save_quantized_model(
243+
model: AutoModelForCausalLM,
244+
quant_config: BaseQuantizeConfig,
245+
save_dir: str,
246+
ignored_layers: List[str] = [],
247+
):
214248
print(model)
215249
print(f"Saving the model to {save_dir}")
216250
static_q_dict = {
217251
"quantization_config": {
218252
"quant_method": "fp8",
219-
"activation_scheme": activation_scheme,
253+
"activation_scheme": quant_config.activation_scheme,
254+
"ignored_layers": quant_config.ignored_layers,
220255
}
221256
}
222257
model.config.update(static_q_dict)

example.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from transformers import AutoTokenizer
2+
23
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
34

45
pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
@@ -9,8 +10,10 @@
910
examples = tokenizer(examples, return_tensors="pt").to("cuda")
1011

1112
quantize_config = BaseQuantizeConfig(
12-
quant_method="fp8", activation_scheme="dynamic"
13-
) # or "static"
13+
quant_method="fp8",
14+
activation_scheme="dynamic", # or "static"
15+
ignore_patterns=["re:.*lm_head"],
16+
)
1417

1518
model = AutoFP8ForCausalLM.from_pretrained(
1619
pretrained_model_dir, quantize_config=quantize_config

example_dataset.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from datasets import load_dataset
2+
from transformers import AutoTokenizer
3+
4+
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
5+
6+
pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
7+
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"
8+
9+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
10+
tokenizer.pad_token = tokenizer.eos_token
11+
12+
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(512)
13+
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
14+
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")
15+
16+
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static")
17+
18+
model = AutoFP8ForCausalLM.from_pretrained(
19+
pretrained_model_dir, quantize_config=quantize_config
20+
)
21+
model.quantize(examples)
22+
model.save_quantized(quantized_model_dir)

examples/example_mixtral.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from datasets import load_dataset
2+
from transformers import AutoTokenizer
3+
4+
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
5+
6+
pretrained_model_dir = "mistralai/Mixtral-8x7B-Instruct-v0.1"
7+
quantized_model_dir = "Mixtral-8x7B-Instruct-v0.1-FP8"
8+
9+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
10+
tokenizer.pad_token = tokenizer.eos_token
11+
12+
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(10))
13+
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
14+
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")
15+
16+
quantize_config = BaseQuantizeConfig(
17+
quant_method="fp8",
18+
activation_scheme="static",
19+
ignore_patterns=["re:.*lm_head", "re:.*gate"],
20+
)
21+
22+
model = AutoFP8ForCausalLM.from_pretrained(
23+
pretrained_model_dir, quantize_config=quantize_config
24+
)
25+
model.quantize(examples)
26+
model.save_quantized(quantized_model_dir)

examples/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from typing import Tuple
55

66
import torch
7-
import transformers
87
import tqdm
8+
import transformers
99
from datasets import load_dataset
1010
from transformers import AutoModelForCausalLM, AutoTokenizer
1111

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from setuptools import setup, find_packages
1+
from setuptools import find_packages, setup
22

33
setup(
44
name="auto_fp8",

tests/test_auto_fp8.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
2+
import shutil
3+
24
from transformers import AutoTokenizer
5+
36
from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
4-
import shutil
57

68

79
def test_quantization():

0 commit comments

Comments
 (0)