From df873fb430df303225cd8f7f60acd25bda3d3514 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 21 Aug 2025 19:18:05 +0000 Subject: [PATCH 01/35] squashed/rebased Signed-off-by: Brian Dellabetta --- .../quantization/lifecycle/apply.py | 78 +++++++++++-------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 828f51ec8..94626055b 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -38,7 +38,11 @@ infer_quantization_status, is_kv_cache_quant_scheme, ) -from compressed_tensors.utils.helpers import deprecated, replace_module +from compressed_tensors.utils.helpers import ( + fix_fsdp_module_name, + deprecated, + replace_module, +) from compressed_tensors.utils.match import match_named_modules, match_targets from compressed_tensors.utils.offload import update_parameter_data from compressed_tensors.utils.safetensors_load import get_safetensors_folder @@ -142,36 +146,48 @@ def apply_quantization_config( for target in scheme.targets: target_to_scheme[target] = scheme - if run_compressed: - from compressed_tensors.linear.compressed_linear import CompressedLinear - - # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in match_named_modules( - model, target_to_scheme, config.ignore, warn_on_fail=True - ): - # mark modules to be quantized by adding - # quant scheme to the matching layers - matched_targets = match_targets(name, submodule, target_to_scheme) - scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) - if run_compressed: - format = config.format - if format != CompressionFormat.dense.value: - if isinstance(submodule, torch.nn.Linear): - # TODO: expand to more module types - compressed_linear = CompressedLinear.from_linear( - submodule, - quantization_scheme=scheme, - quantization_format=format, - ) - replace_module(model, name, compressed_linear) - - # target matched - add layer and scheme to target list - submodule.quantization_scheme = scheme - - names_to_scheme[name] = submodule.quantization_scheme - - # apply current quantization status across all targeted layers - apply_quantization_status(model, config.quantization_status) + # mark appropriate layers for quantization by setting their quantization schemes + for name, submodule in match_named_modules( + model, scheme.targets, config.ignore, warn_on_fail=True + ): + # potentially fix module name to remove FSDP wrapper prefix + name = fix_fsdp_module_name(name) + + # mark modules to be quantized by adding + # quant scheme to the matching layers + scheme = _scheme_from_targets(target_to_scheme, scheme.targets, name) + if run_compressed: + format = config.format + if format != CompressionFormat.dense.value: + if isinstance(submodule, torch.nn.Linear): + from compressed_tensors.linear.compressed_linear import ( + CompressedLinear, + ) + + compressed_linear = CompressedLinear.from_linear( + submodule, + quantization_scheme=scheme, + quantization_format=format, + ) + replace_module(model, name, compressed_linear) + + # target matched - add layer and scheme to target list + submodule.quantization_scheme = scheme + + names_to_scheme[name] = submodule.quantization_scheme + + # apply current quantization status to each targeted submodule + apply_quantization_status(submodule, config.quantization_status) + + # TODO warn on ignore not being found, this is useful in debugging + # if config.ignore is not None and ignored_submodules is not None: + # if set(config.ignore) - set(ignored_submodules): + # _LOGGER.warning( + # "Some layers that were to be ignored were " + # "not found in the model: " + # f"{set(config.ignore) - set(ignored_submodules)}" + # ) + return names_to_scheme From 2a58648d08deab6b8b17aa18b9663f59a01d97b1 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 21 Aug 2025 19:24:14 +0000 Subject: [PATCH 02/35] cleanup Signed-off-by: Brian Dellabetta --- .../quantization/lifecycle/apply.py | 63 ++++++++++--------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 94626055b..529e79648 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -146,38 +146,39 @@ def apply_quantization_config( for target in scheme.targets: target_to_scheme[target] = scheme - # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in match_named_modules( - model, scheme.targets, config.ignore, warn_on_fail=True + # mark appropriate layers for quantization by setting their quantization schemes + for name, submodule in match_named_modules( + model, target_to_scheme, config.ignore, warn_on_fail=True + ): + # potentially fix module name to remove FSDP wrapper prefix + name = fix_fsdp_module_name(name) + + # mark modules to be quantized by adding + # quant scheme to the matching layers + scheme = _scheme_from_targets(target_to_scheme, scheme.targets, name) + if ( + run_compressed + and config.format != CompressionFormat.dense.value + and isinstance(submodule, torch.nn.Linear) ): - # potentially fix module name to remove FSDP wrapper prefix - name = fix_fsdp_module_name(name) - - # mark modules to be quantized by adding - # quant scheme to the matching layers - scheme = _scheme_from_targets(target_to_scheme, scheme.targets, name) - if run_compressed: - format = config.format - if format != CompressionFormat.dense.value: - if isinstance(submodule, torch.nn.Linear): - from compressed_tensors.linear.compressed_linear import ( - CompressedLinear, - ) - - compressed_linear = CompressedLinear.from_linear( - submodule, - quantization_scheme=scheme, - quantization_format=format, - ) - replace_module(model, name, compressed_linear) - - # target matched - add layer and scheme to target list - submodule.quantization_scheme = scheme - - names_to_scheme[name] = submodule.quantization_scheme - - # apply current quantization status to each targeted submodule - apply_quantization_status(submodule, config.quantization_status) + from compressed_tensors.linear.compressed_linear import ( + CompressedLinear, + ) + + compressed_linear = CompressedLinear.from_linear( + submodule, + quantization_scheme=scheme, + quantization_format=config.format, + ) + replace_module(model, name, compressed_linear) + + # target matched - add layer and scheme to target list + submodule.quantization_scheme = scheme + + names_to_scheme[name] = submodule.quantization_scheme + + # apply current quantization status to each targeted submodule + apply_quantization_status(submodule, config.quantization_status) # TODO warn on ignore not being found, this is useful in debugging # if config.ignore is not None and ignored_submodules is not None: From 7cdd1cd3d57b9412972de2d7ef2369ad0bbce1ad Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 21 Aug 2025 19:25:03 +0000 Subject: [PATCH 03/35] remove TODO Signed-off-by: Brian Dellabetta --- src/compressed_tensors/quantization/lifecycle/apply.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 529e79648..e9ad51783 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -180,15 +180,6 @@ def apply_quantization_config( # apply current quantization status to each targeted submodule apply_quantization_status(submodule, config.quantization_status) - # TODO warn on ignore not being found, this is useful in debugging - # if config.ignore is not None and ignored_submodules is not None: - # if set(config.ignore) - set(ignored_submodules): - # _LOGGER.warning( - # "Some layers that were to be ignored were " - # "not found in the model: " - # f"{set(config.ignore) - set(ignored_submodules)}" - # ) - return names_to_scheme From 78d274d971b29cfe2c70dcd6e8d41ac61fcea597 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 21 Aug 2025 19:26:10 +0000 Subject: [PATCH 04/35] more clenaup Signed-off-by: Brian Dellabetta --- .../quantization/lifecycle/apply.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index e9ad51783..d61a57823 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -38,11 +38,7 @@ infer_quantization_status, is_kv_cache_quant_scheme, ) -from compressed_tensors.utils.helpers import ( - fix_fsdp_module_name, - deprecated, - replace_module, -) +from compressed_tensors.utils.helpers import deprecated, replace_module from compressed_tensors.utils.match import match_named_modules, match_targets from compressed_tensors.utils.offload import update_parameter_data from compressed_tensors.utils.safetensors_load import get_safetensors_folder @@ -150,9 +146,6 @@ def apply_quantization_config( for name, submodule in match_named_modules( model, target_to_scheme, config.ignore, warn_on_fail=True ): - # potentially fix module name to remove FSDP wrapper prefix - name = fix_fsdp_module_name(name) - # mark modules to be quantized by adding # quant scheme to the matching layers scheme = _scheme_from_targets(target_to_scheme, scheme.targets, name) @@ -161,9 +154,7 @@ def apply_quantization_config( and config.format != CompressionFormat.dense.value and isinstance(submodule, torch.nn.Linear) ): - from compressed_tensors.linear.compressed_linear import ( - CompressedLinear, - ) + from compressed_tensors.linear.compressed_linear import CompressedLinear compressed_linear = CompressedLinear.from_linear( submodule, From 606f1779fb7c169f28ab8a76a287b2cd5d7ed22c Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 21 Aug 2025 19:37:09 +0000 Subject: [PATCH 05/35] cleanup Signed-off-by: Brian Dellabetta --- src/compressed_tensors/quantization/lifecycle/apply.py | 3 ++- src/compressed_tensors/utils/match.py | 6 +----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index d61a57823..e506deac4 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -148,7 +148,8 @@ def apply_quantization_config( ): # mark modules to be quantized by adding # quant scheme to the matching layers - scheme = _scheme_from_targets(target_to_scheme, scheme.targets, name) + matched_targets = match_targets(name, submodule, target_to_scheme) + scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) if ( run_compressed and config.format != CompressionFormat.dense.value diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 11e2a2a1c..ee0110673 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -146,11 +146,7 @@ def match_targets( targets = sorted(targets, key=lambda x: ("re:" in x, x)) matched_targets = [] for target in targets: - if _match_name(name, target): - matched_targets.append(target) - - for target in targets: - if _match_class(module, target) and target not in matched_targets: + if _match_name(name, target) or _match_class(module, target): matched_targets.append(target) return matched_targets From 829b7cbdad0705a938ffefd6245749d0f7396f4b Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 21 Aug 2025 19:38:51 +0000 Subject: [PATCH 06/35] formatting Signed-off-by: Brian Dellabetta --- .../model_compressors/model_compressor.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 2d88a3b02..f1d3f94fc 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -201,9 +201,11 @@ def from_pretrained_model( sparsity_config=sparsity_config, quantization_config=quantization_config, transform_config=transform_config, - compression_formats=[quantization_format] - if isinstance(quantization_format, str) - else quantization_format, + compression_formats=( + [quantization_format] + if isinstance(quantization_format, str) + else quantization_format + ), ) @staticmethod @@ -314,10 +316,10 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: - self.quantization_compressor[ - format - ] = BaseCompressor.load_from_registry( - format, config=quantization_config + self.quantization_compressor[format] = ( + BaseCompressor.load_from_registry( + format, config=quantization_config + ) ) # ----- used by hf quantizer ----- # From 7f2c5deac0bdbc289de8d7cb8af807f0e6c3436b Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 21 Aug 2025 19:43:18 +0000 Subject: [PATCH 07/35] formatting Signed-off-by: Brian Dellabetta --- .../compressors/model_compressors/model_compressor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index f1d3f94fc..fc4a15f3e 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -316,10 +316,10 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: - self.quantization_compressor[format] = ( - BaseCompressor.load_from_registry( - format, config=quantization_config - ) + self.quantization_compressor[ + format + ] = BaseCompressor.load_from_registry( + format, config=quantization_config ) # ----- used by hf quantizer ----- # From b515c1bdd58d970036e6c499ceea1dbe6d235ce0 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 28 Aug 2025 12:59:15 -0400 Subject: [PATCH 08/35] resolve redundant merge code Signed-off-by: Brian Dellabetta --- src/compressed_tensors/quantization/lifecycle/apply.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index d83da2abf..8bf8d7e8d 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -125,6 +125,7 @@ def apply_quantization_config( :param run_compressed: Whether the model will be run in compressed mode or decompressed fully on load """ + from compressed_tensors.linear.compressed_linear import CompressedLinear config = deepcopy(config) if config is None: # see PR #180 @@ -148,7 +149,6 @@ def apply_quantization_config( # quant scheme to the matching layers matched_targets = match_targets(name, submodule, target_to_scheme) scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) - # target matched - add layer and scheme to target list submodule.quantization_scheme = scheme @@ -159,8 +159,6 @@ def apply_quantization_config( and isinstance(submodule, torch.nn.Linear) and config.format != CompressionFormat.dense.value ): - from compressed_tensors.linear.compressed_linear import CompressedLinear - # TODO: expand to more module types compressed_linear = CompressedLinear.from_linear( submodule, @@ -169,9 +167,6 @@ def apply_quantization_config( ) replace_module(model, name, compressed_linear) - # target matched - add layer and scheme to target list - submodule.quantization_scheme = scheme - # apply current quantization status to each targeted submodule apply_quantization_status(submodule, config.quantization_status) From 0e11f93bc8c14e55ed8d02055280904510ff3f6b Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 28 Aug 2025 13:10:19 -0400 Subject: [PATCH 09/35] style fixes Signed-off-by: Brian Dellabetta --- .../compressors/model_compressors/model_compressor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 74a0d3944..10caee7b9 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -316,10 +316,10 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: - self.quantization_compressor[format] = ( - BaseCompressor.load_from_registry( - format, config=quantization_config - ) + self.quantization_compressor[ + format + ] = BaseCompressor.load_from_registry( + format, config=quantization_config ) # ----- used by hf quantizer ----- # From 80844ea7c37978961a6d32630b1384a4d17e88f9 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 3 Sep 2025 16:22:56 +0000 Subject: [PATCH 10/35] cleanup / test fixes Signed-off-by: Brian Dellabetta --- .../quantization/lifecycle/apply.py | 2 +- .../quantization/lifecycle/compressed.py | 3 +- .../test_model_compressor.py | 31 +++++++++++-------- 3 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 8bf8d7e8d..27a5d8d32 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -235,7 +235,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): ) ) - if current_status < status >= QuantizationStatus.COMPRESSED > current_status: + if status >= QuantizationStatus.COMPRESSED > current_status: model.apply(compress_quantized_weights) diff --git a/src/compressed_tensors/quantization/lifecycle/compressed.py b/src/compressed_tensors/quantization/lifecycle/compressed.py index 00f707920..ee717e399 100644 --- a/src/compressed_tensors/quantization/lifecycle/compressed.py +++ b/src/compressed_tensors/quantization/lifecycle/compressed.py @@ -42,7 +42,8 @@ def compress_quantized_weights(module: Module): # no quantization scheme or weights not quantized, nothing to do return - if scheme is QuantizationStatus.COMPRESSED: + status = getattr(module, "quantization_status", None) + if status is QuantizationStatus.COMPRESSED: # module is already compressed, nothing to do return diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index dc48870b3..8646f6394 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -118,16 +118,14 @@ def __init__(self, weights, weight_scale=None, weight_zero_point=None): self.linear = nn.Linear(in_features, out_features, bias=False) # Set the weights of the linear layer - self.linear.weight = nn.Parameter(weights, requires_grad=False) + self.linear.weight = nn.Parameter(weights.detach().clone()) # Attach weight_scale and weight_zero_point as parameters if weight_scale is not None: - self.linear.weight_scale = nn.Parameter( - torch.tensor(weight_scale), requires_grad=False - ) + self.linear.weight_scale = nn.Parameter(weight_scale.detach().clone()) if weight_zero_point is not None: self.linear.weight_zero_point = nn.Parameter( - torch.tensor(weight_zero_point), requires_grad=False + weight_zero_point.detach().clone() ) def forward(self, x): @@ -443,9 +441,7 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir): ) def test_compress_model_meta(model_stub, q_format, s_config): # Load model on CPU to get expected compressed state_dict - cpu_model = AutoModelForCausalLM.from_pretrained( - model_stub, torch_dtype=torch.float32 - ) + cpu_model = AutoModelForCausalLM.from_pretrained(model_stub) reference_compressor = ModelCompressor.from_pretrained_model( cpu_model, s_config, [q_format] ) @@ -455,7 +451,6 @@ def test_compress_model_meta(model_stub, q_format, s_config): # Load model on meta device meta_model = AutoModelForCausalLM.from_pretrained( model_stub, - torch_dtype=torch.float32, low_cpu_mem_usage=True, ) for module in meta_model.modules(): @@ -542,7 +537,6 @@ def test_decompress_model(model_stub, comp_stub): true_decompressed_model = AutoModelForCausalLM.from_pretrained( comp_stub, quantization_config=CompressedTensorsConfig(run_compressed=False), - torch_dtype=torch.float32, ) true_decompressed = dict(true_decompressed_model.state_dict()) true_decompressed = remove_empty_weight_zero_points(true_decompressed) # see above @@ -551,7 +545,7 @@ def test_decompress_model(model_stub, comp_stub): # NOTE there is no other way to load a compressed model into memory, since # there is no way to turn off decompression for sparse models # https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_compressed_tensors.py#L133 - model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32) + model = AutoModelForCausalLM.from_pretrained(model_stub) compressor = ModelCompressor.from_pretrained(comp_stub) compressor.compress_model(model) compressor.decompress_model(model) @@ -566,8 +560,12 @@ def test_decompress_model(model_stub, comp_stub): # equivalent to decompressing from disk assert decompressed.keys() == true_decompressed.keys() for key in decompressed.keys(): - assert decompressed[key].dtype == true_decompressed[key].dtype - assert torch.all(decompressed[key] == true_decompressed[key]), f"{key}" + assert ( + decompressed[key].dtype == true_decompressed[key].dtype + ), f"{key} dtypes not equal" + assert torch.all( + decompressed[key] == true_decompressed[key] + ), f"{key} values not equal" def remove_empty_weight_zero_points(state_dict): @@ -576,3 +574,10 @@ def remove_empty_weight_zero_points(state_dict): for name, value in state_dict.items() if not (name.endswith("weight_zero_point") and torch.all(value == 0)) } + + +if __name__ == "__main__": + test_decompress_model( + "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed", + "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed", + ) From f62b70cece470eab3e2611969cfc0e6f547205a7 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 4 Sep 2025 21:44:56 +0000 Subject: [PATCH 11/35] test fixes Signed-off-by: Brian Dellabetta --- .../model_compressors/test_model_compressor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index 8646f6394..cf670d5c7 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -537,6 +537,7 @@ def test_decompress_model(model_stub, comp_stub): true_decompressed_model = AutoModelForCausalLM.from_pretrained( comp_stub, quantization_config=CompressedTensorsConfig(run_compressed=False), + torch_dtype=torch.float32, ) true_decompressed = dict(true_decompressed_model.state_dict()) true_decompressed = remove_empty_weight_zero_points(true_decompressed) # see above @@ -545,7 +546,7 @@ def test_decompress_model(model_stub, comp_stub): # NOTE there is no other way to load a compressed model into memory, since # there is no way to turn off decompression for sparse models # https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_compressed_tensors.py#L133 - model = AutoModelForCausalLM.from_pretrained(model_stub) + model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32) compressor = ModelCompressor.from_pretrained(comp_stub) compressor.compress_model(model) compressor.decompress_model(model) From ac1ce1cc859162583726d03d79553c8f353393ea Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 9 Sep 2025 18:32:57 +0000 Subject: [PATCH 12/35] formatting/touchups Signed-off-by: Brian Dellabetta --- .../model_compressors/model_compressor.py | 27 ++++--------------- .../quantization/utils/helpers.py | 12 ++++----- 2 files changed, 11 insertions(+), 28 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 10caee7b9..921658360 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -51,6 +51,7 @@ get_safetensors_folder, has_offloaded_params, merge_names, + patch_attr, register_offload_parameter, update_parameter_data, ) @@ -702,8 +703,10 @@ def decompress(self, model_path: str, model: Module): # that the dtypes of the weights are not unintentionally updated. # The status is restored after quantization params are loaded. - with override_quantization_status( - self.quantization_config, QuantizationStatus.FROZEN + with patch_attr( + self.quantization_config, + "quantization_status", + QuantizationStatus.FROZEN, ): apply_quantization_config(model, self.quantization_config) names_to_scheme: Set[QuantizationScheme] = { @@ -891,23 +894,3 @@ def new_dtype_byte_size(dtype): raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") bit_size = int(bit_search.groups()[0]) return bit_size // 8 - - -@contextmanager -def override_quantization_status( - config: QuantizationConfig, status: QuantizationStatus -): - """ - Within this context, the quantization status will be set to the - supplied status. After the context exits, the original status - will be restored. - - :param config: the quantization config to override - :param status: the status to temporarily set - """ - original_status = config.quantization_status - config.quantization_status = status - try: - yield - finally: - config.quantization_status = original_status diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 73d545193..68954b5b9 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -17,6 +17,7 @@ from typing import Generator, List, Optional, Tuple import torch +from compressed_tensors.quantization import QuantizationScheme, QuantizationStatus from compressed_tensors.quantization.quant_args import ( FP4_E2M1_DATA, FP8_E4M3_DATA, @@ -25,7 +26,6 @@ QuantizationStrategy, QuantizationType, ) -from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.utils import deprecated from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module @@ -234,16 +234,16 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: return q_min, q_max -def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa +def infer_quantization_status(module: Module) -> Optional["QuantizationStatus"]: # noqa """ - Checks the quantization status of a model. Assumes all modules in the model have + Checks the quantization status of a module. Assumes all modules in the model have the same status, so only the first quantized model is checked. - :param model: model to check quantization status for + :param module: module to check quantization status for :return: quantization status if the model is quantized, otherwise None """ - for module in model.modules(): - status = getattr(module, "quantization_status", None) + for submodule in module.modules(): + status = getattr(submodule, "quantization_status", None) if status is not None: return status return None From 0da87300b420fd7c673836eaf28ff44193463dd3 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 9 Sep 2025 19:00:27 +0000 Subject: [PATCH 13/35] stylefix Signed-off-by: Brian Dellabetta --- .../compressors/model_compressors/model_compressor.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 921658360..eb039ebf7 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -17,7 +17,6 @@ import operator import os import re -from contextlib import contextmanager from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union @@ -317,10 +316,10 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: - self.quantization_compressor[ - format - ] = BaseCompressor.load_from_registry( - format, config=quantization_config + self.quantization_compressor[format] = ( + BaseCompressor.load_from_registry( + format, config=quantization_config + ) ) # ----- used by hf quantizer ----- # From 5744d73e6a8d2e8ad0020c64a9e5a8ce40a7cb72 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 9 Sep 2025 20:45:38 +0000 Subject: [PATCH 14/35] stylefixes Signed-off-by: Brian Dellabetta --- .../compressors/model_compressors/model_compressor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index eb039ebf7..b3e880d6a 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -316,10 +316,10 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: - self.quantization_compressor[format] = ( - BaseCompressor.load_from_registry( - format, config=quantization_config - ) + self.quantization_compressor[ + format + ] = BaseCompressor.load_from_registry( + format, config=quantization_config ) # ----- used by hf quantizer ----- # From 304615e2d0fe0e088c76f4a66da07a11df87c431 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 9 Sep 2025 20:53:28 +0000 Subject: [PATCH 15/35] stylefixes Signed-off-by: Brian Dellabetta --- src/compressed_tensors/quantization/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 68954b5b9..e14344d46 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -17,7 +17,6 @@ from typing import Generator, List, Optional, Tuple import torch -from compressed_tensors.quantization import QuantizationScheme, QuantizationStatus from compressed_tensors.quantization.quant_args import ( FP4_E2M1_DATA, FP8_E4M3_DATA, @@ -26,6 +25,7 @@ QuantizationStrategy, QuantizationType, ) +from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.utils import deprecated from torch import FloatTensor, IntTensor, Tensor from torch.nn import Module From 76f81b9fa3cde886c0765c2ad965f03386508a9f Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 9 Sep 2025 21:30:45 +0000 Subject: [PATCH 16/35] remaining test fixes Signed-off-by: Brian Dellabetta --- .../quantization/utils/helpers.py | 16 ---------------- src/compressed_tensors/utils/match.py | 10 +++++++--- tests/test_quantization/lifecycle/test_apply.py | 4 ++-- 3 files changed, 9 insertions(+), 21 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index e14344d46..1b6937d47 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -32,7 +32,6 @@ __all__ = [ - "infer_quantization_status", "is_module_quantized", "is_model_quantized", "module_type", @@ -234,21 +233,6 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: return q_min, q_max -def infer_quantization_status(module: Module) -> Optional["QuantizationStatus"]: # noqa - """ - Checks the quantization status of a module. Assumes all modules in the model have - the same status, so only the first quantized model is checked. - - :param module: module to check quantization status for - :return: quantization status if the model is quantized, otherwise None - """ - for submodule in module.modules(): - status = getattr(submodule, "quantization_status", None) - if status is not None: - return status - return None - - def is_module_quantized(module: Module) -> bool: """ Check if a module is quantized, based on the existence of a non-empty quantization diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index ee0110673..b96e83d04 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -136,17 +136,21 @@ def match_targets( if isinstance(module, InternalModule): return [] - # The order of the output `matches` list matters, the are arranged from most + # The order of the output `matches` list matters, they are arranged from most # specific to least specific, and this order will be used when merging configs. # The entries are sorted in the following order: # 1. matches on exact strings # 2. matches on regex patterns - # 3. matches on module names + # 3. matches on module names (e.g. "Linear") targets = sorted(targets, key=lambda x: ("re:" in x, x)) matched_targets = [] for target in targets: - if _match_name(name, target) or _match_class(module, target): + if _match_name(name, target): + matched_targets.append(target) + + for target in targets: + if _match_class(module, target) and target not in matched_targets: matched_targets.append(target) return matched_targets diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index d5fd6c2cd..e6352dad4 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -66,7 +66,7 @@ def test_target_prioritization(mock_frozen): "config_groups": { "group_1": { "weights": { - "num_bits": 8, + "num_bits": 6, }, "targets": ["Linear"], }, @@ -101,7 +101,7 @@ def test_target_prioritization(mock_frozen): elif re.match(".*down_proj", name): assert module.quantization_scheme.weights.num_bits == 4 elif isinstance(module, torch.nn.Linear): - assert module.quantization_scheme.weights.num_bits == 8 + assert module.quantization_scheme.weights.num_bits == 6 def test_apply_quantization_config_tinyllama(): From 5bf957f5ef73eb5a4f135631794b7858ffd9b0a2 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 9 Sep 2025 21:32:24 +0000 Subject: [PATCH 17/35] revert extraneous test change Signed-off-by: Brian Dellabetta --- tests/test_quantization/lifecycle/test_apply.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index e6352dad4..d5fd6c2cd 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -66,7 +66,7 @@ def test_target_prioritization(mock_frozen): "config_groups": { "group_1": { "weights": { - "num_bits": 6, + "num_bits": 8, }, "targets": ["Linear"], }, @@ -101,7 +101,7 @@ def test_target_prioritization(mock_frozen): elif re.match(".*down_proj", name): assert module.quantization_scheme.weights.num_bits == 4 elif isinstance(module, torch.nn.Linear): - assert module.quantization_scheme.weights.num_bits == 6 + assert module.quantization_scheme.weights.num_bits == 8 def test_apply_quantization_config_tinyllama(): From 360a9fbc536507bbcdf9e2be97b9f316b7dcdd2a Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 9 Sep 2025 21:41:14 +0000 Subject: [PATCH 18/35] remove test running code Signed-off-by: Brian Dellabetta --- .../model_compressors/test_model_compressor.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/test_compressors/model_compressors/test_model_compressor.py b/tests/test_compressors/model_compressors/test_model_compressor.py index cf670d5c7..f833b8290 100644 --- a/tests/test_compressors/model_compressors/test_model_compressor.py +++ b/tests/test_compressors/model_compressors/test_model_compressor.py @@ -575,10 +575,3 @@ def remove_empty_weight_zero_points(state_dict): for name, value in state_dict.items() if not (name.endswith("weight_zero_point") and torch.all(value == 0)) } - - -if __name__ == "__main__": - test_decompress_model( - "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed", - "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed", - ) From cc275687317a4ebc493c4f035141ce7d670393e9 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 9 Sep 2025 21:47:11 +0000 Subject: [PATCH 19/35] remove infer_quantization_status Signed-off-by: Brian Dellabetta --- src/compressed_tensors/quantization/lifecycle/apply.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 27a5d8d32..96d3e0ede 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -35,7 +35,6 @@ from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import ( KV_CACHE_TARGETS, - infer_quantization_status, is_kv_cache_quant_scheme, ) from compressed_tensors.utils.helpers import deprecated, replace_module @@ -215,9 +214,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): :param status: status to update the module to """ - current_status = infer_quantization_status(model) - - if status >= QuantizationStatus.INITIALIZED > current_status: + if status >= QuantizationStatus.INITIALIZED: force_zero_point_init = status != QuantizationStatus.COMPRESSED # When decompressing, we set the scale_dtype as the model's dtype @@ -235,7 +232,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): ) ) - if status >= QuantizationStatus.COMPRESSED > current_status: + if status >= QuantizationStatus.COMPRESSED: model.apply(compress_quantized_weights) From fc2e1024d457dbbab1cce0152504a9d5a884ecf7 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 10 Sep 2025 19:31:41 +0000 Subject: [PATCH 20/35] lifecycle updates for overwriting config Signed-off-by: Brian Dellabetta --- .../quantization/lifecycle/apply.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 96d3e0ede..8c1c53f7f 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -206,34 +206,33 @@ def process_kv_cache_config( return config -def apply_quantization_status(model: Module, status: QuantizationStatus): +def apply_quantization_status(module: Module, status: QuantizationStatus): """ Applies in place the quantization lifecycle up to the given status - :param model: model to apply quantization to + :param module: module to apply quantization to :param status: status to update the module to """ - if status >= QuantizationStatus.INITIALIZED: - force_zero_point_init = status != QuantizationStatus.COMPRESSED - - # When decompressing, we set the scale_dtype as the model's dtype - # This is because the normal workflow of using the weight's dtype - # will be incorrect as the model weight will be compressed - # Therfore, use the dtype set by the user using the PretrainedModel - scale_dtype = None - if status == QuantizationStatus.FROZEN: - if hasattr(model, "dtype"): - scale_dtype = model.dtype - - model.apply( - lambda module: initialize_module_for_quantization( - module, force_zero_point=force_zero_point_init, scale_dtype=scale_dtype - ) + force_zero_point_init = status != QuantizationStatus.COMPRESSED + + # When decompressing, we set the scale_dtype as the model's dtype + # This is because the normal workflow of using the weight's dtype + # will be incorrect as the model weight will be compressed + # Therfore, use the dtype set by the user using the PretrainedModel + scale_dtype = None + if status == QuantizationStatus.FROZEN: + if hasattr(module, "dtype"): + scale_dtype = module.dtype + + module.apply( + lambda module: initialize_module_for_quantization( + module, force_zero_point=force_zero_point_init, scale_dtype=scale_dtype ) + ) if status >= QuantizationStatus.COMPRESSED: - model.apply(compress_quantized_weights) + module.apply(compress_quantized_weights) @deprecated( From 14a359f78a851702af3d607b5a31658b2e31cfbf Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 10 Sep 2025 22:51:12 +0000 Subject: [PATCH 21/35] remove compress_quantized_weight, test fixes, remove sparseml references Signed-off-by: Brian Dellabetta --- examples/quantize_and_pack_int4.ipynb | 2 +- .../quantization/lifecycle/apply.py | 14 ++++---------- .../quantization/quant_config.py | 4 ++-- src/compressed_tensors/utils/helpers.py | 3 --- .../test_quantization/lifecycle/test_apply.py | 19 ++++++++++--------- 5 files changed, 17 insertions(+), 25 deletions(-) diff --git a/examples/quantize_and_pack_int4.ipynb b/examples/quantize_and_pack_int4.ipynb index 8cd58f2f2..e4d654685 100644 --- a/examples/quantize_and_pack_int4.ipynb +++ b/examples/quantize_and_pack_int4.ipynb @@ -144,7 +144,7 @@ "outputs": [], "source": [ "quantization_config_dict = {\n", - "\t\"quant_method\": \"sparseml\",\n", + "\t\"quant_method\": \"compressed-tensors\",\n", "\t\"format\": \"pack-quantized\",\n", "\t\"global_compression_ratio\": None,\n", "\t\"config_groups\": {\n", diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 8c1c53f7f..32dfa8cd1 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -21,9 +21,6 @@ import torch from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization.lifecycle.compressed import ( - compress_quantized_weights, -) from compressed_tensors.quantization.lifecycle.initialize import ( initialize_module_for_quantization, ) @@ -219,20 +216,17 @@ def apply_quantization_status(module: Module, status: QuantizationStatus): # When decompressing, we set the scale_dtype as the model's dtype # This is because the normal workflow of using the weight's dtype # will be incorrect as the model weight will be compressed - # Therfore, use the dtype set by the user using the PretrainedModel + # Therefore, use the dtype set by the user using the PretrainedModel scale_dtype = None if status == QuantizationStatus.FROZEN: if hasattr(module, "dtype"): scale_dtype = module.dtype - module.apply( - lambda module: initialize_module_for_quantization( - module, force_zero_point=force_zero_point_init, scale_dtype=scale_dtype - ) + initialize_module_for_quantization( + module, force_zero_point=force_zero_point_init, scale_dtype=scale_dtype ) - if status >= QuantizationStatus.COMPRESSED: - module.apply(compress_quantized_weights) + module.quantization_status = status @deprecated( diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 42df3a337..994af336f 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -113,8 +113,8 @@ class QuantizationConfig(BaseModel): :param config_groups: dict of QuantizationSchemes specifying the quantization settings for each quantized layer. A group could also be a reference to a predefined scheme name, mapped to a list of its target layers/classes - :param quant_method: a constant used to differentiate sparseML quantization from - other quantization configs + :param quant_method: a constant used to differentiate compressed-tensors + quantization from other quantization configs :param format: specifies how the quantized model is stored on disk :quantization_status: specifies the current status of all quantized layers. It is assumed all layers are in the same state. diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index 712c4f837..1c02fb521 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -71,9 +71,6 @@ def infer_compressor_from_model_config( return compressor -# TODO: There is already the same function in -# SparseML, should be moved to a shared location -# in the future def fix_fsdp_module_name(name: str) -> str: """ Remove FSDP wrapper prefixes from a module name diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index d5fd6c2cd..8d2ca4989 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -25,10 +25,7 @@ QuantizationConfig, QuantizationStatus, ) -from compressed_tensors.quantization.lifecycle import ( - apply_quantization_config, - apply_quantization_status, -) +from compressed_tensors.quantization.lifecycle import apply_quantization_config from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM @@ -105,7 +102,9 @@ def test_target_prioritization(mock_frozen): def test_apply_quantization_config_tinyllama(): - quant_config = get_sample_tinyllama_quant_config(status="calibration") + quant_config = get_sample_tinyllama_quant_config( + status=QuantizationStatus.CALIBRATION + ) model = get_tinyllama_model() # check that model is not already quantized @@ -146,7 +145,8 @@ def test_apply_quantization_config_tinyllama(): # test quantization compression # sample forward pass to fill scales, zps model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int)) - apply_quantization_status(model, QuantizationStatus.COMPRESSED) + quant_config.quantization_status = QuantizationStatus.COMPRESSED + apply_quantization_config(model, quant_config) for name, module in model.named_modules(): if name in quant_config.ignore: continue @@ -157,7 +157,6 @@ def test_apply_quantization_config_tinyllama(): inputs=True, weights=True, expected_status=QuantizationStatus.COMPRESSED, - expected_dtype=torch.int8, ) @@ -218,7 +217,9 @@ def get_tinyllama_model(): ) -def get_sample_tinyllama_quant_config(status: str = "frozen"): +def get_sample_tinyllama_quant_config( + status: QuantizationStatus = QuantizationStatus.FROZEN, +): config_dict = { "quant_method": "compressed-tensors", "format": "fakequant", @@ -270,7 +271,7 @@ def test_apply_quantization_status(caplog, target, should_raise_warning): # load a dense, unquantized tiny llama model model = get_tinyllama_model() quantization_config_dict = { - "quant_method": "sparseml", + "quant_method": "compressed-tensors", "format": "pack-quantized", "global_compression_ratio": None, "config_groups": { From 595228cc775e2471f0051bbb479d51cf8b3f7c5a Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 10 Sep 2025 23:01:02 +0000 Subject: [PATCH 22/35] drop frozen scale_dtype post-merge Signed-off-by: Brian Dellabetta --- .../quantization/lifecycle/apply.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index a35cff543..96f4b8ad4 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -213,18 +213,7 @@ def apply_quantization_status(module: Module, status: QuantizationStatus): force_zero_point_init = status != QuantizationStatus.COMPRESSED - # When decompressing, we set the scale_dtype as the model's dtype - # This is because the normal workflow of using the weight's dtype - # will be incorrect as the model weight will be compressed - # Therefore, use the dtype set by the user using the PretrainedModel - scale_dtype = None - if status == QuantizationStatus.FROZEN: - if hasattr(module, "dtype"): - scale_dtype = module.dtype - - initialize_module_for_quantization( - module, force_zero_point=force_zero_point_init, scale_dtype=scale_dtype - ) + initialize_module_for_quantization(module, force_zero_point=force_zero_point_init) module.quantization_status = status From 49e4e92ede49839161bac7329d9ad7573eb17009 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 10 Sep 2025 23:05:00 +0000 Subject: [PATCH 23/35] formatting Signed-off-by: Brian Dellabetta --- .../compressors/model_compressors/model_compressor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index e49539f76..d8fc42078 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -316,10 +316,10 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: - self.quantization_compressor[format] = ( - BaseCompressor.load_from_registry( - format, config=quantization_config - ) + self.quantization_compressor[ + format + ] = BaseCompressor.load_from_registry( + format, config=quantization_config ) # ----- model memory compression/decompression pathways ----- # From 6ba47e5db0877038b1e5209088801b79f5723288 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 11 Sep 2025 19:52:14 +0000 Subject: [PATCH 24/35] clear previously initialized qparams Signed-off-by: Brian Dellabetta --- .../quantization/lifecycle/apply.py | 8 ----- .../quantization/lifecycle/initialize.py | 33 +++++++++++++++++-- .../test_quantization/lifecycle/test_apply.py | 2 +- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 96f4b8ad4..2ed649ed1 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -243,14 +243,6 @@ def find_name_or_class_matches( return match_targets(name, module, targets) -def _infer_status(model: Module) -> Optional[QuantizationStatus]: - for module in model.modules(): - status = getattr(module, "quantization_status", None) - if status is not None: - return status - return None - - def _load_quant_args_from_mapping( base_name: str, module_name: str, module: Module, mapping: Dict ): diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 5350b4a2c..4a4e4c40b 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -33,6 +33,7 @@ from compressed_tensors.quantization.quant_scheme import QuantizationScheme from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme from compressed_tensors.utils import ( + delete_offload_parameter, disable_hf_hook, get_execution_device, register_offload_parameter, @@ -61,10 +62,11 @@ def initialize_module_for_quantization( force_zero_point: bool = True, ): """ - attaches appropriate scales, zero points, and observers to a layer - given its target quantization scheme + Attaches appropriate scales, zero points, and observers to a layer + given its target quantization scheme. - apply to full model with `model.apply(initialize_module_for_quantization)` + Previously initialized scales and zero points will be removed from + module if they no longer apply to the scheme :param module: module to set for calibration :param scheme: scheme to use for quantization. if None is provided, @@ -73,6 +75,8 @@ def initialize_module_for_quantization( :param force_zero_point: whether to force initialization of a zero point for symmetric quantization """ + _clear_all_qparams(module) + # TODO: don't initialize parameters when running decompression scheme = scheme or getattr(module, "quantization_scheme", None) if scheme is None: @@ -134,6 +138,29 @@ def is_attention_module(module: Module): ) +def _clear_all_qparams( + module: Module, +): + """ + Clear all previously registered quantization parameters from module + + :param module: module to clear qparams from + """ + keys = [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [ + f"{base_name}_{suffix}" + for base_name in ("input", "weight", "output") + for suffix in ( + "global_scale", + "scale", + "zero_point", + "g_idx", + ) + ] + for key in keys: + if hasattr(module, key): + delete_offload_parameter(module, key) + + def _initialize_scale_zero_point( module: Module, base_name: str, diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 8d2ca4989..722795457 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -265,7 +265,7 @@ def get_sample_tinyllama_quant_config( [("Linear", "re:.*foobarbaz"), True], ], ) -def test_apply_quantization_status(caplog, target, should_raise_warning): +def test_apply_quantization_config(caplog, target, should_raise_warning): import logging # load a dense, unquantized tiny llama model From 72e5f3de525e6b9d86a296a7d28b11f19a7923f1 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 11 Sep 2025 21:26:18 +0000 Subject: [PATCH 25/35] remove apply_quantization_status Signed-off-by: Brian Dellabetta --- .../quantization/lifecycle/apply.py | 26 +++------- .../quantization/lifecycle/initialize.py | 4 +- .../test_quantization/lifecycle/test_apply.py | 50 ++++--------------- 3 files changed, 19 insertions(+), 61 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 2ed649ed1..89ac7a887 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -45,7 +45,6 @@ __all__ = [ "load_pretrained_quantization_parameters", "apply_quantization_config", - "apply_quantization_status", "find_name_or_class_matches", ] @@ -163,8 +162,14 @@ def apply_quantization_config( ) replace_module(model, name, compressed_linear) - # apply current quantization status to each targeted submodule - apply_quantization_status(submodule, config.quantization_status) + else: + initialize_module_for_quantization( + submodule, + force_zero_point=config.quantization_status + != QuantizationStatus.COMPRESSED, + ) + + submodule.quantization_status = config.quantization_status def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: @@ -203,21 +208,6 @@ def process_kv_cache_config( return config -def apply_quantization_status(module: Module, status: QuantizationStatus): - """ - Applies in place the quantization lifecycle up to the given status - - :param module: module to apply quantization to - :param status: status to update the module to - """ - - force_zero_point_init = status != QuantizationStatus.COMPRESSED - - initialize_module_for_quantization(module, force_zero_point=force_zero_point_init) - - module.quantization_status = status - - @deprecated( message="This function is deprecated and will be removed in a future release." "Please use `match_targets` from `compressed_tensors.utils.match` instead." diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 4a4e4c40b..0e81e4203 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -75,14 +75,14 @@ def initialize_module_for_quantization( :param force_zero_point: whether to force initialization of a zero point for symmetric quantization """ - _clear_all_qparams(module) - # TODO: don't initialize parameters when running decompression scheme = scheme or getattr(module, "quantization_scheme", None) if scheme is None: # no scheme passed and layer not targeted for quantization - skip return + _clear_all_qparams(module) + if is_attention_module(module): # quantized actions based on calltime status _initialize_attn_scales(module) diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 722795457..473e78d6f 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -26,6 +26,7 @@ QuantizationStatus, ) from compressed_tensors.quantization.lifecycle import apply_quantization_config +from compressed_tensors.utils import match_named_modules from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM @@ -103,7 +104,7 @@ def test_target_prioritization(mock_frozen): def test_apply_quantization_config_tinyllama(): quant_config = get_sample_tinyllama_quant_config( - status=QuantizationStatus.CALIBRATION + status=QuantizationStatus.INITIALIZED ) model = get_tinyllama_model() @@ -111,52 +112,19 @@ def test_apply_quantization_config_tinyllama(): for module in model.modules(): _test_layer_quantization_status(module, inputs=False, weights=False) - count_layer_names = ("Linear", "Embeddidng", "LlamaRotaryEmbedding") - count_layer_num = defaultdict(int) - - for name, module in model.named_modules(): - if name in quant_config.ignore: - continue - module_type = module.__class__.__name__ - if module_type in count_layer_names: - count_layer_num[module_type] += 1 - - assert len(count_layer_num) > 0, f"None of {count_layer_names} found in model" - assert all(value > 0 for value in count_layer_num.values()) - # apply quant config to model apply_quantization_config(model, quant_config) # check for correct application of quant config - for name, module in model.named_modules(): - if name in quant_config.ignore: - continue - module_type = module.__class__.__name__ - if module_type in count_layer_names: - count_layer_num[module_type] -= 1 - _inputs = module_type == "Linear" - _weights = not module_type == "LlamaRotaryEmbedding" - _test_layer_quantization_status(module, inputs=_inputs, weights=_weights) - - assert all( - value == 0 for value in count_layer_num.values() - ), "Not all values are zero" - - # test quantization compression - # sample forward pass to fill scales, zps - model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int)) - quant_config.quantization_status = QuantizationStatus.COMPRESSED - apply_quantization_config(model, quant_config) - for name, module in model.named_modules(): - if name in quant_config.ignore: - continue - module_type = module.__class__.__name__ - if module_type == "Linear": + for quant_scheme in quant_config.config_groups.values(): + for name, module in match_named_modules( + model, quant_scheme.targets, quant_config.ignore + ): _test_layer_quantization_status( module, - inputs=True, - weights=True, - expected_status=QuantizationStatus.COMPRESSED, + inputs=quant_scheme.input_activations is not None, + weights=quant_scheme.weights is not None, + expected_status=QuantizationStatus.INITIALIZED, ) From 88b8865836e104fd774e606e5074e7d319700639 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 11 Sep 2025 21:29:46 +0000 Subject: [PATCH 26/35] stylefix Signed-off-by: Brian Dellabetta --- tests/test_quantization/lifecycle/test_apply.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 473e78d6f..157695470 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -13,7 +13,6 @@ # limitations under the License. import re -from collections import defaultdict from typing import Optional from unittest.mock import MagicMock From fb6aa9ad078d0d22a29115b02215e81543028bfa Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 11 Sep 2025 22:12:10 +0000 Subject: [PATCH 27/35] add ALL_QPARAM_KEYS var Signed-off-by: Brian Dellabetta --- .../quantization/lifecycle/initialize.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 0e81e4203..008d57e2c 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -45,6 +45,7 @@ "initialize_module_for_quantization", "is_attention_module", "KVCacheScaleType", + "ALL_QPARAM_KEYS", ] @@ -56,6 +57,18 @@ class KVCacheScaleType(Enum): VALUE = "v_scale" +ALL_QPARAM_KEYS = [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [ + f"{base_name}_{suffix}" + for base_name in ("input", "weight", "output") + for suffix in ( + "global_scale", + "scale", + "zero_point", + "g_idx", + ) +] + + def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, @@ -146,17 +159,7 @@ def _clear_all_qparams( :param module: module to clear qparams from """ - keys = [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [ - f"{base_name}_{suffix}" - for base_name in ("input", "weight", "output") - for suffix in ( - "global_scale", - "scale", - "zero_point", - "g_idx", - ) - ] - for key in keys: + for key in ALL_QPARAM_KEYS: if hasattr(module, key): delete_offload_parameter(module, key) From d2903a18447fa02f43c2efae4efc49bcce7426d6 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 15 Sep 2025 17:16:15 +0000 Subject: [PATCH 28/35] multi-apply quantization config test Signed-off-by: Brian Dellabetta --- .../test_quantization/lifecycle/test_apply.py | 103 +++++++++++++++++- 1 file changed, 102 insertions(+), 1 deletion(-) diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 157695470..4e6129409 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -21,11 +21,15 @@ from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization import ( DEFAULT_QUANTIZATION_METHOD, + QuantizationArgs, QuantizationConfig, + QuantizationScheme, QuantizationStatus, + QuantizationStrategy, + QuantizationType, ) from compressed_tensors.quantization.lifecycle import apply_quantization_config -from compressed_tensors.utils import match_named_modules +from compressed_tensors.utils import is_match, match_named_modules from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM @@ -265,3 +269,100 @@ def test_apply_quantization_config(caplog, target, should_raise_warning): assert len(caplog.text) > 0 else: assert len(caplog.text) == 0 + + +def test_multi_apply_quantization_config(): + """ + Ensure that multiple quantization configs are applied correctly + If quantization config was previously applied to a module, + those changes should be reset for newly applied quantization config + """ + model = get_tinyllama_model() + + # FP8 applied to mlp and self_attn.o_proj to validate overwriting + qconfig1 = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=[ + r"re:.*model\.layers\.\d+\.mlp\.(down|gate|up)_proj$", + r"re:.*model\.layers\.\d+\.self_attn\.o_proj$", + ], + weights=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False, + ), + input_activations=QuantizationArgs( + num_bits=8, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.TENSOR, + symmetric=True, + dynamic=False, + ), + ) + }, + ignore=["lm_head"], + ) + # W4A16_ASYM applied to self_attn + qconfig2 = QuantizationConfig( + config_groups={ + "group_0": QuantizationScheme( + targets=[ + r"re:.*model\.layers\.\d+\.self_attn\.(k|q|o|v)_proj$", + ], + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.INT, + strategy=QuantizationStrategy.GROUP, + group_size=128, + symmetric=False, + dynamic=False, + ), + ) + }, + ignore=["lm_head"], + ) + + apply_quantization_config(model, qconfig1) + apply_quantization_config(model, qconfig2) + for name, module in model.named_modules(): + if is_match( + name, module, qconfig2.config_groups["group_0"].targets, qconfig2.ignore + ): + # assert W4A16_ASYM parameters are present with correct shape + # and FP8 parameters have been removed + assert not hasattr(module, "input_scale") + assert not hasattr(module, "input_zero_point") + weight_scale = getattr(module, "weight_scale", None) + assert ( + weight_scale is not None + and weight_scale.shape[:-1] == module.weight.shape[:-1] + and weight_scale.shape[-1] == module.weight.shape[-1] / 128 + ) + weight_zero_point = getattr(module, "weight_zero_point", None) + assert ( + weight_zero_point is not None + and weight_zero_point.shape[:-1] == module.weight.shape[:-1] + and weight_zero_point.shape[-1] == module.weight.shape[-1] / 128 + ) + + elif is_match( + name, module, qconfig1.config_groups["group_0"].targets, qconfig1.ignore + ): + # assert FP8 scheme parameters are present with correct shape + input_scale = getattr(module, "input_scale", None) + assert input_scale is not None and input_scale.shape == torch.Size([1]) + input_zero_point = getattr(module, "input_zero_point", None) + assert ( + input_zero_point is not None + and input_zero_point.shape == torch.Size([1]) + ) + weight_scale = getattr(module, "weight_scale", None) + assert weight_scale is not None and weight_scale.shape == torch.Size([1]) + weight_zero_point = getattr(module, "weight_zero_point", None) + assert ( + weight_zero_point is not None + and weight_zero_point.shape == torch.Size([1]) + ) From 5776c86a4c7fbb20f8ec3760ba4fa79f58ec3f2f Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 15 Sep 2025 17:30:40 +0000 Subject: [PATCH 29/35] multi-apply test cleanup Signed-off-by: Brian Dellabetta --- tests/test_quantization/lifecycle/test_apply.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 4e6129409..ae8908202 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -279,13 +279,12 @@ def test_multi_apply_quantization_config(): """ model = get_tinyllama_model() - # FP8 applied to mlp and self_attn.o_proj to validate overwriting + # FP8 applied to self_attn qconfig1 = QuantizationConfig( config_groups={ "group_0": QuantizationScheme( targets=[ - r"re:.*model\.layers\.\d+\.mlp\.(down|gate|up)_proj$", - r"re:.*model\.layers\.\d+\.self_attn\.o_proj$", + r"re:.*self_attn\.(k|q|o|v)_proj$", ], weights=QuantizationArgs( num_bits=8, @@ -305,12 +304,13 @@ def test_multi_apply_quantization_config(): }, ignore=["lm_head"], ) - # W4A16_ASYM applied to self_attn + # W4A16_ASYM applied to mlp and self_attn.o_proj to validate overwriting qconfig2 = QuantizationConfig( config_groups={ "group_0": QuantizationScheme( targets=[ - r"re:.*model\.layers\.\d+\.self_attn\.(k|q|o|v)_proj$", + r"re:.*mlp\.(down|gate|up)_proj$", + r"re:.*self_attn\.o_proj$", ], weights=QuantizationArgs( num_bits=4, From 98a97e573129f0e224d7da5bd2aae4ee96d12ca4 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 17 Sep 2025 20:34:47 +0000 Subject: [PATCH 30/35] ALL_QPARAM_NAMES Signed-off-by: Brian Dellabetta --- .../quantization/__init__.py | 1 + .../quantization/lifecycle/initialize.py | 36 ++++---------- .../quantization/quant_names.py | 48 +++++++++++++++++++ 3 files changed, 58 insertions(+), 27 deletions(-) create mode 100644 src/compressed_tensors/quantization/quant_names.py diff --git a/src/compressed_tensors/quantization/__init__.py b/src/compressed_tensors/quantization/__init__.py index 9fde69a35..10c8975a8 100644 --- a/src/compressed_tensors/quantization/__init__.py +++ b/src/compressed_tensors/quantization/__init__.py @@ -17,5 +17,6 @@ from .quant_args import * from .quant_config import * +from .quant_names import * from .quant_scheme import * from .lifecycle import * diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 008d57e2c..58dffcf43 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -16,21 +16,22 @@ import logging import math import warnings -from enum import Enum from typing import Optional import torch -from compressed_tensors.quantization.lifecycle.forward import ( - wrap_module_forward_quantized, -) -from compressed_tensors.quantization.quant_args import ( +from compressed_tensors.quantization import ( + ALL_QPARAM_NAMES, FP8_E4M3_DATA, ActivationOrdering, + KVCacheScaleType, QuantizationArgs, + QuantizationScheme, + QuantizationStatus, QuantizationStrategy, ) -from compressed_tensors.quantization.quant_config import QuantizationStatus -from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.lifecycle.forward import ( + wrap_module_forward_quantized, +) from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme from compressed_tensors.utils import ( delete_offload_parameter, @@ -44,31 +45,12 @@ __all__ = [ "initialize_module_for_quantization", "is_attention_module", - "KVCacheScaleType", - "ALL_QPARAM_KEYS", ] _LOGGER = logging.getLogger(__name__) -class KVCacheScaleType(Enum): - KEY = "k_scale" - VALUE = "v_scale" - - -ALL_QPARAM_KEYS = [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [ - f"{base_name}_{suffix}" - for base_name in ("input", "weight", "output") - for suffix in ( - "global_scale", - "scale", - "zero_point", - "g_idx", - ) -] - - def initialize_module_for_quantization( module: Module, scheme: Optional[QuantizationScheme] = None, @@ -159,7 +141,7 @@ def _clear_all_qparams( :param module: module to clear qparams from """ - for key in ALL_QPARAM_KEYS: + for key in ALL_QPARAM_NAMES: if hasattr(module, key): delete_offload_parameter(module, key) diff --git a/src/compressed_tensors/quantization/quant_names.py b/src/compressed_tensors/quantization/quant_names.py new file mode 100644 index 000000000..67e32345d --- /dev/null +++ b/src/compressed_tensors/quantization/quant_names.py @@ -0,0 +1,48 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +__all__ = ["ALL_QPARAM_NAMES", "KVCacheScaleType"] + + +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + +ALL_QPARAM_NAMES = ( + [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + + [ + f"{base_name}_{suffix}" + for base_name in ("input", "weight", "output") + for suffix in ( + "global_scale", + "scale", + "zero_point", + "g_idx", + ) + ] + + [ + "weight_packed", + "weight_global_scale", + "weight_shape", + "scale_packed", + "meta", + "shape", + "compressed", + "bitmask", + "row_offsets", + ] +) From 02d5e782120da28475c6b289ed667f0599fde3cc Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 17 Sep 2025 20:38:49 +0000 Subject: [PATCH 31/35] stylefixes Signed-off-by: Brian Dellabetta --- src/compressed_tensors/quantization/quant_names.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compressed_tensors/quantization/quant_names.py b/src/compressed_tensors/quantization/quant_names.py index 67e32345d..1eb97c466 100644 --- a/src/compressed_tensors/quantization/quant_names.py +++ b/src/compressed_tensors/quantization/quant_names.py @@ -14,6 +14,7 @@ from enum import Enum + __all__ = ["ALL_QPARAM_NAMES", "KVCacheScaleType"] From 7d8c5a43c26984334c2b2eeec0c65963e9345fda Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 17 Sep 2025 22:14:23 +0000 Subject: [PATCH 32/35] exclude sparsity param names Signed-off-by: Brian Dellabetta --- src/compressed_tensors/quantization/quant_names.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_names.py b/src/compressed_tensors/quantization/quant_names.py index 1eb97c466..80fd49a91 100644 --- a/src/compressed_tensors/quantization/quant_names.py +++ b/src/compressed_tensors/quantization/quant_names.py @@ -39,11 +39,5 @@ class KVCacheScaleType(Enum): "weight_packed", "weight_global_scale", "weight_shape", - "scale_packed", - "meta", - "shape", - "compressed", - "bitmask", - "row_offsets", ] ) From 01af659a882989024cdbc5bf9dcff28fd82d3ba1 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 18 Sep 2025 16:45:36 +0000 Subject: [PATCH 33/35] QuantizationMetadata class Signed-off-by: Brian Dellabetta --- .../quantization/__init__.py | 2 +- .../quantization/lifecycle/initialize.py | 17 +---- .../quantization/quant_metadata.py | 62 +++++++++++++++++++ .../quantization/quant_names.py | 43 ------------- 4 files changed, 65 insertions(+), 59 deletions(-) create mode 100644 src/compressed_tensors/quantization/quant_metadata.py delete mode 100644 src/compressed_tensors/quantization/quant_names.py diff --git a/src/compressed_tensors/quantization/__init__.py b/src/compressed_tensors/quantization/__init__.py index 10c8975a8..04ccedf53 100644 --- a/src/compressed_tensors/quantization/__init__.py +++ b/src/compressed_tensors/quantization/__init__.py @@ -17,6 +17,6 @@ from .quant_args import * from .quant_config import * -from .quant_names import * +from .quant_metadata import * from .quant_scheme import * from .lifecycle import * diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 58dffcf43..b4031ab49 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -20,11 +20,11 @@ import torch from compressed_tensors.quantization import ( - ALL_QPARAM_NAMES, FP8_E4M3_DATA, ActivationOrdering, KVCacheScaleType, QuantizationArgs, + QuantizationMetadata, QuantizationScheme, QuantizationStatus, QuantizationStrategy, @@ -76,7 +76,7 @@ def initialize_module_for_quantization( # no scheme passed and layer not targeted for quantization - skip return - _clear_all_qparams(module) + QuantizationMetadata.clear_all_qparams(module) if is_attention_module(module): # quantized actions based on calltime status @@ -133,19 +133,6 @@ def is_attention_module(module: Module): ) -def _clear_all_qparams( - module: Module, -): - """ - Clear all previously registered quantization parameters from module - - :param module: module to clear qparams from - """ - for key in ALL_QPARAM_NAMES: - if hasattr(module, key): - delete_offload_parameter(module, key) - - def _initialize_scale_zero_point( module: Module, base_name: str, diff --git a/src/compressed_tensors/quantization/quant_metadata.py b/src/compressed_tensors/quantization/quant_metadata.py new file mode 100644 index 000000000..e7567eabe --- /dev/null +++ b/src/compressed_tensors/quantization/quant_metadata.py @@ -0,0 +1,62 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +from compressed_tensors.utils import delete_offload_parameter +from torch.nn import Module + + +__all__ = ["QuantizationMetadata", "KVCacheScaleType"] + + +class KVCacheScaleType(Enum): + KEY = "k_scale" + VALUE = "v_scale" + + +class QuantizationMetadata: + """ + Container class for metadata related to quantization + """ + + @staticmethod + def all_qparam_names(): + """ + All quantization parameter names that might be registered + onto a module during lifecycle (excluding serialized parameters) + """ + return [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [ + f"{base_name}_{suffix}" + for base_name in ("input", "weight", "output") + for suffix in ( + "global_scale", + "scale", + "zero_point", + "g_idx", + ) + ] + + @classmethod + def clear_all_qparams(cls, module: Module): + """ + Remove all parameters related to quantization that might have + been registered onto a module previously in lifecycle (excluding + serialized parameters) + + :param module: Module to clear + """ + for key in cls.all_qparam_names(): + if hasattr(module, key): + delete_offload_parameter(module, key) diff --git a/src/compressed_tensors/quantization/quant_names.py b/src/compressed_tensors/quantization/quant_names.py deleted file mode 100644 index 80fd49a91..000000000 --- a/src/compressed_tensors/quantization/quant_names.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum - - -__all__ = ["ALL_QPARAM_NAMES", "KVCacheScaleType"] - - -class KVCacheScaleType(Enum): - KEY = "k_scale" - VALUE = "v_scale" - - -ALL_QPARAM_NAMES = ( - [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] - + [ - f"{base_name}_{suffix}" - for base_name in ("input", "weight", "output") - for suffix in ( - "global_scale", - "scale", - "zero_point", - "g_idx", - ) - ] - + [ - "weight_packed", - "weight_global_scale", - "weight_shape", - ] -) From 3fdd125fcb84db2ce32f1a2b39e5f444571240b5 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 18 Sep 2025 16:53:24 +0000 Subject: [PATCH 34/35] stylefix Signed-off-by: Brian Dellabetta --- src/compressed_tensors/quantization/lifecycle/initialize.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index b4031ab49..9f852c74f 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -34,7 +34,6 @@ ) from compressed_tensors.quantization.utils import is_fp4, is_kv_cache_quant_scheme from compressed_tensors.utils import ( - delete_offload_parameter, disable_hf_hook, get_execution_device, register_offload_parameter, From b789adf78b41f7358b13807d49a44d2f08912fef Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 18 Sep 2025 18:32:26 +0000 Subject: [PATCH 35/35] llm-compressor test fix Signed-off-by: Brian Dellabetta --- src/compressed_tensors/quantization/quant_config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 994af336f..4478a2ae5 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -185,7 +185,8 @@ def from_pretrained( ignore[layer_type] = [] ignore[layer_type].append(name) else: - quantization_status = submodule.quantization_status + if hasattr(submodule, "quantization_status"): + quantization_status = submodule.quantization_status scheme = submodule.quantization_scheme quantization_type_names.add(layer_type)