Skip to content

Commit 4571265

Browse files
committed
Merge branch 'kylesayrs/serialize-tconfig' into kylesayrs/transform-merge
2 parents a3cd59d + 9d6a127 commit 4571265

File tree

12 files changed

+204
-126
lines changed

12 files changed

+204
-126
lines changed

src/compressed_tensors/base.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
SPARSITY_CONFIG_NAME = "sparsity_config"
15+
# configs
1616
QUANTIZATION_CONFIG_NAME = "quantization_config"
17-
COMPRESSION_CONFIG_NAME = "compression_config"
18-
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
17+
SPARSITY_CONFIG_NAME = "sparsity_config"
18+
TRANSFORM_CONFIG_NAME = "transform_config"
19+
20+
# required fields
1921
COMPRESSION_VERSION_NAME = "version"
2022
QUANTIZATION_METHOD_NAME = "quant_method"
23+
24+
# auxillary configs
25+
KV_CACHE_SCHEME_NAME = "kv_cache_scheme"

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 55 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
QUANTIZATION_CONFIG_NAME,
3030
QUANTIZATION_METHOD_NAME,
3131
SPARSITY_CONFIG_NAME,
32+
TRANSFORM_CONFIG_NAME,
3233
)
3334
from compressed_tensors.compressors.base import BaseCompressor
3435
from compressed_tensors.compressors.sparse_compressors import DenseCompressor
@@ -43,6 +44,7 @@
4344
)
4445
from compressed_tensors.quantization.lifecycle import expand_target_names
4546
from compressed_tensors.quantization.utils import is_module_quantized
47+
from compressed_tensors.transform import TransformConfig
4648
from compressed_tensors.utils import (
4749
align_module_device,
4850
delete_offload_parameter,
@@ -105,6 +107,7 @@ class ModelCompressor:
105107

106108
sparsity_config: Optional[SparsityCompressionConfig] = None
107109
quantization_config: Optional[QuantizationConfig] = None
110+
transform_config: Optional[TransformConfig] = None
108111

109112
@classmethod
110113
def from_pretrained(
@@ -144,6 +147,8 @@ def from_compression_config(
144147

145148
sparsity_config = cls.parse_sparsity_config(compression_config)
146149
quantization_config = cls.parse_quantization_config(compression_config)
150+
# NOTE: transfrom config is not support by ctconfig yet
151+
147152
if sparsity_config is None and quantization_config is None:
148153
return None
149154

@@ -177,20 +182,27 @@ def from_pretrained_model(
177182
algorithm
178183
:return: compressor for the configs, or None if model is not compressed
179184
"""
185+
# reconstruct config from schemes attached to modules
180186
quantization_config = QuantizationConfig.from_pretrained(
181187
model, format=quantization_format
182188
)
183189

190+
# use config passed as argument
184191
if isinstance(sparsity_config, str): # we passed in a sparsity format
185192
sparsity_config = SparsityCompressionConfig.load_from_registry(
186193
sparsity_config
187194
)
188195

189-
if sparsity_config is None and quantization_config is None:
196+
# use config attached to model
197+
transform_config = getattr(model, TRANSFORM_CONFIG_NAME, None)
198+
199+
if not any((quantization_config, sparsity_config, transform_config)):
190200
return None
191201

192202
return cls(
193-
sparsity_config=sparsity_config, quantization_config=quantization_config
203+
sparsity_config=sparsity_config,
204+
quantization_config=quantization_config,
205+
transform_config=transform_config,
194206
)
195207

196208
@staticmethod
@@ -254,13 +266,17 @@ def __init__(
254266
self,
255267
sparsity_config: Optional[SparsityCompressionConfig] = None,
256268
quantization_config: Optional[QuantizationConfig] = None,
269+
transform_config: Optional[TransformConfig] = None,
257270
):
258271
self.sparsity_config = sparsity_config
259272
self.quantization_config = quantization_config
273+
self.transform_config = transform_config
274+
260275
self.sparsity_compressor = None
261276
self.quantization_compressor: Optional[
262277
Union[BaseQuantizationCompressor, DenseCompressor]
263278
] = None
279+
# no transform compressor is required
264280

265281
if sparsity_config is not None:
266282
self.sparsity_compressor = BaseCompressor.load_from_registry(
@@ -640,43 +656,49 @@ def update_config(self, save_directory: str):
640656
641657
:param save_directory: path to a folder containing a HF model config
642658
"""
643-
if self.quantization_config is None and self.sparsity_config is None:
659+
# this check is also done in `from_pretrained_model`,
660+
# but not in `from_pretrained`` or `from_compression_config``
661+
if not any(
662+
(self.quantization_config, self.sparsity_config, self.transform_config)
663+
):
644664
return
645665

666+
# write to config.json file, regardless of whether it exists already
667+
# overwrite previous config and version if already existing
646668
config_file_path = os.path.join(save_directory, CONFIG_NAME)
647-
if not os.path.exists(config_file_path):
648-
_LOGGER.warning(
649-
f"Could not find a valid model config file in "
650-
f"{save_directory}. Compression config will not be saved."
651-
)
652-
return
669+
if os.path.exists(config_file_path):
670+
with open(config_file_path, "r") as file:
671+
config_data = json.load(file)
672+
else:
673+
config_data = {}
653674

654-
with open(config_file_path, "r") as config_file:
655-
config_data = json.load(config_file)
675+
# serialize configs into json
676+
qconfig_data = (
677+
self.quantization_config.model_dump(exclude=["quant_method", "format"])
678+
if self.quantization_config is not None
679+
else {}
680+
)
681+
sconfig_data = (
682+
self.sparsity_config.model_dump()
683+
if self.sparsity_config is not None
684+
else {}
685+
)
686+
tconfig_data = (
687+
self.transform_config.model_dump()
688+
if self.transform_config is not None
689+
else {}
690+
)
656691

657-
# required metadata whenever a quantization or sparsity config is present
658-
# overwrite previous config and version if already existing
659-
config_data[QUANTIZATION_CONFIG_NAME] = {}
660-
config_data[QUANTIZATION_CONFIG_NAME][
661-
COMPRESSION_VERSION_NAME
662-
] = compressed_tensors.__version__
663-
if self.quantization_config is not None:
664-
self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
665-
else:
666-
config_data[QUANTIZATION_CONFIG_NAME][
667-
QUANTIZATION_METHOD_NAME
668-
] = DEFAULT_QUANTIZATION_METHOD
669-
670-
# quantization and sparsity configs
671-
if self.quantization_config is not None:
672-
quant_config_data = self.quantization_config.model_dump()
673-
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
674-
if self.sparsity_config is not None:
675-
sparsity_config_data = self.sparsity_config.model_dump()
676-
config_data[QUANTIZATION_CONFIG_NAME][
677-
SPARSITY_CONFIG_NAME
678-
] = sparsity_config_data
692+
# construct compression (quantization) config
693+
config_data[QUANTIZATION_CONFIG_NAME] = {
694+
COMPRESSION_VERSION_NAME: compressed_tensors.__version__,
695+
QUANTIZATION_METHOD_NAME: DEFAULT_QUANTIZATION_METHOD,
696+
SPARSITY_CONFIG_NAME: sconfig_data,
697+
TRANSFORM_CONFIG_NAME: tconfig_data,
698+
**qconfig_data,
699+
}
679700

701+
# write results to config.json file
680702
with open(config_file_path, "w") as config_file:
681703
json.dump(config_data, config_file, indent=2, sort_keys=True)
682704

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,27 @@ def compression_param_names(self) -> Tuple[str]:
6161
"weight_global_scale",
6262
)
6363

64+
def compression_param_info(
65+
self,
66+
weight_shape: torch.Size,
67+
quantization_args: Optional[QuantizationArgs] = None,
68+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
69+
"""
70+
Creates a dictionary of expected shapes and dtypes for each compression
71+
parameter used by the compressor
72+
73+
:param weight_shape: uncompressed weight shape
74+
:param quantization_args: quantization parameters for the weight
75+
:return: dictionary mapping compressed parameter names to shape and dtype
76+
"""
77+
output = {
78+
"weight_packed": (
79+
torch.Size((weight_shape[0], weight_shape[1] // 2)),
80+
torch.uint8,
81+
),
82+
}
83+
return output
84+
6485
def compress_weight(
6586
self,
6687
weight: Tensor,

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -257,13 +257,10 @@ def _process_quantization(
257257
QuantizationStrategy.GROUP,
258258
QuantizationStrategy.TENSOR_GROUP,
259259
):
260-
n_dims = x.shape
261-
if len(n_dims) > 2:
262-
x = x.squeeze(0)
263260

264261
output_dtype = dtype if dtype is not None else x.dtype
265262
output = torch.zeros_like(x).to(output_dtype)
266-
columns = output.shape[1]
263+
columns = output.shape[-1]
267264

268265
# TODO: make validation step for inputs
269266

@@ -293,14 +290,12 @@ def _process_quantization(
293290
perm = torch.argsort(g_idx)
294291
x = safe_permute(x, perm, dim=1)
295292

296-
x = torch.reshape(
297-
x,
298-
(
299-
x.shape[0],
300-
ceil(x.shape[1] / group_size),
301-
group_size,
302-
),
293+
# Maintain all dimensions apart from the last dim, which is divided by the group_size
294+
reshaped_dims = (
295+
ceil(x.shape[-1] / group_size),
296+
group_size,
303297
)
298+
x = x.unflatten(-1, reshaped_dims)
304299

305300
if do_quantize:
306301
output = _quantize(
@@ -323,19 +318,12 @@ def _process_quantization(
323318
global_scale=global_scale,
324319
)
325320

326-
output = torch.reshape(
327-
output,
328-
(output.shape[0], output.shape[1] * output.shape[2]),
329-
)
330-
321+
output = output.flatten(start_dim=-2)
331322
output = output.to(output_dtype)
332323

333324
if not is_column_order:
334325
output = safe_permute(output, torch.argsort(perm), dim=1)
335326

336-
if len(n_dims) > 2:
337-
output = output.unsqueeze(0)
338-
339327
else: # covers channel, token and tensor strategies
340328
if do_quantize:
341329
output = _quantize(

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -175,20 +175,16 @@ def compute_dynamic_scales_and_zp(
175175
QuantizationStrategy.TENSOR_GROUP,
176176
QuantizationStrategy.GROUP,
177177
):
178-
if len(value.shape) > 2:
179-
value = value.squeeze(0)
180178

181-
dim = {0, 1}
182-
reduce_dims = tuple(idx for idx in range(3) if idx not in dim)
179+
reduce_dims = -1
183180
keep_dims = False
184-
value = torch.reshape(
185-
value,
186-
(
187-
value.shape[0],
188-
math.ceil(value.shape[1] / args.group_size),
189-
args.group_size,
190-
),
181+
182+
reshaped_dims = (
183+
math.ceil(value.shape[-1] / args.group_size),
184+
args.group_size,
191185
)
186+
value = value.unflatten(-1, reshaped_dims)
187+
192188
else:
193189
supported_strategies = (
194190
QuantizationStrategy.TOKEN,

src/compressed_tensors/transform/apply.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import torch
16+
from compressed_tensors import TRANSFORM_CONFIG_NAME
1617
from compressed_tensors.transform import TransformConfig, TransformFactory
1718

1819

@@ -30,3 +31,6 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
3031
for name, scheme in config.config_groups.items():
3132
factory = TransformFactory.from_scheme(scheme, name=name)
3233
factory.apply_to_model(model)
34+
35+
# attach config to model for compression/serialization
36+
setattr(model, TRANSFORM_CONFIG_NAME, config)

src/compressed_tensors/transform/factory/base.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414

1515
from abc import ABC, abstractmethod
1616
from collections import defaultdict
17-
from typing import List, Optional, Tuple
17+
from typing import List, Optional, Tuple, Set
1818

1919
import torch
2020
import torch.nn.utils.parametrize as P
21-
from compressed_tensors import InternalModule, match_named_modules
21+
from compressed_tensors.utils.internal import InternalModule
2222
from compressed_tensors.registry.registry import RegistryMixin, T
2323
from compressed_tensors.transform import (
2424
TransformArgs,
@@ -164,10 +164,6 @@ def _update_tied_weights(self):
164164
which is used by transformers to detect and remove shared pointers
165165
during saving
166166
"""
167-
# avoid issues with this method being called twice
168-
for transform in self.transforms:
169-
transform._dynamic_tied_weights_keys = list()
170-
171167
# map from data_ptrs to keys
172168
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
173169
for transform in self.transforms:
@@ -184,7 +180,7 @@ def _update_tied_weights(self):
184180
tensor = getattr(shared_keys[0][0], shared_keys[0][1])
185181

186182
for transform, name in shared_keys:
187-
transform._dynamic_tied_weights_keys.append(name)
183+
transform._dynamic_tied_weights_keys.add(name)
188184
setattr(transform, name, tensor)
189185

190186

@@ -195,11 +191,11 @@ class TransformBase(InternalModule, ABC):
195191

196192
args: TransformArgs
197193
weight: Parameter
198-
_dynamic_tied_weights_keys: List[str]
194+
_dynamic_tied_weights_keys: Set[str]
199195

200196
def __init__(self):
201197
super().__init__()
202-
self._dynamic_tied_weights_keys = list()
198+
self._dynamic_tied_weights_keys = set()
203199

204200
@abstractmethod
205201
def forward(self, value: Tensor) -> Tensor:

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
# limitations under the License.
1414

1515
import math
16+
<<<<<<< HEAD
1617
from typing import Optional
18+
=======
19+
from typing import Optional, Union
20+
>>>>>>> kylesayrs/serialize-tconfig
1721

1822
import torch
1923
from compressed_tensors.transform import TransformArgs, TransformScheme

0 commit comments

Comments
 (0)