Skip to content

Commit aecb127

Browse files
author
Sara Adkins
authored
make preset more explicit (#105)
1 parent b341803 commit aecb127

File tree

2 files changed

+49
-11
lines changed

2 files changed

+49
-11
lines changed

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from compressed_tensors.quantization.quant_args import (
1919
QuantizationArgs,
20+
QuantizationStrategy,
2021
QuantizationType,
2122
)
2223
from pydantic import BaseModel
@@ -110,15 +111,55 @@ def is_preset_scheme(name: str) -> bool:
110111
return name.upper() in PRESET_SCHEMES
111112

112113

113-
W8A8 = dict(weights=QuantizationArgs(), input_activations=QuantizationArgs())
114+
W8A8 = dict(
115+
weights=QuantizationArgs(
116+
num_bits=8,
117+
symmetric=True,
118+
type=QuantizationType.INT,
119+
strategy=QuantizationStrategy.CHANNEL,
120+
),
121+
input_activations=QuantizationArgs(
122+
num_bits=8,
123+
symmetric=True,
124+
type=QuantizationType.INT,
125+
strategy=QuantizationStrategy.TOKEN,
126+
dynamic=True,
127+
),
128+
)
114129

115-
W4A16 = dict(weights=QuantizationArgs(num_bits=4, group_size=128))
130+
W8A16 = dict(
131+
weights=QuantizationArgs(
132+
num_bits=8,
133+
symmetric=True,
134+
type=QuantizationType.INT,
135+
strategy=QuantizationStrategy.CHANNEL,
136+
)
137+
)
116138

117-
FP8 = dict(
118-
weights=QuantizationArgs(type=QuantizationType.FLOAT),
119-
input_activations=QuantizationArgs(type=QuantizationType.FLOAT),
139+
W4A16 = dict(
140+
weights=QuantizationArgs(
141+
num_bits=4,
142+
symmetric=True,
143+
type=QuantizationType.INT,
144+
strategy=QuantizationStrategy.GROUP,
145+
group_size=128,
146+
)
120147
)
121148

122-
PRESET_SCHEMES = {"W8A8": W8A8, "W4A16": W4A16, "FP8": FP8}
149+
FP8 = dict(
150+
weights=QuantizationArgs(
151+
num_bits=8,
152+
symmetric=True,
153+
type=QuantizationType.FLOAT,
154+
strategy=QuantizationStrategy.TENSOR,
155+
),
156+
input_activations=QuantizationArgs(
157+
num_bits=8,
158+
symmetric=True,
159+
type=QuantizationType.FLOAT,
160+
strategy=QuantizationStrategy.TENSOR,
161+
dynamic=False,
162+
),
163+
)
123164

124-
PRESET_SCHEMES = {"W8A8": W8A8, "W4A16": W4A16, "FP8": FP8}
165+
PRESET_SCHEMES = {"W8A8": W8A8, "W8A16": W8A16, "W4A16": W4A16, "FP8": FP8}

tests/test_quantization/test_quant_config.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,7 @@ def test_need_config_groups():
6363

6464
@pytest.mark.parametrize(
6565
"scheme_name",
66-
[
67-
"W8A8",
68-
"W4A16",
69-
],
66+
["W8A8", "W8A16", "W4A16", "FP8"],
7067
)
7168
def test_load_scheme_from_preset(scheme_name: str):
7269
targets = ["Linear"]

0 commit comments

Comments
 (0)