Skip to content

Commit d0dd9d0

Browse files
committed
Support calibrating kv cache scales
1 parent 4b2092c commit d0dd9d0

File tree

3 files changed

+213
-0
lines changed

3 files changed

+213
-0
lines changed

auto_fp8/modeling.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ def __init__(
2828
)
2929

3030
if quantize_config.kv_cache_quant_targets:
31+
<<<<<<< HEAD
3132
kv_cache_quant_layers = get_kv_cache_quant_layers(
33+
=======
34+
kv_cache_quant_layers = get_kv_cache_quant_layer(
35+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
3236
self.model, quantize_config.kv_cache_quant_targets
3337
)
3438
if len(kv_cache_quant_layers) == 0:
@@ -108,6 +112,13 @@ def skip(*args, **kwargs):
108112
return cls(model, quantize_config)
109113

110114
def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):
115+
<<<<<<< HEAD
116+
=======
117+
def _prepare_calibration_data(calibration_tokens):
118+
if hasattr(calibration_tokens, "input_ids"):
119+
return calibration_tokens.input_ids
120+
return calibration_tokens
121+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
111122

112123
# Always quantize the weights as they do not require calibration data
113124
quantize_weights(self.model, self.quantize_config)
@@ -116,13 +127,16 @@ def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):
116127
assert (
117128
calibration_tokens is not None
118129
), "Calibration tokens required for activation quantization"
130+
<<<<<<< HEAD
119131

120132

121133
def _prepare_calibration_data(calibration_tokens):
122134
if hasattr(calibration_tokens, "input_ids"):
123135
return calibration_tokens.input_ids
124136
return calibration_tokens
125137

138+
=======
139+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
126140
quantize_activations(
127141
self.model,
128142
self.quantize_config,
@@ -159,15 +173,26 @@ def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
159173
return list(ignored_layers)
160174

161175

176+
<<<<<<< HEAD
162177
def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
163178
kv_cache_quant_layers = []
179+
=======
180+
def get_kv_cache_quant_layer(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
181+
kv_cache_quant_layers = set()
182+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
164183

165184
for name, linear in model.named_modules():
166185
if not isinstance(linear, torch.nn.Linear):
167186
continue
168187

169188
for output_quant_target in kv_cache_quant_targets:
170189
if name.endswith(output_quant_target):
190+
<<<<<<< HEAD
171191
kv_cache_quant_layers.append(name)
172192

173193
return kv_cache_quant_layers
194+
=======
195+
kv_cache_quant_layers.add(name)
196+
197+
return list(kv_cache_quant_layers)
198+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)

auto_fp8/quantize.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,19 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
7272
# Deal with empty tensors (triggeted by empty MoE experts)
7373
return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)
7474

75+
<<<<<<< HEAD
7576
# TODO: Disable native fp8 gemm for now, always just dequantize
7677
# native_fp8_support = (
7778
# torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
7879
# )
7980
native_fp8_support = False
81+
=======
82+
native_fp8_support = (
83+
torch.cuda.is_available()
84+
and torch.cuda.get_device_capability() >= (8, 9)
85+
and False
86+
)
87+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
8088
if native_fp8_support:
8189
need_reshape = A.dim() == 3
8290
if need_reshape:
@@ -108,6 +116,7 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
108116

109117
# Class responsible for quantizing weights
110118
class FP8DynamicLinear(torch.nn.Module):
119+
<<<<<<< HEAD
111120
def __init__(
112121
self,
113122
weight: torch.Tensor,
@@ -125,13 +134,114 @@ def forward(self, x):
125134
A=qinput,
126135
A_scale=x_scale,
127136
B=self.weight,
137+
=======
138+
def __init__(
139+
self,
140+
qweight: torch.Tensor,
141+
weight_scale: torch.Tensor,
142+
bias: torch.nn.Parameter,
143+
):
144+
super().__init__()
145+
self.qweight = torch.nn.Parameter(qweight, requires_grad=False)
146+
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
147+
self.bias = bias
148+
149+
def forward(self, x):
150+
qinput, x_scale = per_tensor_quantize(x)
151+
output = fp8_gemm(
152+
A=qinput,
153+
A_scale=x_scale,
154+
B=self.qweight,
128155
B_scale=self.weight_scale,
129156
bias=self.bias,
130157
out_dtype=x.dtype,
131158
)
132159
return output
133160

134161

162+
# Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer
163+
class FP8StaticLinearQuantizer(torch.nn.Module):
164+
def __init__(
165+
self,
166+
qweight: torch.Tensor,
167+
weight_scale: torch.Tensor,
168+
bias: torch.nn.Parameter,
169+
quantize_output: bool = False,
170+
):
171+
super().__init__()
172+
self.qweight = torch.nn.Parameter(qweight, requires_grad=False)
173+
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
174+
self.bias = bias
175+
self.input_scale = None
176+
self.output_scale = None
177+
self.quantize_output = quantize_output
178+
179+
def forward(self, x):
180+
qinput, x_input_scale = per_tensor_quantize(x)
181+
if self.input_scale is None:
182+
self.input_scale = torch.nn.Parameter(x_input_scale)
183+
elif x_input_scale > self.input_scale:
184+
self.input_scale = torch.nn.Parameter(x_input_scale)
185+
output = fp8_gemm(
186+
A=qinput,
187+
A_scale=self.input_scale,
188+
B=self.qweight,
189+
B_scale=self.weight_scale,
190+
bias=self.bias,
191+
out_dtype=x.dtype,
192+
)
193+
194+
# Optionally, quantize output and record scale
195+
if self.quantize_output:
196+
qoutput, output_scale = per_tensor_quantize(output)
197+
if self.output_scale is None:
198+
self.output_scale = torch.nn.Parameter(output_scale)
199+
elif output_scale > self.output_scale:
200+
self.output_scale = torch.nn.Parameter(output_scale)
201+
output = qoutput.to(output.dtype) * output_scale
202+
203+
return output
204+
205+
206+
# Module responsible for representing the final checkpoint representation
207+
class FP8StaticLinear(torch.nn.Module):
208+
def __init__(
209+
self,
210+
qweight: torch.nn.Parameter,
211+
weight_scale: torch.nn.Parameter,
212+
bias: torch.nn.Parameter,
213+
input_scale: torch.nn.Parameter,
214+
output_scale: Optional[torch.nn.Parameter] = None,
215+
):
216+
super().__init__()
217+
self.qweight = qweight
218+
self.weight_scale = weight_scale
219+
self.bias = bias
220+
self.input_scale = input_scale
221+
self.output_scale = output_scale
222+
223+
def forward(self, x):
224+
qinput = static_per_tensor_quantize(x, self.input_scale)
225+
output = fp8_gemm(
226+
A=qinput,
227+
A_scale=self.input_scale,
228+
B=self.qweight,
229+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
230+
B_scale=self.weight_scale,
231+
bias=self.bias,
232+
out_dtype=x.dtype,
233+
)
234+
<<<<<<< HEAD
235+
=======
236+
237+
if self.output_scale:
238+
qoutput = static_per_tensor_quantize(output, self.output_scale)
239+
output = qoutput.to(output.dtype) * self.output_scale
240+
241+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
242+
return output
243+
244+
135245
# Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer
136246
class FP8StaticLinearQuantizer(torch.nn.Module):
137247
def __init__(
@@ -237,7 +347,11 @@ def quantize_weights(
237347
quant_weight, weight_scale = per_tensor_quantize(linear.weight)
238348
bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
239349
quant_linear = FP8DynamicLinear(
350+
<<<<<<< HEAD
240351
weight=quant_weight, weight_scale=weight_scale, bias=bias
352+
=======
353+
qweight=quant_weight, weight_scale=weight_scale, bias=bias
354+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
241355
)
242356
replace_module(model, name, quant_linear)
243357
del linear.weight
@@ -259,7 +373,11 @@ def quantize_activations(
259373
):
260374
continue
261375
quantizer = FP8StaticLinearQuantizer(
376+
<<<<<<< HEAD
262377
weight=dynamic_quant_linear.weight,
378+
=======
379+
qweight=dynamic_quant_linear.qweight,
380+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
263381
weight_scale=dynamic_quant_linear.weight_scale,
264382
bias=dynamic_quant_linear.bias,
265383
quantize_output=(
@@ -272,12 +390,22 @@ def quantize_activations(
272390
cleanup_memory()
273391

274392
# Pass through calibration data to measure activation scales
393+
<<<<<<< HEAD
275394
with torch.inference_mode():
276395
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
277396
for row_idx in range(calibration_tokens.shape[0]):
278397
model(calibration_tokens[row_idx].reshape(1, -1))
279398
cleanup_memory()
280399
pbar.update(1)
400+
=======
401+
with tqdm.tqdm(
402+
total=calibration_tokens.shape[0], desc="Calibrating activation scales"
403+
) as pbar:
404+
for row_idx in range(calibration_tokens.shape[0]):
405+
model(calibration_tokens[row_idx].reshape(1, -1))
406+
cleanup_memory()
407+
pbar.update(1)
408+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
281409

282410
# Replace dynamic quantizer observer with StaticLinear for export
283411
for name, quantizer in model.named_modules():
@@ -287,7 +415,11 @@ def quantize_activations(
287415
):
288416
continue
289417
static_proj = FP8StaticLinear(
418+
<<<<<<< HEAD
290419
weight=quantizer.weight,
420+
=======
421+
qweight=quantizer.qweight,
422+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
291423
weight_scale=quantizer.weight_scale,
292424
bias=quantizer.bias,
293425
input_scale=quantizer.input_scale,

tests/test_auto_fp8.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import os
22
import shutil
33

4+
<<<<<<< HEAD
45
import pytest
6+
=======
7+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
58
import safetensors.torch
69
from transformers import AutoTokenizer
710

@@ -12,9 +15,15 @@
1215
("Qwen/Qwen2-0.5B-Instruct", 620),
1316
]
1417

18+
<<<<<<< HEAD
1519
@pytest.mark.parametrize("model_id,target_size", MODELS)
1620
def test_dynamic_quantization(model_id, target_size):
1721
quantized_model_dir = model_id.split("/")[-1] + "-fp8-dynamic"
22+
=======
23+
def test_dynamic_quantization():
24+
model_id = "facebook/opt-125m"
25+
quantized_model_dir = "opt-125m-fp8-dynamic"
26+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
1827

1928
quantize_config = BaseQuantizeConfig(
2029
quant_method="fp8", activation_scheme="dynamic"
@@ -30,6 +39,7 @@ def test_dynamic_quantization(model_id, target_size):
3039
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
3140
shutil.rmtree(quantized_model_dir)
3241

42+
<<<<<<< HEAD
3343
# We expect the quantized model to be a certain size
3444
target_size = target_size * (1024 * 1024)
3545
assert model_size < target_size
@@ -38,6 +48,16 @@ def test_dynamic_quantization(model_id, target_size):
3848
@pytest.mark.parametrize("model_id,target_size", MODELS)
3949
def test_static_quantization(model_id, target_size):
4050
quantized_model_dir = model_id.split("/")[-1] + "-fp8-static"
51+
=======
52+
# We expect the model to be < 160MB
53+
target_size = 160 * (1024 * 1024)
54+
assert model_size < target_size
55+
56+
57+
def test_static_quantization():
58+
model_id = "facebook/opt-125m"
59+
quantized_model_dir = "opt-125m-fp8-static"
60+
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
4161

4262
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
4363
examples = ["auto-fp8 is an easy-to-use model quantization library"]
@@ -96,3 +116,39 @@ def test_kv_cache_static_quantization(model_id, target_size):
96116
# We expect the quantized model to be a certain size
97117
target_size = target_size * (1024 * 1024)
98118
assert model_size < target_size
119+
120+
121+
def test_kv_cache_static_quantization():
122+
model_id = "facebook/opt-125m"
123+
quantized_model_dir = "opt-125m-fp8-static-kv"
124+
125+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
126+
examples = ["auto-fp8 is an easy-to-use model quantization library"]
127+
examples = tokenizer(examples, return_tensors="pt")
128+
129+
quantize_config = BaseQuantizeConfig(
130+
quant_method="fp8",
131+
activation_scheme="static",
132+
kv_cache_quant_targets=("k_proj", "v_proj"),
133+
)
134+
135+
model = AutoFP8ForCausalLM.from_pretrained(model_id, quantize_config)
136+
model.model.to("cpu")
137+
138+
model.quantize(examples)
139+
model.save_quantized(quantized_model_dir)
140+
141+
tensors = safetensors.torch.load_file(f"{quantized_model_dir}/model.safetensors")
142+
count_matches = 0
143+
for name, tensor in tensors.items():
144+
if name.endswith("k_proj.output_scale") or name.endswith("v_proj.output_scale"):
145+
count_matches += 1
146+
assert count_matches == 24
147+
148+
# Measure checkpoint size and cleanup
149+
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
150+
shutil.rmtree(quantized_model_dir)
151+
152+
# We expect the model to be < 160MB
153+
target_size = 160 * (1024 * 1024)
154+
assert model_size < target_size

0 commit comments

Comments
 (0)