Skip to content

Commit 0435986

Browse files
authored
Support quantizing only kv cache (#135)
1 parent ef77dca commit 0435986

File tree

3 files changed

+7
-1
lines changed

3 files changed

+7
-1
lines changed

src/compressed_tensors/compressors/model_compressor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,9 @@ def compress(
271271
v_proj_has_quant_output = 0
272272
for name, module in model.named_modules():
273273
if not hasattr(module, "quantization_scheme"):
274+
# We still want to count non-quantized q_proj
275+
if name.endswith(".q_proj"):
276+
q_proj_has_no_quant_output += 1
274277
continue
275278
out_act = module.quantization_scheme.output_activations
276279
if name.endswith(".q_proj") and out_act is None:

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def is_preset_scheme(name: str) -> bool:
110110
"""
111111
return name.upper() in PRESET_SCHEMES
112112

113+
UNQUANTIZED = dict()
113114

114115
# 8 bit integer weights and 8 bit activations quantization
115116
W8A8 = dict(
@@ -208,6 +209,8 @@ def is_preset_scheme(name: str) -> bool:
208209
)
209210

210211
PRESET_SCHEMES = {
212+
# Unquantized (no-op)
213+
"UNQUANTIZED": UNQUANTIZED,
211214
# Integer weight only schemes
212215
"W8A16": W8A16,
213216
"W4A16": W4A16,

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def calculate_compression_ratio(model: Module) -> float:
181181
for parameter in model.parameters():
182182
uncompressed_bits = get_torch_bit_depth(parameter)
183183
compressed_bits = uncompressed_bits
184-
if is_module_quantized(submodule):
184+
if is_module_quantized(submodule) and submodule.quantization_scheme.weights:
185185
compressed_bits = submodule.quantization_scheme.weights.num_bits
186186

187187
num_weights = parameter.numel()

0 commit comments

Comments
 (0)