Skip to content

Commit b885229

Browse files
authored
Skip writing empty g_idx to disk, fix compress_quantized_weights (#143)
1 parent d8a717c commit b885229

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

src/compressed_tensors/compressors/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def compress(
125125
else:
126126
compressed_dict[name] = value.to("cpu")
127127
elif name.endswith("zero_point") and torch.all(value == 0):
128-
# all zero_points are 0, no need to include in
129-
# compressed state_dict
128+
continue
129+
elif name.endswith("g_idx") and torch.any(value <= -1):
130130
continue
131131
else:
132132
compressed_dict[name] = value.to("cpu")

src/compressed_tensors/quantization/lifecycle/compressed.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def compress_quantized_weights(module: Module):
4949
weight = getattr(module, "weight", None)
5050
scale = getattr(module, "weight_scale", None)
5151
zero_point = getattr(module, "weight_zero_point", None)
52+
g_idx = getattr(module, "weight_g_idx", None)
5253

5354
if weight is None or scale is None:
5455
# no weight, scale, or ZP, nothing to do
@@ -62,6 +63,7 @@ def compress_quantized_weights(module: Module):
6263
x=weight,
6364
scale=scale,
6465
zero_point=zero_point,
66+
g_idx=g_idx,
6567
args=scheme.weights,
6668
dtype=torch.int8,
6769
)

0 commit comments

Comments
 (0)