Skip to content

Commit c3acdee

Browse files
committed
Switch from output_scale to kv_scale
1 parent 57c31bb commit c3acdee

File tree

3 files changed

+40
-16
lines changed

3 files changed

+40
-16
lines changed

auto_fp8/modeling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
)
2929

3030
if quantize_config.kv_cache_quant_targets:
31-
kv_cache_quant_layers = get_kv_cache_quant_layer(
31+
kv_cache_quant_layers = get_kv_cache_quant_layers(
3232
self.model, quantize_config.kv_cache_quant_targets
3333
)
3434
if len(kv_cache_quant_layers) == 0:
@@ -159,15 +159,15 @@ def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
159159
return list(ignored_layers)
160160

161161

162-
def get_kv_cache_quant_layer(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
163-
kv_cache_quant_layers = set()
162+
def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
163+
kv_cache_quant_layers = []
164164

165165
for name, linear in model.named_modules():
166166
if not isinstance(linear, torch.nn.Linear):
167167
continue
168168

169169
for output_quant_target in kv_cache_quant_targets:
170170
if name.endswith(output_quant_target):
171-
kv_cache_quant_layers.add(name)
171+
kv_cache_quant_layers.append(name)
172172

173-
return list(kv_cache_quant_layers)
173+
return kv_cache_quant_layers

auto_fp8/quantize.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,9 @@ def __init__(
152152
def forward(self, x):
153153
qinput, x_input_scale = per_tensor_quantize(x)
154154
if self.input_scale is None:
155-
self.input_scale = torch.nn.Parameter(x_input_scale)
155+
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
156156
elif x_input_scale > self.input_scale:
157-
self.input_scale = torch.nn.Parameter(x_input_scale)
157+
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
158158
output = fp8_gemm(
159159
A=qinput,
160160
A_scale=self.input_scale,
@@ -168,9 +168,9 @@ def forward(self, x):
168168
if self.quantize_output:
169169
qoutput, output_scale = per_tensor_quantize(output)
170170
if self.output_scale is None:
171-
self.output_scale = torch.nn.Parameter(output_scale)
171+
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
172172
elif output_scale > self.output_scale:
173-
self.output_scale = torch.nn.Parameter(output_scale)
173+
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
174174
output = qoutput.to(output.dtype) * output_scale
175175

176176
return output
@@ -307,6 +307,30 @@ def quantize_activations(
307307
del quantizer
308308
cleanup_memory()
309309

310+
# Post-process step for kv cache scales to take the k/v module
311+
# `output_scale` parameters, take the max of them, and store them in
312+
# the parent attention module as `kv_scale`
313+
# NOTE: if we want to switch to the `output_scale` representation, we can simply remove this block
314+
if hasattr(quantize_config, "kv_cache_quant_layers"):
315+
# Assumes that list is ordered such that [layer0.k_proj, layer0.v_proj, layer1.k_proj, layer1.v_proj, ...]
316+
# so we make a list of tuples [(layer0.k_proj, layer0.v_proj), (layer1.k_proj, layer1.v_proj), ...]
317+
kv_proj_pairs = zip(*[iter(quantize_config.kv_cache_quant_layers)]*2)
318+
for k_proj_name, v_proj_name in kv_proj_pairs:
319+
parent_module_name = ".".join(k_proj_name.split(".")[:-1])
320+
assert parent_module_name == ".".join(v_proj_name.split(".")[:-1])
321+
parent_module = dict(model.named_modules())[parent_module_name]
322+
323+
k_proj = dict(model.named_modules())[k_proj_name]
324+
v_proj = dict(model.named_modules())[v_proj_name]
325+
326+
kv_scale = max(k_proj.output_scale, v_proj.output_scale)
327+
parent_module.kv_scale = torch.nn.Parameter(kv_scale, requires_grad=False)
328+
329+
# Remove output_scale from k_proj and v_proj
330+
k_proj.output_scale = None
331+
v_proj.output_scale = None
332+
cleanup_memory()
333+
310334

311335
def save_quantized_model(
312336
model: AutoModelForCausalLM,

tests/test_auto_fp8.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_dynamic_quantization(model_id, target_size):
3030
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
3131
shutil.rmtree(quantized_model_dir)
3232

33-
# We expect the model to be a certain size
33+
# We expect the quantized model to be a certain size
3434
target_size = target_size * (1024 * 1024)
3535
assert model_size < target_size
3636

@@ -55,7 +55,7 @@ def test_static_quantization(model_id, target_size):
5555
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
5656
shutil.rmtree(quantized_model_dir)
5757

58-
# We expect the model to be < 160MB
58+
# We expect the quantized model to be a certain size
5959
target_size = target_size * (1024 * 1024)
6060
assert model_size < target_size
6161

@@ -81,18 +81,18 @@ def test_kv_cache_static_quantization(model_id, target_size):
8181

8282
tensors = safetensors.torch.load_file(f"{quantized_model_dir}/model.safetensors")
8383
proj_linear_count = 0
84-
output_scale_count = 0
84+
kv_scale_count = 0
8585
for name, _ in tensors.items():
8686
if name.endswith("k_proj.weight") or name.endswith("v_proj.weight"):
8787
proj_linear_count += 1
88-
if name.endswith("k_proj.output_scale") or name.endswith("v_proj.output_scale"):
89-
output_scale_count += 1
90-
assert proj_linear_count == output_scale_count
88+
if name.endswith("kv_scale"):
89+
kv_scale_count += 1
90+
assert proj_linear_count // 2 == kv_scale_count
9191

9292
# Measure checkpoint size and cleanup
9393
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
9494
shutil.rmtree(quantized_model_dir)
9595

96-
# We expect the model to be < 160MB
96+
# We expect the quantized model to be a certain size
9797
target_size = target_size * (1024 * 1024)
9898
assert model_size < target_size

0 commit comments

Comments
 (0)