Skip to content

Commit a95b81d

Browse files
committed
Add more improvements and requirements.txt
1 parent b0c1f78 commit a95b81d

File tree

2 files changed

+76
-42
lines changed

2 files changed

+76
-42
lines changed

quantize.py

Lines changed: 72 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import gc
23
import re
34
from typing import Tuple
45

@@ -10,6 +11,7 @@
1011

1112

1213
# HACK: override the dtype_byte_size function in transformers to support float8 types
14+
# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
1315
def new_dtype_byte_size(dtype):
1416
if dtype == torch.bool:
1517
return 1 / 8
@@ -23,6 +25,11 @@ def new_dtype_byte_size(dtype):
2325
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
2426

2527

28+
def cleanup_memory():
29+
gc.collect()
30+
torch.cuda.empty_cache()
31+
32+
2633
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
2734
"""Quantize a tensor using per-tensor static scaling factor.
2835
@@ -33,7 +40,14 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
3340
# Calculate the scale as dtype max divided by absmax.
3441
# Since .abs() creates a new tensor, we use aminmax to get
3542
# the min and max first and then calculate the absmax.
36-
min_val, max_val = tensor.aminmax()
43+
if tensor.numel() == 0:
44+
# Deal with empty tensors (triggered by empty MoE experts)
45+
min_val, max_val = (
46+
torch.tensor(0.0, dtype=tensor.dtype),
47+
torch.tensor(1.0, dtype=tensor.dtype),
48+
)
49+
else:
50+
min_val, max_val = tensor.aminmax()
3751
amax = min_val.abs().max(max_val.abs())
3852
scale = finfo.max / amax.clamp(min=1e-12)
3953
# scale and clamp the tensor to bring it to
@@ -145,68 +159,80 @@ def forward(self, x):
145159
return output
146160

147161

162+
def replace_module(model, name, new_module):
163+
if "." in name:
164+
parent_name = name.rsplit(".", 1)[0]
165+
child_name = name[len(parent_name) + 1 :]
166+
parent = model.model.get_submodule(parent_name)
167+
else:
168+
parent_name = ""
169+
parent = model.model
170+
child_name = name
171+
setattr(parent, child_name, new_module)
172+
173+
148174
def quantize_weights(model):
149175
for name, linear in model.model.named_modules():
176+
# if "gate" in name or not isinstance(linear, torch.nn.Linear):
150177
if not isinstance(linear, torch.nn.Linear):
151178
continue
152179
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
153180
quant_linear = FP8DynamicLinear(quant_weight, quant_scale)
154-
if "." in name:
155-
parent_name = name.rsplit(".", 1)[0]
156-
child_name = name[len(parent_name) + 1 :]
157-
parent = model.model.get_submodule(parent_name)
158-
else:
159-
parent_name = ""
160-
parent = model.model
161-
child_name = name
162-
setattr(parent, child_name, quant_linear)
181+
replace_module(model, name, quant_linear)
182+
del linear
183+
cleanup_memory()
163184

164185

165186
def quantize_activations(model, calibration_tokens):
166187
# Replace layers with quantizer.
167188
for name, dynamic_quant_linear in model.model.named_modules():
189+
# if "gate" in name or not isinstance(dynamic_quant_linear, FP8DynamicLinear):
168190
if not isinstance(dynamic_quant_linear, FP8DynamicLinear):
169191
continue
170192
quantizer = FP8StaticLinearQuantizer(
171193
dynamic_quant_linear.weight, dynamic_quant_linear.weight_scale
172194
)
173-
if "." in name:
174-
parent_name = name.rsplit(".", 1)[0]
175-
child_name = name[len(parent_name) + 1 :]
176-
parent = model.model.get_submodule(parent_name)
177-
else:
178-
parent_name = ""
179-
parent = model.model
180-
child_name = name
181-
setattr(parent, child_name, quantizer)
195+
replace_module(model, name, quantizer)
196+
del dynamic_quant_linear
197+
cleanup_memory()
182198

183199
# Calibration.
184200
for row_idx in range(calibration_tokens.shape[0]):
185201
_ = model(calibration_tokens[row_idx].reshape(1, -1))
186202

187203
# Replace quantizer with StaticLayer.
188204
for name, quantizer in model.model.named_modules():
205+
# if "gate" in name or not isinstance(quantizer, FP8StaticLinearQuantizer):
189206
if not isinstance(quantizer, FP8StaticLinearQuantizer):
190207
continue
191208
static_proj = FP8StaticLinear(
192209
quantizer.weight, quantizer.weight_scale, quantizer.act_scale
193210
)
194-
if "." in name:
195-
parent_name = name.rsplit(".", 1)[0]
196-
child_name = name[len(parent_name) + 1 :]
197-
parent = model.model.get_submodule(parent_name)
198-
else:
199-
parent_name = ""
200-
parent = model.model
201-
child_name = name
202-
setattr(parent, child_name, static_proj)
211+
replace_module(model, name, static_proj)
212+
del quantizer
213+
cleanup_memory()
214+
215+
216+
def save_quantized_model(model, activation_scheme, save_dir):
217+
print(f"Saving the model to {save_dir}")
218+
static_q_dict = {
219+
"quantization_config": {
220+
"quant_method": "fp8",
221+
"activation_scheme": activation_scheme,
222+
}
223+
}
224+
model.config.update(static_q_dict)
225+
model.save_pretrained(save_dir)
226+
tokenizer.save_pretrained(save_dir)
203227

204228

205229
if __name__ == "__main__":
206230
parser = argparse.ArgumentParser()
207231
parser.add_argument("--model-id", type=str)
208232
parser.add_argument("--save-dir", type=str)
209-
# parser.add_argument("--static-act", action="store_true")
233+
parser.add_argument(
234+
"--activation-scheme", type=str, default="static", choices=["static", "dynamic"]
235+
)
210236
parser.add_argument("--num-samples", type=int, default=512)
211237
parser.add_argument("--max-seq-len", type=int, default=512)
212238
args = parser.parse_args()
@@ -240,22 +266,26 @@ def quantize_activations(model, calibration_tokens):
240266
model = AutoModelForCausalLM.from_pretrained(
241267
args.model_id, torch_dtype="auto", device_map="auto"
242268
)
269+
print("Original model graph:\n", model)
243270
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
244-
print("ORIGINAL:\n", tokenizer.decode(output[0]), "\n\n")
271+
print("ORIGINAL OUTPUT:\n", tokenizer.decode(output[0]), "\n\n")
245272

246273
# Quantize weights.
247274
quantize_weights(model)
275+
print("Weight-quantized model graph:\n", model)
248276
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
249-
print("WEIGHT QUANT:\n", tokenizer.decode(output[0]), "\n\n")
277+
print("WEIGHT QUANT OUTPUT:\n", tokenizer.decode(output[0]), "\n\n")
250278

251-
# Quantize activations.
252-
quantize_activations(model, calibration_tokens=calibration_tokens)
253-
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
254-
print("ACT QUANT:\n", tokenizer.decode(output[0]), "\n\n")
255-
256-
# Save the model fully quantized
257-
print(f"Saving the model to {args.save_dir}")
258-
static_q_dict = {"quantization_config": {"quant_method": "fp8", "scheme": "static"}}
259-
model.config.update(static_q_dict)
260-
model.save_pretrained(args.save_dir)
261-
tokenizer.save_pretrained(args.save_dir)
279+
if args.activation_scheme in "dynamic":
280+
print("Exporting model with static weights and dynamic activations")
281+
save_quantized_model(model, args.activation_scheme, args.save_dir)
282+
else:
283+
assert args.activation_scheme in "static"
284+
# Quantize activations.
285+
quantize_activations(model, calibration_tokens=calibration_tokens)
286+
print("Weight and activation quantized model graph:\n", model)
287+
output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20)
288+
print("ACT QUANT OUTPUT:\n", tokenizer.decode(output[0]), "\n\n")
289+
290+
print("Exporting model with static weights and static activations")
291+
save_quantized_model(model, args.activation_scheme, args.save_dir)

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
torch>=2.2
2+
transformers
3+
datasets
4+
accelerate

0 commit comments

Comments
 (0)