Skip to content

Commit dd87f23

Browse files
merge main
Signed-off-by: Brian Dellabetta <[email protected]>
2 parents 14a359f + d2daa9a commit dd87f23

File tree

12 files changed

+110
-249
lines changed

12 files changed

+110
-249
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 18 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
get_offloaded_device,
5050
get_safetensors_folder,
5151
has_offloaded_params,
52-
merge_names,
5352
patch_attr,
5453
register_offload_parameter,
5554
update_parameter_data,
@@ -226,7 +225,8 @@ def parse_sparsity_config(
226225
s_config = compression_config.sparsity_config
227226
return s_config.model_dump() if s_config is not None else None
228227

229-
return compression_config.get(SPARSITY_CONFIG_NAME, None)
228+
# explicitly return None if {} in config
229+
return compression_config.get(SPARSITY_CONFIG_NAME, None) or None
230230

231231
@staticmethod
232232
def parse_quantization_config(
@@ -316,117 +316,11 @@ def __init__(
316316

317317
self.quantization_compressor = {}
318318
for format in self.compression_formats:
319-
self.quantization_compressor[
320-
format
321-
] = BaseCompressor.load_from_registry(
322-
format, config=quantization_config
323-
)
324-
325-
# ----- used by hf quantizer ----- #
326-
327-
def get_missing_module_keys(self, model: Module) -> List[str]:
328-
"""
329-
Identifies the expected missing weight keys in the compressed state_dict.
330-
331-
When a model undergoes sparsity or quantization compression, certain
332-
weight tensors may be absent from the checkpoint by virtue of compression.
333-
This function determines which weight keys are missing based on the
334-
applied compression techniques.
335-
336-
:param model: The PyTorch model to check for missing keys.
337-
:return: A list of missing keys expected in the compressed state_dict.
338-
"""
339-
missing_keys = set()
340-
341-
# Determine missing keys due to sparsity compression
342-
if (
343-
self.sparsity_compressor
344-
and self.sparsity_config.format != CompressionFormat.dense.value
345-
):
346-
sparse_targets = match_named_modules(
347-
model=model,
348-
targets=self.sparsity_config.targets,
349-
ignore=self.sparsity_config.ignore,
350-
)
351-
352-
missing_keys.update(
353-
merge_names(target_name, "weight")
354-
for target_name, _module in sparse_targets
355-
)
356-
357-
# Determine missing keys due to pack quantization
358-
if (
359-
self.quantization_compressor
360-
and self.quantization_config.format
361-
== CompressionFormat.pack_quantized.value
362-
):
363-
for scheme in self.quantization_config.config_groups.values():
364-
quant_targets = match_named_modules(
365-
model=model,
366-
targets=scheme.targets,
367-
ignore=self.quantization_config.ignore,
368-
)
369-
missing_keys.update(
370-
merge_names(target_name, "weight")
371-
for target_name, _module in quant_targets
372-
)
373-
374-
return list(missing_keys)
375-
376-
def get_unexpected_file_keys(self, model: Module) -> List[str]:
377-
"""
378-
Identifies extra keys introduced by the compression process in the
379-
compressed state_dict that are not expected by the model graph.
380-
381-
During sparsity or quantization compression, additional metadata or
382-
auxiliary parameters may be stored in the checkpoint, which do not
383-
correspond to any parameter in the original model. These keys are
384-
typically introduced to support the reconstruction of compressed weights.
385-
386-
For example, Sparse24Bitmask compression may introduce keys such as
387-
'compressed', 'bitmask', and 'shape' in the checkpoint, which are
388-
not part of the original model parameters.
389-
390-
:param model: The PyTorch model to check for unexpected keys.
391-
:return: A list of extra keys introduced by the compression process
392-
that are not expected by the model.
393-
"""
394-
395-
unexpected_keys = set()
396-
397-
# Identify unexpected keys from sparsity compression
398-
if (
399-
self.sparsity_compressor
400-
and self.sparsity_config.format != CompressionFormat.dense.value
401-
):
402-
sparse_targets = match_named_modules(
403-
model=model,
404-
targets=self.sparsity_config.targets,
405-
ignore=self.sparsity_config.ignore,
406-
)
407-
unexpected_keys.update(
408-
merge_names(target_name, param)
409-
for target_name, _module in sparse_targets
410-
for param in self.sparsity_compressor.compression_param_names
411-
)
412-
413-
# Identify unexpected keys from quantization compression
414-
if self.quantization_compressor:
415-
for scheme in self.quantization_config.config_groups.values():
416-
quant_targets = match_named_modules(
417-
model=model,
418-
targets=scheme.targets,
419-
ignore=self.quantization_config.ignore,
420-
)
421-
for quant_compressor in self.quantization_compressor.values():
422-
unexpected_keys.update(
423-
merge_names(target_name, param)
424-
for target_name, _module in quant_targets
425-
for param in quant_compressor.compression_param_names
426-
if param != "weight"
319+
self.quantization_compressor[format] = (
320+
BaseCompressor.load_from_registry(
321+
format, config=quantization_config
427322
)
428-
429-
return list(unexpected_keys)
323+
)
430324

431325
# ----- model memory compression/decompression pathways ----- #
432326

@@ -716,17 +610,16 @@ def decompress(self, model_path: str, model: Module):
716610
# Load activation scales/zp or any other quantization parameters
717611
# Conditionally load the weight quantization parameters if we have a
718612
# dense compressor or if a sparsity compressor has already been applied
613+
load_weight_qparams = sparse_decompressed or isinstance(
614+
quant_compressor, DenseCompressor
615+
)
719616
load_pretrained_quantization_parameters(
720617
model,
721618
model_path,
722619
# TODO: all weight quantization params will be moved to the
723620
# compressor in a follow-up including initialization
724-
load_weight_quantization=(
725-
sparse_decompressed
726-
or isinstance(quant_compressor, DenseCompressor)
727-
),
621+
load_weight_qparams=load_weight_qparams,
728622
)
729-
730623
model_path_or_state_dict = (
731624
model.state_dict() if sparse_decompressed else model_path
732625
)
@@ -736,7 +629,9 @@ def decompress(self, model_path: str, model: Module):
736629
)
737630
# TODO: all weight quantization params will be moved to the compressor
738631
# to prevent duplicate parameter updates in update_parameter_data
739-
self._replace_weights(dense_gen, model)
632+
self._replace_weights(
633+
dense_gen, model, load_weight_qparams=not load_weight_qparams
634+
)
740635

741636
def freeze_quantization_status(module):
742637
module.quantization_status = QuantizationStatus.FROZEN
@@ -823,7 +718,9 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module):
823718
param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad)
824719
register_offload_parameter(module, param_name, param)
825720

826-
def _replace_weights(self, dense_weight_generator, model: Module):
721+
def _replace_weights(
722+
self, dense_weight_generator, model: Module, load_weight_qparams: bool = True
723+
):
827724
"""
828725
Replace the weights of the model with the
829726
provided dense weights.
@@ -851,6 +748,7 @@ def _replace_weights(self, dense_weight_generator, model: Module):
851748
# decompression in init to be consistent with loading which happens
852749
# later as well however, update_data does a good shape check -
853750
# should be moved to the compressor
751+
854752
if param_name == "weight":
855753
delattr(module, param_name)
856754
requires_grad = param_data.dtype in (
@@ -862,7 +760,7 @@ def _replace_weights(self, dense_weight_generator, model: Module):
862760
param_data.to(device), requires_grad=requires_grad
863761
)
864762
register_offload_parameter(module, param_name, param)
865-
else:
763+
elif load_weight_qparams:
866764
# Should already be registered to the correct device for
867765
# for scales/zero-points
868766
update_parameter_data(module, param_data, param_name)

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
140140
m, n = x.shape
141141
device = x.device
142142

143+
if n % 2 != 0:
144+
raise ValueError(
145+
"tensor must have an even number of columns for nvfp4 compression"
146+
)
147+
143148
# Create lookup table for FP4 values to indices
144149
# Map the absolute values to 0-7 indices
145150
kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype)
@@ -155,10 +160,6 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
155160
# Reshape to prepare for packing pairs of values
156161
indices = indices.reshape(-1)
157162

158-
# Handle odd length by padding if necessary
159-
if indices.numel() % 2 != 0:
160-
indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)])
161-
162163
# Reshape to pair consecutive elements
163164
indices = indices.reshape(-1, 2)
164165

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,19 @@
6161
def load_pretrained_quantization_parameters(
6262
model: Module,
6363
model_name_or_path: Optional[str] = None,
64-
load_weight_quantization: Optional[bool] = False,
64+
load_weight_qparams: Optional[bool] = False,
6565
):
6666
"""
6767
Loads the quantization parameters (scale and zero point) from model_name_or_path to
6868
a model that has already been initialized with a quantization config.
6969
7070
NOTE: Will always load inputs/output parameters. Will conditioanlly load weight
71-
parameters, if load_weight_quantization is set to True.
71+
parameters, if load_weight_qparams is set to True.
7272
7373
:param model: model to load pretrained quantization parameters to
7474
:param model_name_or_path: Hugging Face stub or local folder containing a quantized
7575
model, which is used to load quantization parameters
76-
:param load_weight_quantization: whether or not the weight quantization parameters
76+
:param load_weight_qparams: whether or not the weight quantization parameters
7777
should be loaded
7878
"""
7979
model_path = get_safetensors_folder(model_name_or_path)
@@ -99,7 +99,7 @@ def load_pretrained_quantization_parameters(
9999
mapping=mapping,
100100
)
101101

102-
if load_weight_quantization and submodule.quantization_scheme.weights:
102+
if load_weight_qparams and submodule.quantization_scheme.weights:
103103
base_name = "weight"
104104
_load_quant_args_from_mapping(
105105
base_name=base_name,

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def initialize_module_for_quantization(
5959
module: Module,
6060
scheme: Optional[QuantizationScheme] = None,
6161
force_zero_point: bool = True,
62-
scale_dtype: Optional[torch.dtype] = None,
6362
):
6463
"""
6564
attaches appropriate scales, zero points, and observers to a layer
@@ -73,8 +72,6 @@ def initialize_module_for_quantization(
7372
if not provided, the layer will be skipped
7473
:param force_zero_point: whether to force initialization of a zero point for
7574
symmetric quantization
76-
:param scale_dtype: dtype to used for the scales, if overriding the
77-
weight dtype as the scale dtype
7875
"""
7976
# TODO: don't initialize parameters when running decompression
8077
scheme = scheme or getattr(module, "quantization_scheme", None)
@@ -93,7 +90,6 @@ def initialize_module_for_quantization(
9390
"input",
9491
scheme.input_activations,
9592
force_zero_point=force_zero_point,
96-
scale_dtype=scale_dtype,
9793
)
9894

9995
if scheme.weights is not None:
@@ -107,7 +103,6 @@ def initialize_module_for_quantization(
107103
scheme.weights,
108104
weight_shape=weight_shape,
109105
force_zero_point=force_zero_point,
110-
scale_dtype=scale_dtype,
111106
)
112107
else:
113108
_LOGGER.warning(
@@ -119,7 +114,7 @@ def initialize_module_for_quantization(
119114
if scheme.output_activations is not None:
120115
if not is_kv_cache_quant_scheme(scheme):
121116
_initialize_scale_zero_point(
122-
module, "output", scheme.output_activations, scale_dtype=scale_dtype
117+
module, "output", scheme.output_activations
123118
)
124119

125120
module.quantization_scheme = scheme
@@ -145,7 +140,6 @@ def _initialize_scale_zero_point(
145140
quantization_args: QuantizationArgs,
146141
weight_shape: Optional[torch.Size] = None,
147142
force_zero_point: bool = True,
148-
scale_dtype: Optional[torch.dtype] = None,
149143
):
150144
if quantization_args.dynamic is True:
151145
return
@@ -213,7 +207,7 @@ def _initialize_scale_zero_point(
213207
expected_shape = 1
214208

215209
# 3. Identify quantization scale and zp dtype
216-
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
210+
scale_dtype = module.weight.dtype
217211

218212
if is_fp4(quantization_args=quantization_args):
219213
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
@@ -226,7 +220,7 @@ def _initialize_scale_zero_point(
226220
torch.float32,
227221
torch.float64,
228222
]:
229-
scale_dtype = torch.float16
223+
scale_dtype = torch.bfloat16
230224
zp_dtype = quantization_args.pytorch_dtype()
231225

232226
# 4. Initializes empty scale, zero point, and g_idx parameters for the module

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,26 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
6060
format = model.format
6161

6262
if inputs is not None:
63+
if inputs.strategy not in (
64+
QuantizationStrategy.TOKEN,
65+
QuantizationStrategy.TENSOR,
66+
QuantizationStrategy.GROUP,
67+
QuantizationStrategy.TENSOR_GROUP,
68+
):
69+
if (
70+
inputs.strategy == QuantizationStrategy.GROUP
71+
and inputs.dynamic is True
72+
):
73+
raise NotImplementedError(
74+
"Static and local group-wise activation "
75+
"quantization is not supported"
76+
)
77+
78+
raise NotImplementedError(
79+
f"Using {inputs.strategy} strategy is not supported for "
80+
"activation quantization"
81+
)
82+
6383
if inputs.actorder is not None:
6484
raise ValueError("Cannot apply actorder to input activations")
6585

src/compressed_tensors/transform/factory/base.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020
import torch.nn.utils.parametrize as P
21+
import tqdm
2122
from compressed_tensors.registry.registry import RegistryMixin, T
2223
from compressed_tensors.transform import (
2324
TransformArgs,
@@ -84,15 +85,21 @@ def create_transform(self, module: Module, args: TransformArgs) -> "TransformBas
8485
"""
8586
raise NotImplementedError()
8687

87-
def apply_to_model(self, model: Module):
88+
def apply_to_model(self, model: Module, use_tqdm=True):
8889
"""
8990
Create transforms and apply them to the model
9091
9192
:param model: module to apply transforms to
9293
"""
93-
for arg in self.scheme.apply:
94-
for _, module in match_named_modules(model, arg.targets, arg.ignore):
95-
self._apply_to_module(module, arg)
94+
modules_args = [
95+
(module, arg)
96+
for arg in self.scheme.apply
97+
for _, module in match_named_modules(model, arg.targets, arg.ignore)
98+
]
99+
100+
desc = f"Applying {self.name} transforms"
101+
for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
102+
self._apply_to_module(module, arg)
96103

97104
self._update_tied_weights()
98105

0 commit comments

Comments
 (0)