Skip to content

Commit 81c33df

Browse files
committed
Initial quantize script
1 parent 280fa32 commit 81c33df

File tree

1 file changed

+259
-0
lines changed

1 file changed

+259
-0
lines changed

quantize.py

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
import torch
2+
import torch.functional as F
3+
from typing import Tuple
4+
import transformers
5+
from transformers import AutoModelForCausalLM, AutoTokenizer
6+
from datasets import load_dataset
7+
import re
8+
9+
MODEL_ID = "facebook/opt-125m"
10+
# MODEL_ID = "echarlaix/tiny-random-mistral"
11+
12+
13+
NUM_PROMPTS = 512
14+
MAX_SEQ_LEN = 512
15+
16+
17+
# HACK: override the dtype_byte_size function in transformers to support float8 types
18+
def new_dtype_byte_size(dtype):
19+
if dtype == torch.bool:
20+
return 1 / 8
21+
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
22+
if bit_search is None:
23+
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
24+
bit_size = int(bit_search.groups()[0])
25+
return bit_size // 8
26+
27+
28+
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
29+
30+
31+
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
32+
"""Quantize a tensor using per-tensor static scaling factor.
33+
34+
Args:
35+
tensor: The input tensor.
36+
"""
37+
finfo = torch.finfo(torch.float8_e4m3fn)
38+
# Calculate the scale as dtype max divided by absmax.
39+
# Since .abs() creates a new tensor, we use aminmax to get
40+
# the min and max first and then calculate the absmax.
41+
min_val, max_val = tensor.aminmax()
42+
amax = min_val.abs().max(max_val.abs())
43+
scale = finfo.max / amax.clamp(min=1e-12)
44+
# scale and clamp the tensor to bring it to
45+
# the representative range of float8 data type
46+
# (as default cast is unsaturated)
47+
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
48+
# Return both float8 data and the inverse scale (as float),
49+
# as both required as inputs to torch._scaled_mm
50+
qweight = qweight.to(torch.float8_e4m3fn)
51+
scale = scale.float().reciprocal()
52+
return qweight, scale
53+
54+
55+
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
56+
cuda_compute_capability = torch.cuda.get_device_capability()
57+
if cuda_compute_capability >= (9, 0):
58+
output, _ = torch._scaled_mm(
59+
A,
60+
B.t(),
61+
out_dtype=out_dtype,
62+
scale_a=A_scale,
63+
scale_b=B_scale,
64+
bias=bias,
65+
)
66+
else:
67+
output = torch.nn.functional.linear(
68+
A.to(out_dtype) * A_scale,
69+
B.to(out_dtype) * B_scale.to(out_dtype),
70+
bias=bias,
71+
)
72+
return output
73+
74+
75+
class FP8StaticLinearQuantizer(torch.nn.Module):
76+
def __init__(self, qweight, weight_scale):
77+
super().__init__()
78+
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
79+
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
80+
self.act_scale = None
81+
82+
def forward(self, x):
83+
# Dynamically quantize
84+
qinput, x_act_scale = per_tensor_quantize(x)
85+
86+
# Update scale if needed.
87+
if self.act_scale is None:
88+
self.act_scale = torch.nn.Parameter(x_act_scale)
89+
elif x_act_scale > self.act_scale:
90+
self.act_scale = torch.nn.Parameter(x_act_scale)
91+
92+
# Pass quantized to next layer so it has realistic data.
93+
output = fp8_gemm(
94+
A=qinput,
95+
A_scale=self.act_scale,
96+
B=self.weight,
97+
B_scale=self.weight_scale,
98+
bias=None,
99+
out_dtype=x.dtype,
100+
)
101+
return output
102+
103+
104+
class FP8StaticLinear(torch.nn.Module):
105+
def __init__(self, qweight, weight_scale, act_scale=0.0):
106+
super().__init__()
107+
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
108+
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
109+
self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False)
110+
111+
def per_tensor_quantize(
112+
self, tensor: torch.Tensor, inv_scale: float
113+
) -> torch.Tensor:
114+
# Scale and clamp the tensor to bring it to
115+
# the representative range of float8 data type
116+
# (as default cast is unsaturated)
117+
finfo = torch.finfo(torch.float8_e4m3fn)
118+
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
119+
return qweight.to(torch.float8_e4m3fn)
120+
121+
def forward(self, x):
122+
qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale)
123+
output = fp8_gemm(
124+
A=qinput,
125+
A_scale=self.act_scale,
126+
B=self.weight,
127+
B_scale=self.weight_scale,
128+
bias=None,
129+
out_dtype=x.dtype,
130+
)
131+
return output
132+
133+
134+
class FP8DynamicLinear(torch.nn.Module):
135+
def __init__(self, qweight, scale):
136+
super().__init__()
137+
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
138+
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
139+
140+
def forward(self, x):
141+
qinput, x_scale = per_tensor_quantize(x)
142+
output = fp8_gemm(
143+
A=qinput,
144+
A_scale=x_scale,
145+
B=self.weight,
146+
B_scale=self.weight_scale,
147+
bias=None,
148+
out_dtype=x.dtype,
149+
)
150+
return output
151+
152+
153+
def quantize_weights(model):
154+
for name, linear in model.model.named_modules():
155+
if not isinstance(linear, torch.nn.Linear):
156+
continue
157+
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
158+
quant_linear = FP8DynamicLinear(quant_weight, quant_scale)
159+
if "." in name:
160+
parent_name = name.rsplit(".", 1)[0]
161+
child_name = name[len(parent_name) + 1 :]
162+
parent = model.model.get_submodule(parent_name)
163+
else:
164+
parent_name = ""
165+
parent = model.model
166+
child_name = name
167+
setattr(parent, child_name, quant_linear)
168+
169+
170+
def quantize_activations(model, calibration_tokens):
171+
# Replace layers with quantizer.
172+
for name, dynamic_quant_linear in model.model.named_modules():
173+
if not isinstance(dynamic_quant_linear, FP8DynamicLinear):
174+
continue
175+
quantizer = FP8StaticLinearQuantizer(
176+
dynamic_quant_linear.weight, dynamic_quant_linear.weight_scale
177+
)
178+
if "." in name:
179+
parent_name = name.rsplit(".", 1)[0]
180+
child_name = name[len(parent_name) + 1 :]
181+
parent = model.model.get_submodule(parent_name)
182+
else:
183+
parent_name = ""
184+
parent = model.model
185+
child_name = name
186+
setattr(parent, child_name, quantizer)
187+
188+
# Calibration.
189+
for row_idx in range(calibration_tokens.shape[0]):
190+
_ = model(calibration_tokens[row_idx].reshape(1, -1))
191+
192+
# Replace quantizer with StaticLayer.
193+
for name, quantizer in model.model.named_modules():
194+
if not isinstance(quantizer, FP8StaticLinearQuantizer):
195+
continue
196+
static_proj = FP8StaticLinear(
197+
quantizer.weight, quantizer.weight_scale, quantizer.act_scale
198+
)
199+
if "." in name:
200+
parent_name = name.rsplit(".", 1)[0]
201+
child_name = name[len(parent_name) + 1 :]
202+
parent = model.model.get_submodule(parent_name)
203+
else:
204+
parent_name = ""
205+
parent = model.model
206+
child_name = name
207+
setattr(parent, child_name, static_proj)
208+
209+
210+
if __name__ == "__main__":
211+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
212+
sample_input_tokens = tokenizer.apply_chat_template(
213+
[{"role": "user", "content": "What is your name?"}],
214+
add_generation_prompt=True,
215+
return_tensors="pt",
216+
).to("cuda")
217+
218+
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
219+
ds = ds.shuffle(seed=42).select(range(NUM_PROMPTS))
220+
ds = ds.map(
221+
lambda batch: {
222+
"text": tokenizer.apply_chat_template(batch["messages"], tokenize=False)
223+
}
224+
)
225+
tokenizer.pad_token_id = tokenizer.eos_token_id
226+
calibration_tokens = tokenizer(
227+
ds["text"],
228+
return_tensors="pt",
229+
truncation=True,
230+
padding="max_length",
231+
max_length=MAX_SEQ_LEN,
232+
add_special_tokens=False,
233+
).input_ids.to("cuda")
234+
print("Calibration tokens:", calibration_tokens.shape)
235+
236+
# Load and test the model
237+
model = AutoModelForCausalLM.from_pretrained(
238+
MODEL_ID, torch_dtype="auto", device_map="auto"
239+
)
240+
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
241+
print("ORIGINAL:\n", tokenizer.decode(output[0]), "\n\n")
242+
243+
# Quantize weights.
244+
quantize_weights(model)
245+
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
246+
print("WEIGHT QUANT:\n", tokenizer.decode(output[0]), "\n\n")
247+
248+
# Quantize activations.
249+
quantize_activations(model, calibration_tokens=calibration_tokens)
250+
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
251+
print("ACT QUANT:\n", tokenizer.decode(output[0]), "\n\n")
252+
253+
# Save the model fully quantized
254+
output_path = "fp8-static-quant"
255+
print(f"Saving the model to {output_path}")
256+
static_q_dict = {"quantization_config": {"quant_method": "fp8", "scheme": "static"}}
257+
model.config.update(static_q_dict)
258+
model.save_pretrained(output_path)
259+
tokenizer.save_pretrained(output_path)

0 commit comments

Comments
 (0)