-
Notifications
You must be signed in to change notification settings - Fork 20
[Quantization] Support more than one quant-compressor #415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
5d6ebe8
30ae05c
246d711
c6136b2
b201266
6bf7171
d9141d9
8b5d4c9
b5cd4e7
f0bb64b
20d362a
f7203b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -169,7 +169,7 @@ def from_pretrained_model( | |
cls, | ||
model: Module, | ||
sparsity_config: Union[SparsityCompressionConfig, str, None] = None, | ||
quantization_format: Optional[str] = None, | ||
quantization_format: Optional[Union[str, List[str]]] = None, | ||
) -> Optional["ModelCompressor"]: | ||
""" | ||
Given a pytorch model and optional sparsity and/or quantization configs, | ||
|
@@ -182,7 +182,22 @@ def from_pretrained_model( | |
algorithm | ||
:return: compressor for the configs, or None if model is not compressed | ||
""" | ||
# reconstruct config from schemes attached to modules | ||
|
||
compression_formats = None | ||
if quantization_format is not None: | ||
# llmcompressor incorrectly passes in a CompressionFormat when | ||
# the value string is expected - handle both cases | ||
if isinstance(quantization_format, (str, CompressionFormat)): | ||
quantization_format = [quantization_format] | ||
|
||
compression_formats = quantization_format | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI this parsing logic is duplicated in |
||
# assume multiple compression formats means mixed-precision | ||
# as we currently only support one compressor per precision type and scheme | ||
if len(quantization_format) > 1: | ||
quantization_format = CompressionFormat.mixed_precision.value | ||
else: | ||
quantization_format = quantization_format[0] | ||
|
||
quantization_config = QuantizationConfig.from_pretrained( | ||
model, format=quantization_format | ||
) | ||
|
@@ -203,6 +218,7 @@ def from_pretrained_model( | |
sparsity_config=sparsity_config, | ||
quantization_config=quantization_config, | ||
transform_config=transform_config, | ||
compression_formats=compression_formats, | ||
) | ||
|
||
@staticmethod | ||
|
@@ -263,30 +279,55 @@ def parse_quantization_config( | |
|
||
return quantization_config | ||
|
||
def _fetch_unique_quantization_formats(self) -> List[str]: | ||
""" | ||
Get all unique compression formats present in a model | ||
:return: list of quantization formats | ||
""" | ||
quantization_formats = [] | ||
for _, scheme in self.quantization_config.config_groups.items(): | ||
if scheme.format is not None and scheme.format not in quantization_formats: | ||
quantization_formats.append(scheme.format) | ||
|
||
# If empty list, fallback to using the global format | ||
if len(quantization_formats) == 0: | ||
quantization_formats.append(self.quantization_config.format) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
return quantization_formats | ||
|
||
def __init__( | ||
self, | ||
sparsity_config: Optional[SparsityCompressionConfig] = None, | ||
quantization_config: Optional[QuantizationConfig] = None, | ||
transform_config: Optional[TransformConfig] = None, | ||
compression_formats: Optional[List[str]] = None, | ||
): | ||
self.sparsity_config = sparsity_config | ||
self.quantization_config = quantization_config | ||
self.transform_config = transform_config | ||
self.compression_formats = compression_formats | ||
|
||
self.sparsity_compressor = None | ||
self.quantization_compressor: Optional[ | ||
Union[BaseQuantizationCompressor, DenseCompressor] | ||
Dict[str, Union[BaseQuantizationCompressor, DenseCompressor]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a reason this can't be renamed to indicate it is a map of compressors instead of a single compressor? |
||
] = None | ||
dsikka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# no transform compressor is required | ||
|
||
if sparsity_config is not None: | ||
self.sparsity_compressor = BaseCompressor.load_from_registry( | ||
sparsity_config.format, config=sparsity_config | ||
) | ||
|
||
if quantization_config is not None: | ||
self.quantization_compressor = BaseCompressor.load_from_registry( | ||
quantization_config.format, config=quantization_config | ||
) | ||
if not self.compression_formats: | ||
self.compression_formats = self._fetch_unique_quantization_formats() | ||
|
||
self.quantization_compressor = {} | ||
for format in self.compression_formats: | ||
self.quantization_compressor[ | ||
format | ||
] = BaseCompressor.load_from_registry( | ||
dsikka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
format, config=quantization_config | ||
brian-dellabetta marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
# ----- used by hf quantizer ----- # | ||
|
||
|
@@ -381,12 +422,13 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]: | |
targets=scheme.targets, | ||
ignore=self.quantization_config.ignore, | ||
) | ||
unexpected_keys.update( | ||
merge_names(target, param) | ||
for target in quant_targets | ||
for param in self.quantization_compressor.compression_param_names | ||
if param != "weight" | ||
) | ||
for quant_compressor in self.quantization_compressor.values(): | ||
unexpected_keys.update( | ||
merge_names(target, param) | ||
for target in quant_targets | ||
for param in quant_compressor.compression_param_names | ||
if param != "weight" | ||
) | ||
|
||
return list(unexpected_keys) | ||
|
||
|
@@ -424,7 +466,25 @@ def compress_model(self, model: Module): | |
|
||
# quantization first | ||
if prefix in module_to_scheme: | ||
state_dict = self.quantization_compressor.compress( | ||
if ( | ||
not hasattr(module.quantization_scheme, "format") | ||
or module.quantization_scheme.format is None | ||
): | ||
if ( | ||
self.quantization_config.format | ||
== CompressionFormat.mixed_precision.value | ||
): | ||
raise ValueError( | ||
"Compressing mixed-precision models without defining " | ||
"per module quantization_scheme.format is currently " | ||
"not supported" | ||
) | ||
format = self.quantization_config.format | ||
else: | ||
format = module.quantization_scheme.format | ||
|
||
quant_compressor = self.quantization_compressor.get(format) | ||
state_dict = quant_compressor.compress( | ||
state_dict, | ||
names_to_scheme=module_to_scheme, | ||
show_progress=False, | ||
|
@@ -495,12 +555,28 @@ def decompress_model(self, model: Module): | |
|
||
# quantization second | ||
if prefix in module_to_scheme: | ||
state_dict = ( | ||
self.quantization_compressor.decompress_module_from_state_dict( | ||
prefix, | ||
state_dict, | ||
scheme=module_to_scheme[prefix], | ||
) | ||
|
||
if ( | ||
not hasattr(module.quantization_scheme, "format") | ||
or module.quantization_scheme.format is None | ||
): | ||
if ( | ||
self.quantization_config.format | ||
== CompressionFormat.mixed_precision.value | ||
): | ||
raise ValueError( | ||
"Decompressing mixed-precision models without defining " | ||
"per module quantization_scheme.format is currently not " | ||
"supported" | ||
) | ||
format = self.quantization_config.format | ||
else: | ||
format = module.quantization_scheme.format | ||
quant_compressor = self.quantization_compressor.get(format) | ||
state_dict = quant_compressor.decompress_module_from_state_dict( | ||
prefix, | ||
state_dict, | ||
scheme=module_to_scheme[prefix], | ||
) | ||
|
||
# remove any existing parameters | ||
|
@@ -539,7 +615,9 @@ def compress( | |
|
||
if self.quantization_compressor is not None: | ||
module_to_scheme = map_module_to_scheme(model) | ||
state_dict = self.quantization_compressor.compress( | ||
# Note - compress only supports one compression format atm | ||
quant_compressor = next(iter(self.quantization_compressor.values())) | ||
state_dict = quant_compressor.compress( | ||
state_dict, | ||
names_to_scheme=module_to_scheme, | ||
show_progress=show_progress, | ||
|
@@ -588,14 +666,20 @@ def decompress(self, model_path: str, model: Module): | |
""" | ||
model_path = get_safetensors_folder(model_path) | ||
sparse_decompressed = False | ||
quant_compressor = ( | ||
next(iter(self.quantization_compressor.values())) | ||
if self.quantization_compressor is not None | ||
else None | ||
) | ||
|
||
if ( | ||
self.sparsity_compressor is not None | ||
and self.sparsity_config.format != CompressionFormat.dense.value | ||
): | ||
# note - decompress only supports one compressor atm | ||
params_to_ignore = None | ||
if self.quantization_compressor is not None: | ||
params_to_ignore = self.quantization_compressor.compression_param_names | ||
if quant_compressor is not None: | ||
params_to_ignore = quant_compressor.compression_param_names | ||
# Sparse decompression is applied on the model_path | ||
# The compressor will try and load any quantization parameters as well | ||
# params_to_skip_load will skip over quantization params from being loaded | ||
|
@@ -606,7 +690,7 @@ def decompress(self, model_path: str, model: Module): | |
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config) | ||
sparse_decompressed = True | ||
|
||
if self.quantization_compressor is not None: | ||
if quant_compressor is not None: | ||
# Temporarily set quantization status to FROZEN to prevent | ||
# quantization during apply_quantization_config. This ensures | ||
# that the dtypes of the weights are not unintentionally updated. | ||
|
@@ -629,15 +713,15 @@ def decompress(self, model_path: str, model: Module): | |
# including initialization | ||
load_weight_quantization=( | ||
sparse_decompressed | ||
or isinstance(self.quantization_compressor, DenseCompressor) | ||
or isinstance(quant_compressor, DenseCompressor) | ||
), | ||
) | ||
|
||
model_path_or_state_dict = ( | ||
model.state_dict() if sparse_decompressed else model_path | ||
) | ||
|
||
dense_gen = self.quantization_compressor.decompress( | ||
dense_gen = quant_compressor.decompress( | ||
model_path_or_state_dict, names_to_scheme=names_to_scheme | ||
) | ||
# TODO: all weight quantization params will be moved to the compressor | ||
|
Uh oh!
There was an error while loading. Please reload this page.