Skip to content

Commit 0249168

Browse files
committed
Switch from output_scale to kv_scale
1 parent e6c2225 commit 0249168

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

auto_fp8/modeling.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,15 @@ def __init__(
2828
)
2929

3030
if quantize_config.kv_cache_quant_targets:
31+
<<<<<<< HEAD
3132
<<<<<<< HEAD
3233
kv_cache_quant_layers = get_kv_cache_quant_layers(
3334
=======
3435
kv_cache_quant_layers = get_kv_cache_quant_layer(
3536
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
37+
=======
38+
kv_cache_quant_layers = get_kv_cache_quant_layers(
39+
>>>>>>> c3acdee (Switch from output_scale to kv_scale)
3640
self.model, quantize_config.kv_cache_quant_targets
3741
)
3842
if len(kv_cache_quant_layers) == 0:
@@ -182,20 +186,26 @@ def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
182186
return list(ignored_layers)
183187

184188

189+
<<<<<<< HEAD
185190
<<<<<<< HEAD
186191
def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
187192
kv_cache_quant_layers = []
188193
=======
189194
def get_kv_cache_quant_layer(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
190195
kv_cache_quant_layers = set()
191196
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
197+
=======
198+
def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List[str]:
199+
kv_cache_quant_layers = []
200+
>>>>>>> c3acdee (Switch from output_scale to kv_scale)
192201

193202
for name, linear in model.named_modules():
194203
if not isinstance(linear, torch.nn.Linear):
195204
continue
196205

197206
for output_quant_target in kv_cache_quant_targets:
198207
if name.endswith(output_quant_target):
208+
<<<<<<< HEAD
199209
<<<<<<< HEAD
200210
kv_cache_quant_layers.append(name)
201211

@@ -205,3 +215,8 @@ def get_kv_cache_quant_layer(model, kv_cache_quant_targets: Tuple[str]) -> List[
205215

206216
return list(kv_cache_quant_layers)
207217
>>>>>>> 3ee9283 (Support calibrating kv cache scales)
218+
=======
219+
kv_cache_quant_layers.append(name)
220+
221+
return kv_cache_quant_layers
222+
>>>>>>> c3acdee (Switch from output_scale to kv_scale)

auto_fp8/quantize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ def __init__(
185185
def forward(self, x):
186186
qinput, x_input_scale = per_tensor_quantize(x)
187187
if self.input_scale is None:
188-
self.input_scale = torch.nn.Parameter(x_input_scale)
188+
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
189189
elif x_input_scale > self.input_scale:
190-
self.input_scale = torch.nn.Parameter(x_input_scale)
190+
self.input_scale = torch.nn.Parameter(x_input_scale, requires_grad=False)
191191
output = fp8_gemm(
192192
A=qinput,
193193
A_scale=self.input_scale,
@@ -201,9 +201,9 @@ def forward(self, x):
201201
if self.quantize_output:
202202
qoutput, output_scale = per_tensor_quantize(output)
203203
if self.output_scale is None:
204-
self.output_scale = torch.nn.Parameter(output_scale)
204+
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
205205
elif output_scale > self.output_scale:
206-
self.output_scale = torch.nn.Parameter(output_scale)
206+
self.output_scale = torch.nn.Parameter(output_scale, requires_grad=False)
207207
output = qoutput.to(output.dtype) * output_scale
208208

209209
return output

tests/test_auto_fp8.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def test_dynamic_quantization(model_id, target_size):
6464

6565
<<<<<<< HEAD
6666
<<<<<<< HEAD
67+
<<<<<<< HEAD
68+
=======
69+
>>>>>>> c3acdee (Switch from output_scale to kv_scale)
6770
# We expect the quantized model to be a certain size
6871
target_size = target_size * (1024 * 1024)
6972
assert model_size < target_size
@@ -114,6 +117,7 @@ def test_static_quantization(model_id, target_size):
114117
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
115118
shutil.rmtree(quantized_model_dir)
116119

120+
<<<<<<< HEAD
117121
<<<<<<< HEAD
118122
# We expect the quantized model to be a certain size
119123
target_size = target_size * (1024 * 1024)
@@ -157,6 +161,9 @@ def test_kv_cache_static_quantization(model_id, target_size):
157161
=======
158162
# We expect the model to be < 160MB
159163
>>>>>>> 415c0b7 (Add fixed target sizes)
164+
=======
165+
# We expect the quantized model to be a certain size
166+
>>>>>>> c3acdee (Switch from output_scale to kv_scale)
160167
target_size = target_size * (1024 * 1024)
161168
assert model_size < target_size
162169

@@ -182,18 +189,18 @@ def test_kv_cache_static_quantization(model_id, target_size):
182189

183190
tensors = safetensors.torch.load_file(f"{quantized_model_dir}/model.safetensors")
184191
proj_linear_count = 0
185-
output_scale_count = 0
192+
kv_scale_count = 0
186193
for name, _ in tensors.items():
187194
if name.endswith("k_proj.weight") or name.endswith("v_proj.weight"):
188195
proj_linear_count += 1
189-
if name.endswith("k_proj.output_scale") or name.endswith("v_proj.output_scale"):
190-
output_scale_count += 1
191-
assert proj_linear_count == output_scale_count
196+
if name.endswith("kv_scale"):
197+
kv_scale_count += 1
198+
assert proj_linear_count // 2 == kv_scale_count
192199

193200
# Measure checkpoint size and cleanup
194201
model_size = os.path.getsize(f"{quantized_model_dir}/model.safetensors")
195202
shutil.rmtree(quantized_model_dir)
196203

197-
# We expect the model to be < 160MB
204+
# We expect the quantized model to be a certain size
198205
target_size = target_size * (1024 * 1024)
199206
assert model_size < target_size

0 commit comments

Comments
 (0)