Skip to content

Commit 0d40b99

Browse files
authored
Support calibrating kv cache scales (#17)
* Support calibrating kv cache scales * Add comment * Fix weight name * Add Qwen test * Fix kv cache test count * Add fixed target sizes * Fix proj linear count * Switch from output_scale to kv_scale * Add example
1 parent b1c6ad6 commit 0d40b99

File tree

5 files changed

+249
-69
lines changed

5 files changed

+249
-69
lines changed

auto_fp8/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Optional, Tuple
22

33

44
class BaseQuantizeConfig:
@@ -17,13 +17,17 @@ class BaseQuantizeConfig:
1717
regex style matching i.e. re.search(), for each Linear layer.
1818
By default, "re:.*lm_head" is included to ignore the embedding
1919
Linear layer usually at the end of decoder LLMs
20+
kv_cache_quant_targets: Tuple of Linear module names to target for
21+
calibration of the output scales for KV cache quantization.
22+
Usually, these should be `("k_proj", "v_proj")`.
2023
"""
2124

2225
def __init__(
2326
self,
2427
quant_method: str = "fp8",
2528
activation_scheme: str = "static",
26-
ignore_patterns: List[str] = [],
29+
ignore_patterns: List[str] = ["re:.*lm_head"],
30+
kv_cache_quant_targets: Optional[Tuple[str]] = None,
2731
):
2832
if quant_method != "fp8":
2933
raise ValueError("Only FP8 quantization is supported.")
@@ -34,4 +38,5 @@ def __init__(
3438
self.quant_method = quant_method
3539
self.activation_scheme = activation_scheme
3640
self.ignore_patterns = ignore_patterns
41+
self.kv_cache_quant_targets = kv_cache_quant_targets
3742
self.ignored_layers = []

auto_fp8/modeling.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from typing import List
2+
from typing import List, Optional, Tuple
33

44
import torch
55
from transformers import AutoModelForCausalLM
@@ -27,6 +27,16 @@ def __init__(
2727
self.model, quantize_config.ignore_patterns
2828
)
2929

30+
if quantize_config.kv_cache_quant_targets:
31+
kv_cache_quant_layers = get_kv_cache_quant_layers(
32+
self.model, quantize_config.kv_cache_quant_targets
33+
)
34+
if len(kv_cache_quant_layers) == 0:
35+
raise ValueError(
36+
f"Could not find any kv cache layers using kv_cache_quant_targets={quantize_config.kv_cache_quant_targets}, please fix your argument."
37+
)
38+
quantize_config.kv_cache_quant_layers = kv_cache_quant_layers
39+
3040
self.quantize_config = quantize_config
3141

3242
@classmethod
@@ -97,26 +107,28 @@ def skip(*args, **kwargs):
97107

98108
return cls(model, quantize_config)
99109

100-
def quantize(self, calibration_tokens):
101-
def _prepare_calibration_data(calibration_tokens):
102-
if hasattr(calibration_tokens, "input_ids"):
103-
return calibration_tokens.input_ids
104-
return calibration_tokens
110+
def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):
105111

106112
# Always quantize the weights as they do not require calibration data
107113
quantize_weights(self.model, self.quantize_config)
108114

109115
if self.quantize_config.activation_scheme == "static":
116+
assert (
117+
calibration_tokens is not None
118+
), "Calibration tokens required for activation quantization"
119+
120+
121+
def _prepare_calibration_data(calibration_tokens):
122+
if hasattr(calibration_tokens, "input_ids"):
123+
return calibration_tokens.input_ids
124+
return calibration_tokens
125+
110126
quantize_activations(
111127
self.model,
112128
self.quantize_config,
113129
_prepare_calibration_data(calibration_tokens),
114130
)
115131

116-
# import copy
117-
# for layer in self.model.model.layers:
118-
# layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.input_scale)
119-
120132
def save_quantized(self, save_dir):
121133
save_quantized_model(
122134
self.model,
@@ -128,9 +140,6 @@ def save_quantized(self, save_dir):
128140
def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
129141
ignored_layers = set()
130142

131-
# TODO: don't always ignore lm_head
132-
ignore_patterns.append("re:.*lm_head")
133-
134143
for name, linear in model.named_modules():
135144
if not isinstance(linear, torch.nn.Linear):
136145
continue
@@ -148,3 +157,17 @@ def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
148157
ignored_layers.add(name)
149158

150159
return list(ignored_layers)
160+
161+
162+
def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
163+
kv_cache_quant_layers = []
164+
165+
for name, linear in model.named_modules():
166+
if not isinstance(linear, torch.nn.Linear):
167+
continue
168+
169+
for output_quant_target in kv_cache_quant_targets:
170+
if name.endswith(output_quant_target):
171+
kv_cache_quant_layers.append(name)
172+
173+
return kv_cache_quant_layers

auto_fp8/quantize.py

Lines changed: 109 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import gc
22
import re
3-
from typing import List, Tuple
3+
from typing import Optional, Tuple
44
import copy
55

66
import torch
@@ -61,14 +61,22 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
6161
return qweight, scale
6262

6363

64+
def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
65+
finfo = torch.finfo(torch.float8_e4m3fn)
66+
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
67+
return qweight.to(torch.float8_e4m3fn)
68+
69+
6470
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
6571
if A.numel() == 0:
6672
# Deal with empty tensors (triggeted by empty MoE experts)
6773
return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)
68-
69-
native_fp8_support = (
70-
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
71-
)
74+
75+
# TODO: Disable native fp8 gemm for now, always just dequantize
76+
# native_fp8_support = (
77+
# torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
78+
# )
79+
native_fp8_support = False
7280
if native_fp8_support:
7381
need_reshape = A.dim() == 3
7482
if need_reshape:
@@ -98,25 +106,24 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
98106
return output
99107

100108

101-
class FP8StaticLinearQuantizer(torch.nn.Module):
109+
# Class responsible for quantizing weights
110+
class FP8DynamicLinear(torch.nn.Module):
102111
def __init__(
103-
self, qweight: torch.Tensor, weight_scale: torch.Tensor, bias: torch.Tensor
112+
self,
113+
weight: torch.Tensor,
114+
weight_scale: torch.Tensor,
115+
bias: torch.nn.Parameter,
104116
):
105117
super().__init__()
106-
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
118+
self.weight = torch.nn.Parameter(weight, requires_grad=False)
107119
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
108-
self.input_scale = None
109120
self.bias = bias
110121

111122
def forward(self, x):
112-
qinput, x_input_scale = per_tensor_quantize(x)
113-
if self.input_scale is None:
114-
self.input_scale = torch.nn.Parameter(x_input_scale)
115-
elif x_input_scale > self.input_scale:
116-
self.input_scale = torch.nn.Parameter(x_input_scale)
123+
qinput, x_scale = per_tensor_quantize(x)
117124
output = fp8_gemm(
118125
A=qinput,
119-
A_scale=self.input_scale,
126+
A_scale=x_scale,
120127
B=self.weight,
121128
B_scale=self.weight_scale,
122129
bias=self.bias,
@@ -125,29 +132,29 @@ def forward(self, x):
125132
return output
126133

127134

128-
class FP8StaticLinear(torch.nn.Module):
135+
# Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer
136+
class FP8StaticLinearQuantizer(torch.nn.Module):
129137
def __init__(
130138
self,
131-
qweight: torch.Tensor,
139+
weight: torch.Tensor,
132140
weight_scale: torch.Tensor,
133-
bias: torch.Tensor,
134-
input_scale: float = 1.0,
141+
bias: torch.nn.Parameter,
142+
quantize_output: bool = False,
135143
):
136144
super().__init__()
137-
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
145+
self.weight = torch.nn.Parameter(weight, requires_grad=False)
138146
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
139-
self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False)
140147
self.bias = bias
141-
142-
def per_tensor_quantize(
143-
self, tensor: torch.Tensor, inv_scale: float
144-
) -> torch.Tensor:
145-
finfo = torch.finfo(torch.float8_e4m3fn)
146-
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
147-
return qweight.to(torch.float8_e4m3fn)
148+
self.input_scale = None
149+
self.output_scale = None
150+
self.quantize_output = quantize_output
148151

149152
def forward(self, x):
150-
qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale)
153+
qinput, x_input_scale = per_tensor_quantize(x)
154+
if self.input_scale is None:
155+
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
156+
elif x_input_scale > self.input_scale:
157+
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
151158
output = fp8_gemm(
152159
A=qinput,
153160
A_scale=self.input_scale,
@@ -156,26 +163,51 @@ def forward(self, x):
156163
bias=self.bias,
157164
out_dtype=x.dtype,
158165
)
166+
167+
# Optionally, quantize output and record scale
168+
if self.quantize_output:
169+
qoutput, output_scale = per_tensor_quantize(output)
170+
if self.output_scale is None:
171+
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
172+
elif output_scale > self.output_scale:
173+
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
174+
output = qoutput.to(output.dtype) * output_scale
175+
159176
return output
160177

161178

162-
class FP8DynamicLinear(torch.nn.Module):
163-
def __init__(self, qweight: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor):
179+
# Module responsible for representing the final checkpoint representation
180+
class FP8StaticLinear(torch.nn.Module):
181+
def __init__(
182+
self,
183+
weight: torch.nn.Parameter,
184+
weight_scale: torch.nn.Parameter,
185+
bias: torch.nn.Parameter,
186+
input_scale: torch.nn.Parameter,
187+
output_scale: Optional[torch.nn.Parameter] = None,
188+
):
164189
super().__init__()
165-
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
166-
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
190+
self.weight = weight
191+
self.weight_scale = weight_scale
167192
self.bias = bias
193+
self.input_scale = input_scale
194+
self.output_scale = output_scale
168195

169196
def forward(self, x):
170-
qinput, x_scale = per_tensor_quantize(x)
197+
qinput = static_per_tensor_quantize(x, self.input_scale)
171198
output = fp8_gemm(
172199
A=qinput,
173-
A_scale=x_scale,
200+
A_scale=self.input_scale,
174201
B=self.weight,
175202
B_scale=self.weight_scale,
176203
bias=self.bias,
177204
out_dtype=x.dtype,
178205
)
206+
207+
if self.output_scale:
208+
qoutput = static_per_tensor_quantize(output, self.output_scale)
209+
output = qoutput.to(output.dtype) * self.output_scale
210+
179211
return output
180212

181213

@@ -194,7 +226,6 @@ def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn.
194226
def quantize_weights(
195227
model: AutoModelForCausalLM,
196228
quantize_config: BaseQuantizeConfig,
197-
ignored_layers: List[str] = [],
198229
):
199230
named_modules = list(model.named_modules())
200231
for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights"):
@@ -203,9 +234,11 @@ def quantize_weights(
203234
or name in quantize_config.ignored_layers
204235
):
205236
continue
206-
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
237+
quant_weight, weight_scale = per_tensor_quantize(linear.weight)
207238
bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
208-
quant_linear = FP8DynamicLinear(quant_weight, quant_scale, bias)
239+
quant_linear = FP8DynamicLinear(
240+
weight=quant_weight, weight_scale=weight_scale, bias=bias
241+
)
209242
replace_module(model, name, quant_linear)
210243
del linear.weight
211244
del linear.bias
@@ -217,7 +250,6 @@ def quantize_activations(
217250
model: AutoModelForCausalLM,
218251
quantize_config: BaseQuantizeConfig,
219252
calibration_tokens,
220-
ignored_layers: List[str] = [],
221253
):
222254
# Replace weight quantizer with a dynamic activation quantizer observer
223255
for name, dynamic_quant_linear in model.named_modules():
@@ -227,9 +259,13 @@ def quantize_activations(
227259
):
228260
continue
229261
quantizer = FP8StaticLinearQuantizer(
230-
dynamic_quant_linear.weight,
231-
dynamic_quant_linear.weight_scale,
232-
dynamic_quant_linear.bias,
262+
weight=dynamic_quant_linear.weight,
263+
weight_scale=dynamic_quant_linear.weight_scale,
264+
bias=dynamic_quant_linear.bias,
265+
quantize_output=(
266+
hasattr(quantize_config, "kv_cache_quant_layers")
267+
and name in quantize_config.kv_cache_quant_layers
268+
),
233269
)
234270
replace_module(model, name, quantizer)
235271
del dynamic_quant_linear
@@ -251,21 +287,45 @@ def quantize_activations(
251287
):
252288
continue
253289
static_proj = FP8StaticLinear(
254-
quantizer.weight,
255-
quantizer.weight_scale,
256-
quantizer.bias,
257-
quantizer.input_scale,
290+
weight=quantizer.weight,
291+
weight_scale=quantizer.weight_scale,
292+
bias=quantizer.bias,
293+
input_scale=quantizer.input_scale,
294+
output_scale=quantizer.output_scale,
258295
)
259296
replace_module(model, name, static_proj)
260297
del quantizer
261298
cleanup_memory()
262299

300+
# Post-process step for kv cache scales to take the k/v module
301+
# `output_scale` parameters, take the max of them, and store them in
302+
# the parent attention module as `kv_scale`
303+
# NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block
304+
if hasattr(quantize_config, "kv_cache_quant_layers"):
305+
# Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...]
306+
# so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...]
307+
kv_proj_pairs = zip(*[iter(quantize_config.kv_cache_quant_layers)]*2)
308+
for k_proj_name, v_proj_name in kv_proj_pairs:
309+
parent_module_name = ".".join(k_proj_name.split(".")[:-1])
310+
assert parent_module_name == ".".join(v_proj_name.split(".")[:-1])
311+
parent_module = dict(model.named_modules())[parent_module_name]
312+
313+
k_proj = dict(model.named_modules())[k_proj_name]
314+
v_proj = dict(model.named_modules())[v_proj_name]
315+
316+
kv_scale = max(k_proj.output_scale, v_proj.output_scale)
317+
parent_module.kv_scale = torch.nn.Parameter(kv_scale, requires_grad=False)
318+
319+
# Remove output_scale from k_proj and v_proj
320+
k_proj.output_scale = None
321+
v_proj.output_scale = None
322+
cleanup_memory()
323+
263324

264325
def save_quantized_model(
265326
model: AutoModelForCausalLM,
266327
quant_config: BaseQuantizeConfig,
267328
save_dir: str,
268-
ignored_layers: List[str] = [],
269329
):
270330
print(model)
271331
print(f"Saving the model to {save_dir}")
@@ -276,6 +336,8 @@ def save_quantized_model(
276336
"ignored_layers": quant_config.ignored_layers,
277337
}
278338
}
339+
if hasattr(quant_config, "kv_cache_quant_layers"):
340+
static_q_dict["quantization_config"]["kv_cache_scheme"] = "static"
279341
model.config.update(static_q_dict)
280342
model.save_pretrained(save_dir)
281343
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)

0 commit comments

Comments
 (0)