Skip to content

Commit 5e5ffb5

Browse files
Merge branch 'main' into bdellabe/scoped-quant-status
2 parents 5776c86 + 0e5df88 commit 5e5ffb5

File tree

18 files changed

+385
-117
lines changed

18 files changed

+385
-117
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _setup_packages() -> List:
8888
)
8989

9090
def _setup_install_requires() -> List:
91-
return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "frozendict", "loguru"]
91+
return ["torch>=1.7.0", "transformers", "pydantic>=2.0", "loguru"]
9292

9393
def _setup_extras() -> Dict:
9494
return {

src/compressed_tensors/config/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
# flake8: noqa
1616
from .base import *
1717
from .dense import *
18+
from .format import *
1819
from .sparse_24_bitmask import *
1920
from .sparse_bitmask import *
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Optional
16+
17+
import torch
18+
from compressed_tensors.config import CompressionFormat, SparsityStructure
19+
from compressed_tensors.quantization import (
20+
QuantizationArgs,
21+
QuantizationStrategy,
22+
QuantizationType,
23+
)
24+
from compressed_tensors.quantization.utils import is_module_quantized
25+
from loguru import logger
26+
27+
28+
__all__ = ["infer_and_set_per_module_quantization_format"]
29+
30+
31+
def _get_quant_compression_format(
32+
input_args: Optional[QuantizationArgs],
33+
weight_args: Optional[QuantizationArgs],
34+
sparsity_structure: Optional[str] = None,
35+
) -> CompressionFormat:
36+
"""
37+
Using the weight and input quantization args as well as an optional
38+
sparsity structure, determine the compression format that should be
39+
applied to a given module
40+
41+
:param input_args: input quantization parameters
42+
:param weight_args: weight quantization parameters
43+
:param sparsity_structure: optional (global) modle sparsity
44+
structure
45+
:return CompresssionFormat for the module
46+
"""
47+
is_24_structure = (
48+
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
49+
)
50+
is_weight_only = weight_args is not None and input_args is None
51+
52+
if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value:
53+
return CompressionFormat.nvfp4_pack_quantized
54+
55+
if is_weight_only: # w4a16 and w8a16
56+
is_valid_pack = (
57+
weight_args.num_bits in [4, 8]
58+
and weight_args.type == QuantizationType.INT.value
59+
)
60+
if not is_valid_pack: # packing only valid for int4 and int 8
61+
return CompressionFormat.naive_quantized
62+
63+
if is_24_structure and weight_args.strategy in (
64+
QuantizationStrategy.CHANNEL.value,
65+
QuantizationStrategy.GROUP.value,
66+
):
67+
# marlin24 kernel only applicable for channel/group quantization
68+
# Note: vLLM may only support group quant for marlin24
69+
return CompressionFormat.marlin_24
70+
return CompressionFormat.pack_quantized
71+
72+
else: # w8a8 float and int
73+
if (
74+
weight_args.type == QuantizationType.FLOAT.value
75+
and weight_args.num_bits == 8
76+
):
77+
return CompressionFormat.float_quantized
78+
if weight_args.type == QuantizationType.INT.value:
79+
return CompressionFormat.int_quantized
80+
81+
return CompressionFormat.naive_quantized
82+
83+
84+
def set_per_module_format(
85+
module: torch.nn.Module, sparsity_structure: Optional[str] = None
86+
):
87+
"""
88+
Determine and set the per module quantization format given quantization args
89+
and sparsity structure.
90+
91+
:param module: module which has its quantization inferred
92+
:param sparsity_structure: optional sparsity applied to the module
93+
94+
"""
95+
weight_scheme = module.quantization_scheme.weights
96+
input_scheme = module.quantization_scheme.input_activations
97+
if weight_scheme is None:
98+
return # no weight quant - nothing to compress
99+
compression_format = _get_quant_compression_format(
100+
input_scheme, weight_scheme, sparsity_structure
101+
)
102+
103+
# If set, we check if it matches our inferred one
104+
if module.quantization_scheme.format is not None:
105+
# If it does not, warn the user
106+
if module.quantization_scheme.format != compression_format.value:
107+
logger.warning(
108+
"The provided format for the module does not match the "
109+
"inferred format. Compression may fail "
110+
)
111+
else:
112+
# If not set, we set ours
113+
module.quantization_scheme.format = compression_format.value
114+
115+
116+
def infer_and_set_per_module_quantization_format(
117+
model: torch.nn.Module,
118+
sparsity_structure: Optional[str] = None,
119+
) -> List[str]:
120+
"""
121+
Infers the quantization format for a model based on its state and provided
122+
compression arguments. Updates thhe quantization_scheme.format value
123+
based on the inferred format. Returns the unique list of formats in the model
124+
or None if empty list
125+
126+
For a summary of the formats, see `docs/guides/compression_formats.md`.
127+
128+
:param model: model to check for quantization
129+
:param sparsity_structure: optional sparsity applied to the module
130+
:return compression format appropriate for model
131+
"""
132+
unique_formats = []
133+
for submodule in model.modules():
134+
if is_module_quantized(submodule):
135+
assert hasattr(submodule, "quantization_scheme")
136+
set_per_module_format(submodule, sparsity_structure)
137+
if submodule.quantization_scheme.format not in unique_formats:
138+
unique_formats.append(submodule.quantization_scheme.format)
139+
140+
if len(unique_formats) > 0:
141+
return unique_formats
142+
return [CompressionFormat.dense.value]

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
calculate_range,
3030
compute_dynamic_scales_and_zp,
3131
)
32-
from compressed_tensors.utils import safe_permute
3332
from torch.nn import Module
3433

3534

@@ -294,7 +293,7 @@ def _process_quantization(
294293
group_sizes = group_sizes[torch.argsort(group_indices)]
295294

296295
perm = torch.argsort(g_idx)
297-
x = safe_permute(x, perm, dim=1)
296+
x = x.index_select(-1, perm)
298297

299298
# Maintain all dimensions except the last dim, which is divided by group_size
300299
reshaped_dims = (
@@ -328,7 +327,8 @@ def _process_quantization(
328327
output = output.to(output_dtype)
329328

330329
if not is_column_order:
331-
output = safe_permute(output, torch.argsort(perm), dim=1)
330+
inv_perm = torch.argsort(perm)
331+
output = output.index_select(-1, inv_perm)
332332

333333
else: # covers channel, token and tensor strategies
334334
if do_quantize:

src/compressed_tensors/transform/apply.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Dict
16+
1517
import torch
18+
from accelerate.utils import has_offloaded_params
1619
from compressed_tensors import TRANSFORM_CONFIG_NAME
1720
from compressed_tensors.transform import TransformConfig, TransformFactory
1821

@@ -34,3 +37,35 @@ def apply_transform_config(model: torch.nn.Module, config: TransformConfig):
3437

3538
# attach config to model for compression/serialization
3639
setattr(model, TRANSFORM_CONFIG_NAME, config)
40+
41+
# ensure that tied weight transforms can be serialized without aliases
42+
# In the future, this could be done by transformers or model compressor
43+
# which would make this more robust to changing dispatches after transforms
44+
_tie_offloaded_tensors(model)
45+
46+
47+
def _tie_offloaded_tensors(model: torch.nn.Module):
48+
"""
49+
When accelerate replaces tensors with meta tensors during offloading, the meta
50+
tensors may not be identical, even if the offloaded values are identical.
51+
52+
However, transformers can only serialize correctly if meta tensors are identical
53+
(see transformers#39263).
54+
55+
This function collects all meta tensors which have shared offloaded values and sets
56+
those tensors to be identical so that they can be removed during serialization
57+
58+
:param model: model potentially containing offloaded meta tensors to fix
59+
"""
60+
61+
# ensure that if a location shares an offloaded tensor pointers, that the
62+
# meta tensor is also identical (assigned to the first instance of parameter)
63+
ptr_to_meta: Dict[int, torch.nn.Parameter] = dict()
64+
for module in model.modules():
65+
if has_offloaded_params(module):
66+
for key, _ in module.named_parameters(recurse=False):
67+
offloaded_ptr = module._hf_hook.weights_map[key].data_ptr()
68+
69+
if offloaded_ptr not in ptr_to_meta:
70+
ptr_to_meta[offloaded_ptr] = getattr(module, key)
71+
setattr(module, key, ptr_to_meta[offloaded_ptr])

src/compressed_tensors/transform/factory/base.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# limitations under the License.
1414

1515
from abc import ABC, abstractmethod
16-
from collections import defaultdict
17-
from typing import List, Optional, Set, Tuple
16+
from typing import List, Optional
1817

1918
import torch
2019
import torch.nn.utils.parametrize as P
@@ -57,7 +56,6 @@ def __init__(self, name: str, scheme: TransformScheme, seed: Optional[int] = Non
5756
self.name = name
5857
self.scheme = scheme
5958
self.generator = torch.Generator()
60-
self.transforms = list()
6159
if seed is not None:
6260
self.generator.manual_seed(seed)
6361

@@ -101,8 +99,6 @@ def apply_to_model(self, model: Module, use_tqdm=True):
10199
for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
102100
self._apply_to_module(module, arg)
103101

104-
self._update_tied_weights()
105-
106102
def _apply_to_module(self, module: Module, args: TransformArgs):
107103
"""
108104
Create transforms and apply them to the module
@@ -120,7 +116,6 @@ def _apply_to_module(self, module: Module, args: TransformArgs):
120116
# create transform as submodule
121117
transform_name = f"{self.name}_{args.location}"
122118
transform = self.create_transform(module, args)
123-
self.transforms.append(transform)
124119
register_offload_module(module, transform_name, transform)
125120

126121
# register input transformation hook
@@ -165,31 +160,6 @@ def output_hook(_, _input, output):
165160
else:
166161
raise NotImplementedError()
167162

168-
def _update_tied_weights(self):
169-
"""
170-
Populate the `_dynamic_tied_weights_keys` attribute of transforms,
171-
which is used by transformers to detect and remove shared pointers
172-
during saving
173-
"""
174-
# map from data_ptrs to keys
175-
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
176-
for transform in self.transforms:
177-
for name, param in transform.named_parameters(recurse=False):
178-
# NOTE: previously asserted that parent._hf_hook.place_submodules=False
179-
if has_offloaded_params(transform):
180-
param = transform._hf_hook.weights_map[name]
181-
ptr_to_keys[param.data_ptr()].append((transform, name))
182-
183-
# populate `_dynamic_tied_weights_keys` if there is more than one key
184-
# and ensure that they share tensors
185-
for shared_keys in ptr_to_keys.values():
186-
if len(shared_keys) > 1:
187-
tensor = getattr(shared_keys[0][0], shared_keys[0][1])
188-
189-
for transform, name in shared_keys:
190-
transform._dynamic_tied_weights_keys.add(name)
191-
setattr(transform, name, tensor)
192-
193163

194164
class TransformBase(InternalModule, ABC):
195165
"""
@@ -198,11 +168,7 @@ class TransformBase(InternalModule, ABC):
198168

199169
args: TransformArgs
200170
weight: Parameter
201-
_dynamic_tied_weights_keys: Set[str]
202-
203-
def __init__(self):
204-
super().__init__()
205-
self._dynamic_tied_weights_keys = set()
171+
_dynamic_tied_weights_keys: List[str] = ["weight"]
206172

207173
@abstractmethod
208174
def forward(self, value: Tensor) -> Tensor:

src/compressed_tensors/transform/factory/hadamard.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
15+
from typing import List, Optional
1616

1717
import torch
1818
from compressed_tensors.transform import TransformArgs, TransformScheme
@@ -52,7 +52,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5252
:param args: defines how the transform will be applied to the module
5353
"""
5454
assert hasattr(module, "weight")
55-
size = get_transform_size(module, args.location, self.scheme.head_dim)
55+
size = get_transform_size(module, args.location, self.scheme.block_size)
5656
exec_device = get_execution_device(module)
5757
device = get_offloaded_device(module)
5858
precision = self.scheme.precision if args.is_online() else torch.float64
@@ -84,6 +84,8 @@ def _create_permutation(self, weight: Parameter) -> Parameter:
8484

8585

8686
class HadamardTransform(TransformBase):
87+
_dynamic_tied_weights_keys: List[str] = ["weight", "perm"]
88+
8789
def __init__(
8890
self,
8991
weight: Parameter,

src/compressed_tensors/transform/factory/matrix_multiply.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def create_transform(self, module: Module, args: TransformArgs):
5151
:param args: defines how the transform will be applied to the module
5252
"""
5353
assert hasattr(module, "weight")
54-
size = get_transform_size(module, args.location, self.scheme.head_dim)
54+
size = get_transform_size(module, args.location, self.scheme.block_size)
5555
device = get_offloaded_device(module)
5656
precision = self.scheme.precision if args.is_online() else torch.float64
5757

src/compressed_tensors/transform/transform_scheme.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from compressed_tensors.transform import TransformArgs
1919
from compressed_tensors.utils import TorchDtype
20-
from pydantic import BaseModel, ConfigDict, Field
20+
from pydantic import BaseModel, ConfigDict, Field, model_validator
2121

2222

2323
__all__ = ["TransformScheme"]
@@ -36,6 +36,8 @@ class TransformScheme(BaseModel):
3636
:param randomize: True if uniquely randomized transform weights should be used,
3737
otherwise use identical transform weights where applicable
3838
:param requires_grad: True if weights include gradients for training
39+
:param block_size: If set, the transform matrix will be block diagonal, with each
40+
block being a square matrix of this size.
3941
:param precision: Precision at which this transform should be applied during online
4042
rotations. Fused (offline) rotations are always performed in float64
4143
"""
@@ -44,7 +46,21 @@ class TransformScheme(BaseModel):
4446
apply: List[TransformArgs] = Field(default_factory=list)
4547
randomize: bool = Field(default=False)
4648
requires_grad: bool = Field(default=False)
47-
head_dim: Optional[int] = Field(default=None)
49+
block_size: Optional[int] = Field(default=None)
50+
head_dim: Optional[int] = Field(
51+
default=None, deprecated="head_dim is deprecated, use block_size instead"
52+
)
4853
precision: TorchDtype = Field(default=torch.float32)
4954

55+
@model_validator(mode="after")
56+
def validate_model_after(model: "TransformScheme") -> "TransformScheme":
57+
"""
58+
If head_dim is used instead of block_size, set block_size to head_dim
59+
and remove head_dim
60+
"""
61+
if model.block_size is None and model.head_dim is not None:
62+
model.block_size = model.head_dim
63+
model.head_dim = None
64+
return model
65+
5066
model_config = ConfigDict(extra="forbid")

0 commit comments

Comments
 (0)