-
Notifications
You must be signed in to change notification settings - Fork 19
[Quantization] Support more than one quant-compressor #415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
5d6ebe8
30ae05c
246d711
c6136b2
b201266
6bf7171
d9141d9
8b5d4c9
b5cd4e7
f0bb64b
20d362a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -169,7 +169,7 @@ def from_pretrained_model( | |
cls, | ||
model: Module, | ||
sparsity_config: Union[SparsityCompressionConfig, str, None] = None, | ||
quantization_format: Optional[str] = None, | ||
quantization_format: Optional[Union[str, List[str]]] = None, | ||
) -> Optional["ModelCompressor"]: | ||
""" | ||
Given a pytorch model and optional sparsity and/or quantization configs, | ||
|
@@ -182,7 +182,21 @@ def from_pretrained_model( | |
algorithm | ||
:return: compressor for the configs, or None if model is not compressed | ||
""" | ||
# reconstruct config from schemes attached to modules | ||
|
||
if quantization_format is not None: | ||
# llmcompressor incorrectly passes in a CompressionFormat when | ||
# the value string is expected - handle both cases | ||
if isinstance(quantization_format, (str, CompressionFormat)): | ||
quantization_format = [quantization_format] | ||
|
||
compression_formats = quantization_format | ||
# assume multiple compression formats means mixed-precision | ||
# as we currently only support one compressor per precision type and scheme | ||
if len(quantization_format) > 1: | ||
quantization_format = CompressionFormat.mixed_precision.value | ||
else: | ||
quantization_format = quantization_format[0] | ||
|
||
quantization_config = QuantizationConfig.from_pretrained( | ||
model, format=quantization_format | ||
) | ||
|
@@ -203,6 +217,7 @@ def from_pretrained_model( | |
sparsity_config=sparsity_config, | ||
quantization_config=quantization_config, | ||
transform_config=transform_config, | ||
compression_formats=compression_formats, | ||
) | ||
|
||
@staticmethod | ||
|
@@ -263,30 +278,55 @@ def parse_quantization_config( | |
|
||
return quantization_config | ||
|
||
def _fetch_unique_quantization_formats(self) -> List[str]: | ||
""" | ||
Get all unique compression formats present in a model | ||
:return: list of quantization formats | ||
""" | ||
quantization_formats = [] | ||
for _, scheme in self.quantization_config.config_groups.items(): | ||
if scheme.format is not None and scheme.format not in quantization_formats: | ||
quantization_formats.append(scheme.format) | ||
|
||
# If empty list, fallback to using the global format | ||
if len(quantization_formats) == 0: | ||
quantization_formats.append(self.quantization_config.format) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
return quantization_formats | ||
|
||
def __init__( | ||
self, | ||
sparsity_config: Optional[SparsityCompressionConfig] = None, | ||
quantization_config: Optional[QuantizationConfig] = None, | ||
transform_config: Optional[TransformConfig] = None, | ||
compression_formats: Optional[List[str]] = None, | ||
): | ||
self.sparsity_config = sparsity_config | ||
self.quantization_config = quantization_config | ||
self.transform_config = transform_config | ||
self.compression_formats = compression_formats | ||
|
||
self.sparsity_compressor = None | ||
self.quantization_compressor: Optional[ | ||
Union[BaseQuantizationCompressor, DenseCompressor] | ||
Dict[str, Union[BaseQuantizationCompressor, DenseCompressor]] | ||
] = None | ||
Comment on lines
310
to
312
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we rename to |
||
# no transform compressor is required | ||
|
||
if sparsity_config is not None: | ||
self.sparsity_compressor = BaseCompressor.load_from_registry( | ||
sparsity_config.format, config=sparsity_config | ||
) | ||
|
||
if quantization_config is not None: | ||
self.quantization_compressor = BaseCompressor.load_from_registry( | ||
quantization_config.format, config=quantization_config | ||
) | ||
if not self.compression_formats: | ||
self.compression_formats = self._fetch_unique_quantization_formats() | ||
|
||
self.quantization_compressor = {} | ||
for format in self.compression_formats: | ||
self.quantization_compressor[ | ||
format | ||
] = BaseCompressor.load_from_registry( | ||
dsikka marked this conversation as resolved.
Show resolved
Hide resolved
|
||
format, config=quantization_config | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that a compressor is local to a |
||
) | ||
|
||
# ----- used by hf quantizer ----- # | ||
|
||
|
@@ -381,12 +421,13 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]: | |
targets=scheme.targets, | ||
ignore=self.quantization_config.ignore, | ||
) | ||
unexpected_keys.update( | ||
merge_names(target, param) | ||
for target in quant_targets | ||
for param in self.quantization_compressor.compression_param_names | ||
if param != "weight" | ||
) | ||
for quant_compressor in self.quantization_compressor.values(): | ||
unexpected_keys.update( | ||
merge_names(target, param) | ||
for target in quant_targets | ||
for param in quant_compressor.compression_param_names | ||
if param != "weight" | ||
) | ||
|
||
return list(unexpected_keys) | ||
|
||
|
@@ -424,7 +465,25 @@ def compress_model(self, model: Module): | |
|
||
# quantization first | ||
if prefix in module_to_scheme: | ||
state_dict = self.quantization_compressor.compress( | ||
if ( | ||
not hasattr(module.quantization_scheme, "format") | ||
or module.quantization_scheme.format is None | ||
): | ||
if ( | ||
self.quantization_config.format | ||
== CompressionFormat.mixed_precision.value | ||
): | ||
raise ValueError( | ||
"Compressing mixed-precision models without defining " | ||
"per module quantization_scheme.format is currently " | ||
"not supported" | ||
) | ||
format = self.quantization_config.format | ||
else: | ||
format = module.quantization_scheme.format | ||
|
||
quant_compressor = self.quantization_compressor.get(format) | ||
state_dict = quant_compressor.compress( | ||
state_dict, | ||
names_to_scheme=module_to_scheme, | ||
show_progress=False, | ||
|
@@ -495,12 +554,28 @@ def decompress_model(self, model: Module): | |
|
||
# quantization second | ||
if prefix in module_to_scheme: | ||
state_dict = ( | ||
self.quantization_compressor.decompress_module_from_state_dict( | ||
prefix, | ||
state_dict, | ||
scheme=module_to_scheme[prefix], | ||
) | ||
|
||
if ( | ||
not hasattr(module.quantization_scheme, "format") | ||
or module.quantization_scheme.format is None | ||
): | ||
if ( | ||
self.quantization_config.format | ||
== CompressionFormat.mixed_precision.value | ||
): | ||
raise ValueError( | ||
"Decompressing mixed-precision models without defining " | ||
"per module quantization_scheme.format is currently not " | ||
"supported" | ||
) | ||
format = self.quantization_config.format | ||
else: | ||
format = module.quantization_scheme.format | ||
quant_compressor = self.quantization_compressor.get(format) | ||
state_dict = quant_compressor.decompress_module_from_state_dict( | ||
prefix, | ||
state_dict, | ||
scheme=module_to_scheme[prefix], | ||
) | ||
|
||
# remove any existing parameters | ||
|
@@ -539,7 +614,9 @@ def compress( | |
|
||
if self.quantization_compressor is not None: | ||
module_to_scheme = map_module_to_scheme(model) | ||
state_dict = self.quantization_compressor.compress( | ||
# Note - compress only supports one compression format atm | ||
quant_compressor = next(iter(self.quantization_compressor.values())) | ||
state_dict = quant_compressor.compress( | ||
state_dict, | ||
names_to_scheme=module_to_scheme, | ||
show_progress=show_progress, | ||
|
@@ -588,14 +665,20 @@ def decompress(self, model_path: str, model: Module): | |
""" | ||
model_path = get_safetensors_folder(model_path) | ||
sparse_decompressed = False | ||
quant_compressor = ( | ||
next(iter(self.quantization_compressor.values())) | ||
if self.quantization_compressor is not None | ||
else None | ||
) | ||
|
||
if ( | ||
self.sparsity_compressor is not None | ||
and self.sparsity_config.format != CompressionFormat.dense.value | ||
): | ||
# note - decompress only supports one compressor atm | ||
params_to_ignore = None | ||
if self.quantization_compressor is not None: | ||
params_to_ignore = self.quantization_compressor.compression_param_names | ||
if quant_compressor is not None: | ||
params_to_ignore = quant_compressor.compression_param_names | ||
# Sparse decompression is applied on the model_path | ||
# The compressor will try and load any quantization parameters as well | ||
# params_to_skip_load will skip over quantization params from being loaded | ||
|
@@ -606,7 +689,7 @@ def decompress(self, model_path: str, model: Module): | |
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config) | ||
sparse_decompressed = True | ||
|
||
if self.quantization_compressor is not None: | ||
if quant_compressor is not None: | ||
# Temporarily set quantization status to FROZEN to prevent | ||
# quantization during apply_quantization_config. This ensures | ||
# that the dtypes of the weights are not unintentionally updated. | ||
|
@@ -629,15 +712,15 @@ def decompress(self, model_path: str, model: Module): | |
# including initialization | ||
load_weight_quantization=( | ||
sparse_decompressed | ||
or isinstance(self.quantization_compressor, DenseCompressor) | ||
or isinstance(quant_compressor, DenseCompressor) | ||
), | ||
) | ||
|
||
model_path_or_state_dict = ( | ||
model.state_dict() if sparse_decompressed else model_path | ||
) | ||
|
||
dense_gen = self.quantization_compressor.decompress( | ||
dense_gen = quant_compressor.decompress( | ||
model_path_or_state_dict, names_to_scheme=names_to_scheme | ||
) | ||
# TODO: all weight quantization params will be moved to the compressor | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
afaict this is the only entrypoint for this function.
Why not just adjust the upstream function infer_quantization_format to infer the mixed value? Rather than supporting an extra data type (List[str]) which ideally should never actually appear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @kylesayrs on this, also if a list of quantization formats are passed in we override them to mixed precision format and then infer them again downstream?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I disagree. Separation of concern. The infer_quantization_format is responsible for inferring the formats in the model but what gets written to the config should be determined by the ModelCompressor class which is ultimately responsible for writing the quantization config
We dont infer again - we use the per module format attached to each scheme to compress each module.
See the updated llmcompressor functionality: vllm-project/llm-compressor#1713
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Afaict the only reason why we would need to infer the list of used quantization formats in a model is to write to the config. I since model_compressor is responsible for writing to the config, I would argue that the "infer global quantization tag for the purposes of writing to config" logic should exist in model compressor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are going to pass all available formats, why are we then re inferring afterwards via
_fetch_unique_quantization_formats
? This seems like a potential conflict in source of truth.Ideally
scheme.format
should be the source of truth of formats.