Skip to content

Commit c649159

Browse files
dsikkakylesayrs
andauthored
[Quantization Format] Add functionality to infer format (#452)
* add format infer code * update * update * add loguru * use dense not None * update * Apply suggestion from @kylesayrs Co-authored-by: Kyle Sayers <[email protected]> --------- Co-authored-by: Kyle Sayers <[email protected]>
1 parent 42363c3 commit c649159

File tree

3 files changed

+208
-0
lines changed

3 files changed

+208
-0
lines changed

src/compressed_tensors/config/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
# flake8: noqa
1616
from .base import *
1717
from .dense import *
18+
from .format import *
1819
from .sparse_24_bitmask import *
1920
from .sparse_bitmask import *
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import List, Optional
16+
17+
import torch
18+
from compressed_tensors.config import CompressionFormat, SparsityStructure
19+
from compressed_tensors.quantization import (
20+
QuantizationArgs,
21+
QuantizationStrategy,
22+
QuantizationType,
23+
)
24+
from compressed_tensors.quantization.utils import is_module_quantized
25+
from loguru import logger
26+
27+
28+
__all__ = ["infer_and_set_per_module_quantization_format"]
29+
30+
31+
def _get_quant_compression_format(
32+
input_args: Optional[QuantizationArgs],
33+
weight_args: Optional[QuantizationArgs],
34+
sparsity_structure: Optional[str] = None,
35+
) -> CompressionFormat:
36+
"""
37+
Using the weight and input quantization args as well as an optional
38+
sparsity structure, determine the compression format that should be
39+
applied to a given module
40+
41+
:param input_args: input quantization parameters
42+
:param weight_args: weight quantization parameters
43+
:param sparsity_structure: optional (global) modle sparsity
44+
structure
45+
:return CompresssionFormat for the module
46+
"""
47+
is_24_structure = (
48+
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
49+
)
50+
is_weight_only = weight_args is not None and input_args is None
51+
52+
if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value:
53+
return CompressionFormat.nvfp4_pack_quantized
54+
55+
if is_weight_only: # w4a16 and w8a16
56+
is_valid_pack = (
57+
weight_args.num_bits in [4, 8]
58+
and weight_args.type == QuantizationType.INT.value
59+
)
60+
if not is_valid_pack: # packing only valid for int4 and int 8
61+
return CompressionFormat.naive_quantized
62+
63+
if is_24_structure and weight_args.strategy in (
64+
QuantizationStrategy.CHANNEL.value,
65+
QuantizationStrategy.GROUP.value,
66+
):
67+
# marlin24 kernel only applicable for channel/group quantization
68+
# Note: vLLM may only support group quant for marlin24
69+
return CompressionFormat.marlin_24
70+
return CompressionFormat.pack_quantized
71+
72+
else: # w8a8 float and int
73+
if (
74+
weight_args.type == QuantizationType.FLOAT.value
75+
and weight_args.num_bits == 8
76+
):
77+
return CompressionFormat.float_quantized
78+
if weight_args.type == QuantizationType.INT.value:
79+
return CompressionFormat.int_quantized
80+
81+
return CompressionFormat.naive_quantized
82+
83+
84+
def set_per_module_format(
85+
module: torch.nn.Module, sparsity_structure: Optional[str] = None
86+
):
87+
"""
88+
Determine and set the per module quantization format given quantization args
89+
and sparsity structure.
90+
91+
:param module: module which has its quantization inferred
92+
:param sparsity_structure: optional sparsity applied to the module
93+
94+
"""
95+
weight_scheme = module.quantization_scheme.weights
96+
input_scheme = module.quantization_scheme.input_activations
97+
if weight_scheme is None:
98+
return # no weight quant - nothing to compress
99+
compression_format = _get_quant_compression_format(
100+
input_scheme, weight_scheme, sparsity_structure
101+
)
102+
103+
# If set, we check if it matches our inferred one
104+
if module.quantization_scheme.format is not None:
105+
# If it does not, warn the user
106+
if module.quantization_scheme.format != compression_format.value:
107+
logger.warning(
108+
"The provided format for the module does not match the "
109+
"inferred format. Compression may fail "
110+
)
111+
else:
112+
# If not set, we set ours
113+
module.quantization_scheme.format = compression_format.value
114+
115+
116+
def infer_and_set_per_module_quantization_format(
117+
model: torch.nn.Module,
118+
sparsity_structure: Optional[str] = None,
119+
) -> List[str]:
120+
"""
121+
Infers the quantization format for a model based on its state and provided
122+
compression arguments. Updates thhe quantization_scheme.format value
123+
based on the inferred format. Returns the unique list of formats in the model
124+
or None if empty list
125+
126+
For a summary of the formats, see `docs/guides/compression_formats.md`.
127+
128+
:param model: model to check for quantization
129+
:param sparsity_structure: optional sparsity applied to the module
130+
:return compression format appropriate for model
131+
"""
132+
unique_formats = []
133+
for submodule in model.modules():
134+
if is_module_quantized(submodule):
135+
assert hasattr(submodule, "quantization_scheme")
136+
set_per_module_format(submodule, sparsity_structure)
137+
if submodule.quantization_scheme.format not in unique_formats:
138+
unique_formats.append(submodule.quantization_scheme.format)
139+
140+
if len(unique_formats) > 0:
141+
return unique_formats
142+
return [CompressionFormat.dense.value]
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from collections import OrderedDict
16+
17+
import pytest
18+
import torch
19+
from compressed_tensors.config.format import (
20+
infer_and_set_per_module_quantization_format,
21+
)
22+
from compressed_tensors.quantization import preset_name_to_scheme
23+
24+
25+
@pytest.mark.parametrize(
26+
"preset,sparsity_structure,expected_format",
27+
[
28+
["W8A8", "unstructured", "int-quantized"],
29+
["W8A16", "unstructured", "pack-quantized"],
30+
["W8A16", "2:4", "marlin-24"],
31+
["W4A16", "unstructured", "pack-quantized"],
32+
["W4A16", "2:4", "marlin-24"],
33+
["FP8", "unstructured", "float-quantized"],
34+
],
35+
)
36+
def test_infer_quant_format(preset, sparsity_structure, expected_format):
37+
quant_scheme = preset_name_to_scheme(preset, targets=["Linear"])
38+
39+
dummy_model = torch.nn.Sequential(
40+
OrderedDict(
41+
[
42+
("fc1", torch.nn.Linear(8, 16, bias=True)),
43+
("fc2", torch.nn.Linear(16, 32, bias=True)),
44+
(
45+
"block1",
46+
torch.nn.Sequential(
47+
OrderedDict(
48+
[
49+
("fc1", torch.nn.Linear(32, 16, bias=True)),
50+
("fc2", torch.nn.Linear(16, 8, bias=True)),
51+
]
52+
)
53+
),
54+
),
55+
]
56+
)
57+
)
58+
59+
for _, module in dummy_model.named_modules():
60+
module.quantization_scheme = quant_scheme
61+
62+
inferred_format = infer_and_set_per_module_quantization_format(
63+
dummy_model, sparsity_structure=sparsity_structure
64+
)
65+
assert inferred_format[0] == expected_format

0 commit comments

Comments
 (0)