Skip to content

Commit a26c03a

Browse files
authored
Allow ModelCompressor.from_pretrained to load from quantization_config, not compression config (#207)
1 parent 7103a27 commit a26c03a

File tree

4 files changed

+24
-9
lines changed

4 files changed

+24
-9
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import torch
2525
import transformers
2626
from compressed_tensors.base import (
27-
COMPRESSION_CONFIG_NAME,
2827
COMPRESSION_VERSION_NAME,
2928
QUANTIZATION_CONFIG_NAME,
3029
QUANTIZATION_METHOD_NAME,
@@ -39,6 +38,7 @@
3938
apply_quantization_config,
4039
load_pretrained_quantization,
4140
)
41+
from compressed_tensors.quantization.quant_args import QuantizationArgs
4242
from compressed_tensors.quantization.utils import (
4343
is_module_quantized,
4444
iter_named_leaf_modules,
@@ -103,12 +103,14 @@ def from_pretrained(
103103
:return: compressor for the configs, or None if model is not compressed
104104
"""
105105
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
106-
compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
106+
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
107+
107108
return cls.from_compression_config(compression_config)
108109

109110
@classmethod
110111
def from_compression_config(
111-
cls, compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
112+
cls,
113+
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
112114
):
113115
"""
114116
:param compression_config:
@@ -265,7 +267,11 @@ def compress(
265267
state_dict = model.state_dict()
266268

267269
compressed_state_dict = state_dict
268-
quantized_modules_to_args = map_modules_to_quant_args(model)
270+
271+
quantized_modules_to_args: Dict[
272+
str, QuantizationArgs
273+
] = map_modules_to_quant_args(model)
274+
269275
if self.quantization_compressor is not None:
270276
compressed_state_dict = self.quantization_compressor.compress(
271277
state_dict, names_to_scheme=quantized_modules_to_args
@@ -369,7 +375,13 @@ def _replace_weights(self, dense_weight_generator, model):
369375
update_parameter_data(module, data, param_name)
370376

371377

372-
def map_modules_to_quant_args(model: Module) -> Dict:
378+
def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
379+
"""
380+
Given a pytorch model, map out the submodule name (usually linear layers)
381+
to the QuantizationArgs
382+
383+
:param model: pytorch model
384+
"""
373385
quantized_modules_to_args = {}
374386
for name, submodule in iter_named_leaf_modules(model):
375387
if is_module_quantized(submodule):

src/compressed_tensors/linear/compressed_linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Dict, Tuple
16+
1517
import torch
1618
from compressed_tensors.compressors.base import BaseCompressor
1719
from compressed_tensors.quantization import (
@@ -53,7 +55,7 @@ def from_linear(
5355
)
5456

5557
# get the shape and dtype of compressed parameters
56-
compression_params = module.compressor.compression_param_info(
58+
compression_params: Dict[str, Tuple] = module.compressor.compression_param_info(
5759
module.weight.shape, quantization_scheme.weights
5860
)
5961

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def apply_quantization_config(
106106
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
107107
) -> OrderedDict:
108108
"""
109-
Initializes the model for quantization in-place based on the given config
109+
Initializes the model for quantization in-place based on the given config.
110+
Optionally coverts quantizable modules to compressed_linear modules
110111
111112
:param model: model to apply quantization config to
112113
:param config: quantization config

src/compressed_tensors/quantization/quant_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,9 @@ class QuantizationConfig(BaseModel):
132132
`k_proj` and `v_proj` in their names. If this is not the case
133133
and kv_cache_scheme != None, the quantization of kv cache will fail
134134
:global_compression_ratio: optional informational config to report the model
135-
compression ratio acheived by the quantization config
135+
compression ratio acheived by the quantization config
136136
:ignore: optional list of layers to ignore from config_groups. Layers in this list
137-
are not quantized even if they match up with a target in config_groups
137+
are not quantized even if they match up with a target in config_groups
138138
"""
139139

140140
config_groups: Dict[str, Union[QuantizationScheme, List[str]]]

0 commit comments

Comments
 (0)