Skip to content

Commit 7801f00

Browse files
authored
Composability (#219)
* Add: Support for targets and ignore in SparseCompressors Enable: Operations on state_dict to allow composability Add: Composability for compress/decompress pathways Update: Typing for a few methods Add: Composability Test Add: Some testing utils * Add: FP8 Test for composability * Review Comments! * More review comments from @dsikka * Fix failing tests * Rename is_target to is_sparse_target Update _replace_weight to work with updates from `85b473e` Add docstring to _replace_weights Update failing test * review comments from @kylesayrs
1 parent fe4a442 commit 7801f00

File tree

14 files changed

+694
-54
lines changed

14 files changed

+694
-54
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
import operator
1818
import os
1919
import re
20+
from contextlib import contextmanager
2021
from copy import deepcopy
21-
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
22+
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union
2223

2324
import compressed_tensors
2425
import torch
@@ -38,6 +39,7 @@
3839
apply_quantization_config,
3940
load_pretrained_quantization,
4041
)
42+
from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
4143
from compressed_tensors.quantization.quant_args import QuantizationArgs
4244
from compressed_tensors.quantization.utils import (
4345
is_module_quantized,
@@ -104,7 +106,6 @@ def from_pretrained(
104106
"""
105107
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
106108
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
107-
108109
return cls.from_compression_config(compression_config)
109110

110111
@classmethod
@@ -282,8 +283,14 @@ def compress(
282283
)
283284

284285
if self.sparsity_compressor is not None:
286+
sparse_compression_targets: Set[str] = expand_sparse_target_names(
287+
model=model,
288+
targets=self.sparsity_config.targets,
289+
ignore=self.sparsity_config.ignore,
290+
)
285291
compressed_state_dict = self.sparsity_compressor.compress(
286-
compressed_state_dict
292+
compressed_state_dict,
293+
compression_targets=sparse_compression_targets,
287294
)
288295

289296
# HACK: Override the dtype_byte_size function in transformers to
@@ -301,23 +308,41 @@ def decompress(self, model_path: str, model: Module):
301308
:param model: pytorch model to load decompressed weights into
302309
"""
303310
model_path = get_safetensors_folder(model_path)
311+
sparse_decompressed = False
312+
304313
if self.sparsity_compressor is not None:
314+
# Sparse decompression is applied on the model_path
305315
dense_gen = self.sparsity_compressor.decompress(model_path)
306316
self._replace_weights(dense_gen, model)
307317
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
318+
sparse_decompressed = True
308319

309320
if self.quantization_compressor is not None:
310-
names_to_scheme = apply_quantization_config(model, self.quantization_config)
311-
load_pretrained_quantization(model, model_path)
321+
# Temporarily set quantization status to FROZEN to prevent
322+
# quantization during apply_quantization_config. This ensures
323+
# that the dtypes of the weights are not unintentionally updated.
324+
# The status is restored after quantization params are loaded.
325+
with override_quantization_status(
326+
self.quantization_config, QuantizationStatus.FROZEN
327+
):
328+
names_to_scheme = apply_quantization_config(
329+
model, self.quantization_config
330+
)
331+
load_pretrained_quantization(model, model_path)
332+
333+
model_path_or_state_dict = (
334+
model.state_dict() if sparse_decompressed else model_path
335+
)
336+
312337
dense_gen = self.quantization_compressor.decompress(
313-
model_path, names_to_scheme=names_to_scheme
338+
model_path_or_state_dict, names_to_scheme=names_to_scheme
314339
)
315340
self._replace_weights(dense_gen, model)
316341

317-
def update_status(module):
342+
def freeze_quantization_status(module):
318343
module.quantization_status = QuantizationStatus.FROZEN
319344

320-
model.apply(update_status)
345+
model.apply(freeze_quantization_status)
321346
setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
322347

323348
def update_config(self, save_directory: str):
@@ -367,12 +392,26 @@ def update_config(self, save_directory: str):
367392
with open(config_file_path, "w") as config_file:
368393
json.dump(config_data, config_file, indent=2, sort_keys=True)
369394

370-
def _replace_weights(self, dense_weight_generator, model):
395+
def _replace_weights(self, dense_weight_generator, model: Module):
396+
"""
397+
Replace the weights of the model with the
398+
provided dense weights.
399+
400+
This method iterates over the dense_weight_generator and
401+
updates the corresponding weights in the model. If a parameter
402+
name does not exist in the model, it will be skipped.
403+
404+
:param dense_weight_generator (generator): A generator that yields
405+
tuples of (name, data), where 'name' is the parameter name and
406+
'data' is the updated param data
407+
:param model: The model whose weights are to be updated.
408+
"""
371409
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
372410
split_name = name.split(".")
373411
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
374412
module = operator.attrgetter(prefix)(model)
375-
update_parameter_data(module, data, param_name)
413+
if hasattr(module, param_name):
414+
update_parameter_data(module, data, param_name)
376415

377416

378417
def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
@@ -402,3 +441,23 @@ def new_dtype_byte_size(dtype):
402441
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
403442
bit_size = int(bit_search.groups()[0])
404443
return bit_size // 8
444+
445+
446+
@contextmanager
447+
def override_quantization_status(
448+
config: QuantizationConfig, status: QuantizationStatus
449+
):
450+
"""
451+
Within this context, the quantization status will be set to the
452+
supplied status. After the context exits, the original status
453+
will be restored.
454+
455+
:param config: the quantization config to override
456+
:param status: the status to temporarily set
457+
"""
458+
original_status = config.quantization_status
459+
config.quantization_status = status
460+
try:
461+
yield
462+
finally:
463+
config.quantization_status = original_status

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,17 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Dict, Generator, Tuple
16+
from pathlib import Path
17+
from typing import Any, Dict, Generator, Tuple, Union
1718

1819
import torch
1920
from compressed_tensors.compressors.base import BaseCompressor
2021
from compressed_tensors.quantization import QuantizationArgs
21-
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
22+
from compressed_tensors.utils import (
23+
get_nested_mappings_from_state_dict,
24+
get_nested_weight_mappings,
25+
merge_names,
26+
)
2227
from safetensors import safe_open
2328
from torch import Tensor
2429
from tqdm import tqdm
@@ -113,30 +118,55 @@ def compress(
113118

114119
def decompress(
115120
self,
116-
path_to_model_or_tensors: str,
121+
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
117122
names_to_scheme: Dict[str, QuantizationArgs],
118123
device: str = "cpu",
119124
) -> Generator[Tuple[str, Tensor], None, None]:
120125
"""
121126
Reads a compressed state dict located at path_to_model_or_tensors
122127
and returns a generator for sequentially decompressing back to a
123128
dense state dict
124-
125129
:param path_to_model_or_tensors: path to compressed safetensors model (directory
126130
with one or more safetensors files) or compressed tensors file
127131
:param names_to_scheme: quantization args for each quantized weight
128132
:param device: optional device to load intermediate weights into
129133
:return: compressed state dict
130134
"""
135+
if isinstance(path_to_model_or_tensors, (str, Path)):
136+
yield from self._decompress_from_path(
137+
path_to_model_or_tensors, names_to_scheme, device
138+
)
139+
140+
else:
141+
yield from self._decompress_from_state_dict(
142+
path_to_model_or_tensors, names_to_scheme
143+
)
144+
145+
def _decompress_from_path(self, path_to_model, names_to_scheme, device):
131146
weight_mappings = get_nested_weight_mappings(
132-
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
147+
path_to_model, self.COMPRESSION_PARAM_NAMES
133148
)
134149
for weight_name in weight_mappings.keys():
135150
weight_data = {}
136151
for param_name, safe_path in weight_mappings[weight_name].items():
137152
full_name = merge_names(weight_name, param_name)
138153
with safe_open(safe_path, framework="pt", device=device) as f:
139154
weight_data[param_name] = f.get_tensor(full_name)
155+
if "weight_scale" in weight_data:
156+
quant_args = names_to_scheme[weight_name]
157+
decompressed = self.decompress_weight(
158+
compressed_data=weight_data, quantization_args=quant_args
159+
)
160+
yield merge_names(weight_name, "weight"), decompressed
161+
162+
def _decompress_from_state_dict(self, state_dict, names_to_scheme):
163+
weight_mappings = get_nested_mappings_from_state_dict(
164+
state_dict, self.COMPRESSION_PARAM_NAMES
165+
)
166+
for weight_name in weight_mappings.keys():
167+
weight_data = {}
168+
for param_name, param_value in weight_mappings[weight_name].items():
169+
weight_data[param_name] = param_value
140170

141171
if "weight_scale" in weight_data:
142172
quant_args = names_to_scheme[weight_name]

src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,19 @@ def compress_weight(
6868
self,
6969
weight: Tensor,
7070
scale: Tensor,
71+
quantization_args: QuantizationArgs,
7172
zero_point: Optional[Tensor] = None,
7273
g_idx: Optional[torch.Tensor] = None,
73-
quantization_args: Optional[QuantizationArgs] = None,
7474
device: Optional[torch.device] = None,
7575
) -> Dict[str, torch.Tensor]:
7676
"""
7777
Compresses a single uncompressed weight
7878
7979
:param weight: uncompressed weight tensor
8080
:param scale: quantization scale for weight
81+
:param quantization_args: quantization parameters for weight
8182
:param zero_point: quantization zero point for weight
8283
:param g_idx: optional mapping from column index to group index
83-
:param quantization_args: quantization parameters for weight
8484
:param device: optional device to move compressed output to
8585
:return: dictionary of compressed weight data
8686
"""

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,19 @@ def compress_weight(
6868
self,
6969
weight: Tensor,
7070
scale: Tensor,
71+
quantization_args: QuantizationArgs,
7172
zero_point: Optional[Tensor] = None,
7273
g_idx: Optional[torch.Tensor] = None,
73-
quantization_args: Optional[QuantizationArgs] = None,
7474
device: Optional[torch.device] = None,
7575
) -> Dict[str, torch.Tensor]:
7676
"""
7777
Compresses a single uncompressed weight
7878
7979
:param weight: uncompressed weight tensor
8080
:param scale: quantization scale for weight
81+
:param quantization_args: quantization parameters for weight
8182
:param zero_point: quantization zero point for weight
8283
:param g_idx: optional mapping from column index to group index
83-
:param quantization_args: quantization parameters for weight
8484
:param device: optional device to move compressed output to
8585
:return: dictionary of compressed weight data
8686
"""

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import logging
16-
from typing import Dict, Generator, Tuple
16+
from typing import Dict, Generator, Optional, Set, Tuple
1717

1818
from compressed_tensors.compressors.base import BaseCompressor
1919
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
@@ -30,7 +30,8 @@
3030
class BaseSparseCompressor(BaseCompressor):
3131
"""
3232
Base class representing a sparse compression algorithm. Each child class should
33-
implement compression_param_info, compress_weight and decompress_weight.
33+
implement compression_param_info, compress_weight and decompress_weight; child
34+
classes should also define COMPRESSION_PARAM_NAMES.
3435
3536
Compressors support compressing/decompressing a full module state dict or a single
3637
quantized PyTorch leaf module.
@@ -59,19 +60,32 @@ class BaseSparseCompressor(BaseCompressor):
5960
:param config: config specifying compression parameters
6061
"""
6162

62-
def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
63+
def compress(
64+
self,
65+
model_state: Dict[str, Tensor],
66+
compression_targets: Optional[Set[str]] = None,
67+
) -> Dict[str, Tensor]:
6368
"""
6469
Compresses a dense state dict using bitmask compression
6570
6671
:param model_state: state dict of uncompressed model
72+
:param compression_targets: optional set of layer prefixes to compress,
73+
otherwise compress all layers (for backwards compatibility)
6774
:return: compressed state dict
6875
"""
6976
compressed_dict = {}
7077
_LOGGER.debug(
7178
f"Compressing model with {len(model_state)} parameterized layers..."
7279
)
7380
for name, value in tqdm(model_state.items(), desc="Compressing model"):
74-
compression_data = self.compress_weight(name, value)
81+
if not self.should_compress(name, compression_targets):
82+
compressed_dict[name] = value
83+
continue
84+
prefix = name
85+
if prefix.endswith(".weight"):
86+
prefix = prefix[: -(len(".weight"))]
87+
88+
compression_data = self.compress_weight(prefix, value)
7589
for key in compression_data.keys():
7690
if key in compressed_dict:
7791
_LOGGER.warn(
@@ -97,8 +111,10 @@ def decompress(
97111
:param device: device to load decompressed weights onto
98112
:return: iterator for generating decompressed weights
99113
"""
100-
weight_mappings = get_nested_weight_mappings(
101-
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
114+
weight_mappings, ignored_params = get_nested_weight_mappings(
115+
path_to_model_or_tensors,
116+
self.COMPRESSION_PARAM_NAMES,
117+
return_unmatched_params=True,
102118
)
103119
for weight_name in weight_mappings.keys():
104120
weight_data = {}
@@ -107,4 +123,26 @@ def decompress(
107123
with safe_open(safe_path, framework="pt", device=device) as f:
108124
weight_data[param_name] = f.get_tensor(full_name)
109125
decompressed = self.decompress_weight(weight_data)
110-
yield weight_name, decompressed
126+
yield merge_names(weight_name, "weight"), decompressed
127+
128+
for ignored_param_name, safe_path in ignored_params.items():
129+
with safe_open(safe_path, framework="pt", device=device) as f:
130+
value = f.get_tensor(ignored_param_name)
131+
yield ignored_param_name, value
132+
133+
@staticmethod
134+
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
135+
"""
136+
Check if a parameter should be compressed.
137+
Currently, this only returns True for weight parameters.
138+
139+
:param name: name of the parameter
140+
:param expanded_targets: set of layer prefixes to compress
141+
:return: whether or not the parameter should be compressed
142+
"""
143+
if expanded_targets is None:
144+
return name.endswith(".weight")
145+
146+
return (
147+
name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
148+
)

src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from compressed_tensors.compressors.base import BaseCompressor
2020
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
2121
from compressed_tensors.config import CompressionFormat
22+
from compressed_tensors.quantization import FP8_DTYPE
2223
from compressed_tensors.utils import merge_names
2324
from torch import Tensor
2425

@@ -134,9 +135,14 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
134135
bytemasks = tensor != 0
135136
row_counts = bytemasks.sum(dim=-1)
136137
row_offsets = torch.cumsum(row_counts, 0) - row_counts
137-
values = tensor[bytemasks]
138+
if tensor.dtype == FP8_DTYPE:
139+
# acces raw bytes of the tensor
140+
tensor_view = tensor.view(torch.int8)
141+
values = tensor_view[bytemasks]
142+
values = values.view(FP8_DTYPE)
143+
else:
144+
values = tensor[bytemasks]
138145
bitmasks_packed = pack_bitmasks(bytemasks)
139-
140146
return values, bitmasks_packed, row_offsets
141147

142148

0 commit comments

Comments
 (0)