Skip to content

Commit 9170fb3

Browse files
authored
[Decompression] Clean-up and some fixes (#461)
* clean-up * update * update * update * remove import * Update initialize.py
1 parent 0b686cd commit 9170fb3

File tree

3 files changed

+22
-32
lines changed

3 files changed

+22
-32
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ def parse_sparsity_config(
224224
s_config = compression_config.sparsity_config
225225
return s_config.model_dump() if s_config is not None else None
226226

227-
return compression_config.get(SPARSITY_CONFIG_NAME, None)
227+
# explicitly return None if {} in config
228+
return compression_config.get(SPARSITY_CONFIG_NAME, None) or None
228229

229230
@staticmethod
230231
def parse_quantization_config(
@@ -712,17 +713,16 @@ def decompress(self, model_path: str, model: Module):
712713
# Load activation scales/zp or any other quantization parameters
713714
# Conditionally load the weight quantization parameters if we have a
714715
# dense compressor or if a sparsity compressor has already been applied
716+
load_weight_qparams = sparse_decompressed or isinstance(
717+
quant_compressor, DenseCompressor
718+
)
715719
load_pretrained_quantization_parameters(
716720
model,
717721
model_path,
718722
# TODO: all weight quantization params will be moved to the
719723
# compressor in a follow-up including initialization
720-
load_weight_quantization=(
721-
sparse_decompressed
722-
or isinstance(quant_compressor, DenseCompressor)
723-
),
724+
load_weight_qparams=load_weight_qparams,
724725
)
725-
726726
model_path_or_state_dict = (
727727
model.state_dict() if sparse_decompressed else model_path
728728
)
@@ -732,7 +732,9 @@ def decompress(self, model_path: str, model: Module):
732732
)
733733
# TODO: all weight quantization params will be moved to the compressor
734734
# to prevent duplicate parameter updates in update_parameter_data
735-
self._replace_weights(dense_gen, model)
735+
self._replace_weights(
736+
dense_gen, model, load_weight_qparams=not load_weight_qparams
737+
)
736738

737739
def freeze_quantization_status(module):
738740
module.quantization_status = QuantizationStatus.FROZEN
@@ -819,7 +821,9 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module):
819821
param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad)
820822
register_offload_parameter(module, param_name, param)
821823

822-
def _replace_weights(self, dense_weight_generator, model: Module):
824+
def _replace_weights(
825+
self, dense_weight_generator, model: Module, load_weight_qparams: bool = True
826+
):
823827
"""
824828
Replace the weights of the model with the
825829
provided dense weights.
@@ -847,6 +851,7 @@ def _replace_weights(self, dense_weight_generator, model: Module):
847851
# decompression in init to be consistent with loading which happens
848852
# later as well however, update_data does a good shape check -
849853
# should be moved to the compressor
854+
850855
if param_name == "weight":
851856
delattr(module, param_name)
852857
requires_grad = param_data.dtype in (
@@ -858,7 +863,7 @@ def _replace_weights(self, dense_weight_generator, model: Module):
858863
param_data.to(device), requires_grad=requires_grad
859864
)
860865
register_offload_parameter(module, param_name, param)
861-
else:
866+
elif load_weight_qparams:
862867
# Should already be registered to the correct device for
863868
# for scales/zero-points
864869
update_parameter_data(module, param_data, param_name)

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,19 +65,19 @@
6565
def load_pretrained_quantization_parameters(
6666
model: Module,
6767
model_name_or_path: Optional[str] = None,
68-
load_weight_quantization: Optional[bool] = False,
68+
load_weight_qparams: Optional[bool] = False,
6969
):
7070
"""
7171
Loads the quantization parameters (scale and zero point) from model_name_or_path to
7272
a model that has already been initialized with a quantization config.
7373
7474
NOTE: Will always load inputs/output parameters. Will conditioanlly load weight
75-
parameters, if load_weight_quantization is set to True.
75+
parameters, if load_weight_qparams is set to True.
7676
7777
:param model: model to load pretrained quantization parameters to
7878
:param model_name_or_path: Hugging Face stub or local folder containing a quantized
7979
model, which is used to load quantization parameters
80-
:param load_weight_quantization: whether or not the weight quantization parameters
80+
:param load_weight_qparams: whether or not the weight quantization parameters
8181
should be loaded
8282
"""
8383
model_path = get_safetensors_folder(model_name_or_path)
@@ -103,7 +103,7 @@ def load_pretrained_quantization_parameters(
103103
mapping=mapping,
104104
)
105105

106-
if load_weight_quantization and submodule.quantization_scheme.weights:
106+
if load_weight_qparams and submodule.quantization_scheme.weights:
107107
base_name = "weight"
108108
_load_quant_args_from_mapping(
109109
base_name=base_name,
@@ -219,18 +219,9 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
219219
if status >= QuantizationStatus.INITIALIZED > current_status:
220220
force_zero_point_init = status != QuantizationStatus.COMPRESSED
221221

222-
# When decompressing, we set the scale_dtype as the model's dtype
223-
# This is because the normal workflow of using the weight's dtype
224-
# will be incorrect as the model weight will be compressed
225-
# Therfore, use the dtype set by the user using the PretrainedModel
226-
scale_dtype = None
227-
if status == QuantizationStatus.FROZEN:
228-
if hasattr(model, "dtype"):
229-
scale_dtype = model.dtype
230-
231222
model.apply(
232223
lambda module: initialize_module_for_quantization(
233-
module, force_zero_point=force_zero_point_init, scale_dtype=scale_dtype
224+
module, force_zero_point=force_zero_point_init
234225
)
235226
)
236227

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def initialize_module_for_quantization(
5959
module: Module,
6060
scheme: Optional[QuantizationScheme] = None,
6161
force_zero_point: bool = True,
62-
scale_dtype: Optional[torch.dtype] = None,
6362
):
6463
"""
6564
attaches appropriate scales, zero points, and observers to a layer
@@ -73,8 +72,6 @@ def initialize_module_for_quantization(
7372
if not provided, the layer will be skipped
7473
:param force_zero_point: whether to force initialization of a zero point for
7574
symmetric quantization
76-
:param scale_dtype: dtype to used for the scales, if overriding the
77-
weight dtype as the scale dtype
7875
"""
7976
# TODO: don't initialize parameters when running decompression
8077
scheme = scheme or getattr(module, "quantization_scheme", None)
@@ -93,7 +90,6 @@ def initialize_module_for_quantization(
9390
"input",
9491
scheme.input_activations,
9592
force_zero_point=force_zero_point,
96-
scale_dtype=scale_dtype,
9793
)
9894

9995
if scheme.weights is not None:
@@ -107,7 +103,6 @@ def initialize_module_for_quantization(
107103
scheme.weights,
108104
weight_shape=weight_shape,
109105
force_zero_point=force_zero_point,
110-
scale_dtype=scale_dtype,
111106
)
112107
else:
113108
_LOGGER.warning(
@@ -119,7 +114,7 @@ def initialize_module_for_quantization(
119114
if scheme.output_activations is not None:
120115
if not is_kv_cache_quant_scheme(scheme):
121116
_initialize_scale_zero_point(
122-
module, "output", scheme.output_activations, scale_dtype=scale_dtype
117+
module, "output", scheme.output_activations
123118
)
124119

125120
module.quantization_scheme = scheme
@@ -145,7 +140,6 @@ def _initialize_scale_zero_point(
145140
quantization_args: QuantizationArgs,
146141
weight_shape: Optional[torch.Size] = None,
147142
force_zero_point: bool = True,
148-
scale_dtype: Optional[torch.dtype] = None,
149143
):
150144
if quantization_args.dynamic is True:
151145
return
@@ -213,7 +207,7 @@ def _initialize_scale_zero_point(
213207
expected_shape = 1
214208

215209
# 3. Identify quantization scale and zp dtype
216-
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
210+
scale_dtype = module.weight.dtype
217211

218212
if is_fp4(quantization_args=quantization_args):
219213
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
@@ -226,7 +220,7 @@ def _initialize_scale_zero_point(
226220
torch.float32,
227221
torch.float64,
228222
]:
229-
scale_dtype = torch.float16
223+
scale_dtype = torch.bfloat16
230224
zp_dtype = quantization_args.pytorch_dtype()
231225

232226
# 4. Initializes empty scale, zero point, and g_idx parameters for the module

0 commit comments

Comments
 (0)