Skip to content

Commit 3ee9283

Browse files
committed
Support calibrating kv cache scales
1 parent 9474526 commit 3ee9283

File tree

4 files changed

+186
-59
lines changed

4 files changed

+186
-59
lines changed

auto_fp8/config.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Optional, Tuple
22

33

44
class BaseQuantizeConfig:
@@ -17,13 +17,17 @@ class BaseQuantizeConfig:
1717
regex style matching i.e. re.search(), for each Linear layer.
1818
By default, "re:.*lm_head" is included to ignore the embedding
1919
Linear layer usually at the end of decoder LLMs
20+
kv_cache_quant_targets: Tuple of Linear module names to target for
21+
calibration of the output scales for KV cache quantization.
22+
Usually, these should be `("k_proj", "v_proj")`.
2023
"""
2124

2225
def __init__(
2326
self,
2427
quant_method: str = "fp8",
2528
activation_scheme: str = "static",
26-
ignore_patterns: List[str] = [],
29+
ignore_patterns: List[str] = ["re:.*lm_head"],
30+
kv_cache_quant_targets: Optional[Tuple[str]] = None,
2731
):
2832
if quant_method != "fp8":
2933
raise ValueError("Only FP8 quantization is supported.")
@@ -34,4 +38,5 @@ def __init__(
3438
self.quant_method = quant_method
3539
self.activation_scheme = activation_scheme
3640
self.ignore_patterns = ignore_patterns
41+
self.kv_cache_quant_targets = kv_cache_quant_targets
3742
self.ignored_layers = []

auto_fp8/modeling.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from typing import List
2+
from typing import List, Optional, Tuple
33

44
import torch
55
from transformers import AutoModelForCausalLM
@@ -27,6 +27,16 @@ def __init__(
2727
self.model, quantize_config.ignore_patterns
2828
)
2929

30+
if quantize_config.kv_cache_quant_targets:
31+
kv_cache_quant_layers = get_kv_cache_quant_layer(
32+
self.model, quantize_config.kv_cache_quant_targets
33+
)
34+
if len(kv_cache_quant_layers) == 0:
35+
raise ValueError(
36+
f"Could not find any kv cache layers using kv_cache_quant_targets={quantize_config.kv_cache_quant_targets}, please fix your argument."
37+
)
38+
quantize_config.kv_cache_quant_layers = kv_cache_quant_layers
39+
3040
self.quantize_config = quantize_config
3141

3242
@classmethod
@@ -97,7 +107,7 @@ def skip(*args, **kwargs):
97107

98108
return cls(model, quantize_config)
99109

100-
def quantize(self, calibration_tokens):
110+
def quantize(self, calibration_tokens: Optional[torch.Tensor] = None):
101111
def _prepare_calibration_data(calibration_tokens):
102112
if hasattr(calibration_tokens, "input_ids"):
103113
return calibration_tokens.input_ids
@@ -107,6 +117,9 @@ def _prepare_calibration_data(calibration_tokens):
107117
quantize_weights(self.model, self.quantize_config)
108118

109119
if self.quantize_config.activation_scheme == "static":
120+
assert (
121+
calibration_tokens is not None
122+
), "Calibration tokens required for activation quantization"
110123
quantize_activations(
111124
self.model,
112125
self.quantize_config,
@@ -128,9 +141,6 @@ def save_quantized(self, save_dir):
128141
def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
129142
ignored_layers = set()
130143

131-
# TODO: don't always ignore lm_head
132-
ignore_patterns.append("re:.*lm_head")
133-
134144
for name, linear in model.named_modules():
135145
if not isinstance(linear, torch.nn.Linear):
136146
continue
@@ -148,3 +158,17 @@ def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
148158
ignored_layers.add(name)
149159

150160
return list(ignored_layers)
161+
162+
163+
def get_kv_cache_quant_layer(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
164+
kv_cache_quant_layers = set()
165+
166+
for name, linear in model.named_modules():
167+
if not isinstance(linear, torch.nn.Linear):
168+
continue
169+
170+
for output_quant_target in kv_cache_quant_targets:
171+
if name.endswith(output_quant_target):
172+
kv_cache_quant_layers.add(name)
173+
174+
return list(kv_cache_quant_layers)

auto_fp8/quantize.py

Lines changed: 88 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import gc
22
import re
3-
from typing import List, Tuple
3+
from typing import Optional, Tuple
44
import copy
55

66
import torch
@@ -61,13 +61,21 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
6161
return qweight, scale
6262

6363

64+
def static_per_tensor_quantize(tensor: torch.Tensor, inv_scale: float) -> torch.Tensor:
65+
finfo = torch.finfo(torch.float8_e4m3fn)
66+
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
67+
return qweight.to(torch.float8_e4m3fn)
68+
69+
6470
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
6571
if A.numel() == 0:
6672
# Deal with empty tensors (triggeted by empty MoE experts)
6773
return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)
68-
74+
6975
native_fp8_support = (
70-
torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
76+
torch.cuda.is_available()
77+
and torch.cuda.get_device_capability() >= (8, 9)
78+
and False
7179
)
7280
if native_fp8_support:
7381
need_reshape = A.dim() == 3
@@ -98,84 +106,108 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):
98106
return output
99107

100108

101-
class FP8StaticLinearQuantizer(torch.nn.Module):
109+
# Class responsible for quantizing weights
110+
class FP8DynamicLinear(torch.nn.Module):
102111
def __init__(
103-
self, qweight: torch.Tensor, weight_scale: torch.Tensor, bias: torch.Tensor
112+
self,
113+
qweight: torch.Tensor,
114+
weight_scale: torch.Tensor,
115+
bias: torch.nn.Parameter,
104116
):
105117
super().__init__()
106-
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
118+
self.qweight = torch.nn.Parameter(qweight, requires_grad=False)
107119
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
108-
self.input_scale = None
109120
self.bias = bias
110121

111122
def forward(self, x):
112-
qinput, x_input_scale = per_tensor_quantize(x)
113-
if self.input_scale is None:
114-
self.input_scale = torch.nn.Parameter(x_input_scale)
115-
elif x_input_scale > self.input_scale:
116-
self.input_scale = torch.nn.Parameter(x_input_scale)
123+
qinput, x_scale = per_tensor_quantize(x)
117124
output = fp8_gemm(
118125
A=qinput,
119-
A_scale=self.input_scale,
120-
B=self.weight,
126+
A_scale=x_scale,
127+
B=self.qweight,
121128
B_scale=self.weight_scale,
122129
bias=self.bias,
123130
out_dtype=x.dtype,
124131
)
125132
return output
126133

127134

128-
class FP8StaticLinear(torch.nn.Module):
135+
# Module responsible for taking already quantized weights, and recording input scales (and possibly output scales) using an activation observer
136+
class FP8StaticLinearQuantizer(torch.nn.Module):
129137
def __init__(
130138
self,
131139
qweight: torch.Tensor,
132140
weight_scale: torch.Tensor,
133-
bias: torch.Tensor,
134-
input_scale: float = 1.0,
141+
bias: torch.nn.Parameter,
142+
quantize_output: bool = False,
135143
):
136144
super().__init__()
137-
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
145+
self.qweight = torch.nn.Parameter(qweight, requires_grad=False)
138146
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
139-
self.input_scale = torch.nn.Parameter(input_scale, requires_grad=False)
140147
self.bias = bias
141-
142-
def per_tensor_quantize(
143-
self, tensor: torch.Tensor, inv_scale: float
144-
) -> torch.Tensor:
145-
finfo = torch.finfo(torch.float8_e4m3fn)
146-
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
147-
return qweight.to(torch.float8_e4m3fn)
148+
self.input_scale = None
149+
self.output_scale = None
150+
self.quantize_output = quantize_output
148151

149152
def forward(self, x):
150-
qinput = self.per_tensor_quantize(x, inv_scale=self.input_scale)
153+
qinput, x_input_scale = per_tensor_quantize(x)
154+
if self.input_scale is None:
155+
self.input_scale = torch.nn.Parameter(x_input_scale)
156+
elif x_input_scale > self.input_scale:
157+
self.input_scale = torch.nn.Parameter(x_input_scale)
151158
output = fp8_gemm(
152159
A=qinput,
153160
A_scale=self.input_scale,
154-
B=self.weight,
161+
B=self.qweight,
155162
B_scale=self.weight_scale,
156163
bias=self.bias,
157164
out_dtype=x.dtype,
158165
)
166+
167+
# Optionally, quantize output and record scale
168+
if self.quantize_output:
169+
qoutput, output_scale = per_tensor_quantize(output)
170+
if self.output_scale is None:
171+
self.output_scale = torch.nn.Parameter(output_scale)
172+
elif output_scale > self.output_scale:
173+
self.output_scale = torch.nn.Parameter(output_scale)
174+
output = qoutput.to(output.dtype) * output_scale
175+
159176
return output
160177

161178

162-
class FP8DynamicLinear(torch.nn.Module):
163-
def __init__(self, qweight: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor):
179+
# Module responsible for representing the final checkpoint representation
180+
class FP8StaticLinear(torch.nn.Module):
181+
def __init__(
182+
self,
183+
qweight: torch.nn.Parameter,
184+
weight_scale: torch.nn.Parameter,
185+
bias: torch.nn.Parameter,
186+
input_scale: torch.nn.Parameter,
187+
output_scale: Optional[torch.nn.Parameter] = None,
188+
):
164189
super().__init__()
165-
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
166-
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
190+
self.qweight = qweight
191+
self.weight_scale = weight_scale
167192
self.bias = bias
193+
self.input_scale = input_scale
194+
self.output_scale = output_scale
168195

169196
def forward(self, x):
170-
qinput, x_scale = per_tensor_quantize(x)
197+
qinput = static_per_tensor_quantize(x, self.input_scale)
171198
output = fp8_gemm(
172199
A=qinput,
173-
A_scale=x_scale,
174-
B=self.weight,
200+
A_scale=self.input_scale,
201+
B=self.qweight,
175202
B_scale=self.weight_scale,
176203
bias=self.bias,
177204
out_dtype=x.dtype,
178205
)
206+
207+
if self.output_scale:
208+
qoutput = static_per_tensor_quantize(output, self.output_scale)
209+
output = qoutput.to(output.dtype) * self.output_scale
210+
179211
return output
180212

181213

@@ -194,7 +226,6 @@ def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn.
194226
def quantize_weights(
195227
model: AutoModelForCausalLM,
196228
quantize_config: BaseQuantizeConfig,
197-
ignored_layers: List[str] = [],
198229
):
199230
named_modules = list(model.named_modules())
200231
for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights"):
@@ -203,9 +234,11 @@ def quantize_weights(
203234
or name in quantize_config.ignored_layers
204235
):
205236
continue
206-
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
237+
quant_weight, weight_scale = per_tensor_quantize(linear.weight)
207238
bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
208-
quant_linear = FP8DynamicLinear(quant_weight, quant_scale, bias)
239+
quant_linear = FP8DynamicLinear(
240+
qweight=quant_weight, weight_scale=weight_scale, bias=bias
241+
)
209242
replace_module(model, name, quant_linear)
210243
del linear.weight
211244
del linear.bias
@@ -217,7 +250,6 @@ def quantize_activations(
217250
model: AutoModelForCausalLM,
218251
quantize_config: BaseQuantizeConfig,
219252
calibration_tokens,
220-
ignored_layers: List[str] = [],
221253
):
222254
# Replace weight quantizer with a dynamic activation quantizer observer
223255
for name, dynamic_quant_linear in model.named_modules():
@@ -227,16 +259,22 @@ def quantize_activations(
227259
):
228260
continue
229261
quantizer = FP8StaticLinearQuantizer(
230-
dynamic_quant_linear.weight,
231-
dynamic_quant_linear.weight_scale,
232-
dynamic_quant_linear.bias,
262+
qweight=dynamic_quant_linear.qweight,
263+
weight_scale=dynamic_quant_linear.weight_scale,
264+
bias=dynamic_quant_linear.bias,
265+
quantize_output=(
266+
hasattr(quantize_config, "kv_cache_quant_layers")
267+
and name in quantize_config.kv_cache_quant_layers
268+
),
233269
)
234270
replace_module(model, name, quantizer)
235271
del dynamic_quant_linear
236272
cleanup_memory()
237273

238274
# Pass through calibration data to measure activation scales
239-
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
275+
with tqdm.tqdm(
276+
total=calibration_tokens.shape[0], desc="Calibrating activation scales"
277+
) as pbar:
240278
for row_idx in range(calibration_tokens.shape[0]):
241279
model(calibration_tokens[row_idx].reshape(1, -1))
242280
cleanup_memory()
@@ -250,10 +288,11 @@ def quantize_activations(
250288
):
251289
continue
252290
static_proj = FP8StaticLinear(
253-
quantizer.weight,
254-
quantizer.weight_scale,
255-
quantizer.bias,
256-
quantizer.input_scale,
291+
qweight=quantizer.qweight,
292+
weight_scale=quantizer.weight_scale,
293+
bias=quantizer.bias,
294+
input_scale=quantizer.input_scale,
295+
output_scale=quantizer.output_scale,
257296
)
258297
replace_module(model, name, static_proj)
259298
del quantizer
@@ -264,7 +303,6 @@ def save_quantized_model(
264303
model: AutoModelForCausalLM,
265304
quant_config: BaseQuantizeConfig,
266305
save_dir: str,
267-
ignored_layers: List[str] = [],
268306
):
269307
print(model)
270308
print(f"Saving the model to {save_dir}")
@@ -275,6 +313,8 @@ def save_quantized_model(
275313
"ignored_layers": quant_config.ignored_layers,
276314
}
277315
}
316+
if hasattr(quant_config, "kv_cache_quant_layers"):
317+
static_q_dict["quantization_config"]["kv_cache_scheme"] = "static"
278318
model.config.update(static_q_dict)
279319
model.save_pretrained(save_dir)
280320
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)

0 commit comments

Comments
 (0)