Skip to content

Commit d3ee6c2

Browse files
committed
Bump version for float8 dynamic quant and weight only quant configs
Summary: This PR changes the default VERSION for Float8DynamicActivationFloat8WeightConfig and Float8WeightOnlyConfig from 1 to 2 and makes the VERSION 1 config and VERSION 1 quantized models deprecated, more details in: #2649 Also extended current config serialization to work with multiple config versions Deprecation Note: ``` from transformers import AutoModelForCausalLM, AutoTokenizer model_name = "torchao-testing/opt-125m-float8dq-row-v1-0.13-dev" quantized_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="bfloat16", device_map="cuda", ) /data/users/jerryzh/ao/torchao/core/config.py:249: UserWarning: Stored version is not the same as current default version of the config: stored_version=1, current_version=2, please check the deprecation warning warnings.warn( /data/users/jerryzh/ao/torchao/dtypes/floatx/float8_layout.py:113: UserWarning: Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see #2649 for more details warnings.warn( ``` Suggestion: upgrade torchao to 0.13 and later and generate the checkpoint again: ``` quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) ``` Or download the checkpoint again (please let us know if the checkpoint is not updated) Test Plan: tested with serializing a model with VERSION 1 config and load it, and checks warnings are properly printed ``` python test/integration/test_loading_deprecated_checkpoint.py ``` Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2650, branch: jerryzh168/stack/14
1 parent 3b4bc98 commit d3ee6c2

File tree

8 files changed

+163
-80
lines changed

8 files changed

+163
-80
lines changed

test/core/test_config.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import os
99
import tempfile
10+
import warnings
1011
from dataclasses import dataclass
1112
from unittest import mock
1213

@@ -15,7 +16,6 @@
1516

1617
from torchao.core.config import (
1718
AOBaseConfig,
18-
VersionMismatchError,
1919
config_from_dict,
2020
config_to_dict,
2121
)
@@ -176,7 +176,7 @@ def test_disallowed_modules():
176176

177177

178178
def test_version_mismatch():
179-
"""Test that version mismatch raises an error during reconstruction."""
179+
"""Test that version mismatch prints a warning during reconstruction."""
180180
# Create a config
181181
dummy_config = DummyNonAllowedConfig()
182182
reconstructable = config_to_dict(dummy_config)
@@ -186,11 +186,13 @@ def test_version_mismatch():
186186

187187
# Patch to allow the module but should still fail due to version mismatch
188188
with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}):
189-
with pytest.raises(
190-
VersionMismatchError,
191-
match="Version mismatch for DummyNonAllowedConfig: stored version 1 != current version 2",
192-
):
189+
with warnings.catch_warnings(record=True) as caught_warnings:
193190
config_from_dict(reconstructable)
191+
assert any(
192+
"Stored version is not the same as current default version of the config"
193+
in str(w.message)
194+
for w in caught_warnings
195+
), "Didn't get expected warning message for version mismatch"
194196

195197

196198
def test_default_version():

test/dtypes/test_affine_quantized_float.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,14 @@
3030
from torchao.float8.float8_utils import compute_error
3131
from torchao.quantization import (
3232
Float8DynamicActivationFloat8WeightConfig,
33-
float8_dynamic_activation_float8_weight,
34-
float8_weight_only,
33+
Float8StaticActivationFloat8WeightConfig,
34+
Float8WeightOnlyConfig,
3535
quantize_,
3636
)
3737
from torchao.quantization.granularity import (
3838
PerRow,
3939
PerTensor,
4040
)
41-
from torchao.quantization.quant_api import (
42-
float8_static_activation_float8_weight,
43-
)
4441
from torchao.quantization.quant_primitives import (
4542
MappingType,
4643
_choose_scale_float8,
@@ -119,11 +116,13 @@ def test_fp8_linear_variants(
119116
)
120117
mode_map = {
121118
"dynamic": partial(
122-
float8_dynamic_activation_float8_weight, granularity=granularity
119+
Float8DynamicActivationFloat8WeightConfig,
120+
granularity=granularity,
121+
VERSION=1,
123122
),
124-
"weight-only": float8_weight_only,
123+
"weight-only": partial(Float8WeightOnlyConfig, VERSION=1),
125124
"static": partial(
126-
float8_static_activation_float8_weight,
125+
Float8StaticActivationFloat8WeightConfig,
127126
scale=scale,
128127
granularity=granularity,
129128
),
@@ -152,7 +151,7 @@ def test_fp8_linear_variants(
152151
)
153152
def test_invalid_granularity(self):
154153
with pytest.raises(ValueError, match="Invalid granularity specification"):
155-
float8_dynamic_activation_float8_weight(granularity="invalid")
154+
Float8DynamicActivationFloat8WeightConfig(granularity="invalid", VERSION=1)
156155

157156
@unittest.skipIf(
158157
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
@@ -162,7 +161,9 @@ def test_mismatched_granularity(self):
162161
ValueError,
163162
match="Different granularities for activation and weight are not supported",
164163
):
165-
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))
164+
Float8DynamicActivationFloat8WeightConfig(
165+
granularity=(PerTensor(), PerRow()), VERSION=1
166+
)
166167

167168
@unittest.skipIf(
168169
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
@@ -172,8 +173,9 @@ class UnsupportedGranularity:
172173
pass
173174

174175
with pytest.raises(ValueError, match="Invalid granularity types"):
175-
float8_dynamic_activation_float8_weight(
176-
granularity=(UnsupportedGranularity(), UnsupportedGranularity())
176+
Float8DynamicActivationFloat8WeightConfig(
177+
granularity=(UnsupportedGranularity(), UnsupportedGranularity()),
178+
VERSION=1,
177179
)
178180

179181
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -187,7 +189,10 @@ def test_per_row_with_float32(self):
187189
):
188190
model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda")
189191
quantize_(
190-
model, float8_dynamic_activation_float8_weight(granularity=PerRow())
192+
model,
193+
Float8DynamicActivationFloat8WeightConfig(
194+
granularity=PerRow(), VERSION=1
195+
),
191196
)
192197

193198
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -201,11 +206,13 @@ def test_serialization(self, mode: str):
201206

202207
mode_map = {
203208
"dynamic": partial(
204-
float8_dynamic_activation_float8_weight, granularity=PerTensor()
209+
Float8DynamicActivationFloat8WeightConfig,
210+
granularity=PerTensor(),
211+
VERSION=1,
205212
),
206-
"weight-only": float8_weight_only,
213+
"weight-only": partial(Float8WeightOnlyConfig, VERSION=1),
207214
"static": partial(
208-
float8_static_activation_float8_weight,
215+
Float8StaticActivationFloat8WeightConfig,
209216
scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"),
210217
granularity=PerTensor(),
211218
),
@@ -275,7 +282,10 @@ def test_fp8_weight_dimension_warning(self):
275282
"torchao.quantization.quant_api", level="INFO"
276283
) as log_context:
277284
quantize_(
278-
model, float8_dynamic_activation_float8_weight(granularity=PerTensor())
285+
model,
286+
Float8DynamicActivationFloat8WeightConfig(
287+
granularity=PerTensor(), VERSION=1
288+
),
279289
)
280290
print(model)
281291

@@ -320,7 +330,8 @@ def test_mm_float8dq_per_row(
320330
)
321331
test_linear = copy.deepcopy(ref_linear)
322332
quantize_(
323-
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
333+
test_linear,
334+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), VERSION=1),
324335
)
325336

326337
quant_weight = test_linear.weight
@@ -472,7 +483,10 @@ def test_float8_tensor_slicing_basic(self, granularity):
472483
# Create and quantize a model
473484
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
474485
quantize_(
475-
model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
486+
model,
487+
Float8DynamicActivationFloat8WeightConfig(
488+
granularity=granularity, VERSION=1
489+
),
476490
)
477491

478492
weight_impl = model.weight.original_weight_tensor.tensor_impl
@@ -506,7 +520,10 @@ def test_float8_tensor_slicing_per_tensor(self):
506520
# Create and quantize with per-tensor granularity
507521
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
508522
quantize_(
509-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
523+
model,
524+
Float8DynamicActivationFloat8WeightConfig(
525+
granularity=PerTensor(), VERSION=1
526+
),
510527
)
511528

512529
original_weight = model.weight
@@ -537,7 +554,8 @@ def test_float8_tensor_slicing_per_row(self):
537554
# Create and quantize with per-row granularity
538555
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
539556
quantize_(
540-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
557+
model,
558+
Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), VERSION=1),
541559
)
542560

543561
original_weight = model.weight # Shape: (32, 64)
@@ -575,7 +593,10 @@ def test_float8_tensor_slicing_edge_cases(self):
575593
# Create and quantize a model
576594
model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype)
577595
quantize_(
578-
model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor())
596+
model,
597+
Float8DynamicActivationFloat8WeightConfig(
598+
granularity=PerTensor(), VERSION=1
599+
),
579600
)
580601

581602
original_weight = model.weight
@@ -613,7 +634,9 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity):
613634
quant_model = copy.deepcopy(ref_model)
614635
quantize_(
615636
quant_model,
616-
Float8DynamicActivationFloat8WeightConfig(granularity=granularity),
637+
Float8DynamicActivationFloat8WeightConfig(
638+
granularity=granularity, VERSION=1
639+
),
617640
)
618641

619642
# Create input with batch size that works well with slicing
@@ -743,7 +766,12 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode):
743766
m = torch.nn.Sequential(
744767
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
745768
)
746-
quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity=granularity))
769+
quantize_(
770+
m,
771+
Float8DynamicActivationFloat8WeightConfig(
772+
granularity=granularity, VERSION=1
773+
),
774+
)
747775
m = torch.compile(m, mode=torch_compile_mode)
748776
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
749777

test/float8/test_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,10 +473,10 @@ def test_quantize(self):
473473
m = nn.Sequential(nn.Linear(32, 32)).cuda()
474474
m = convert_to_float8_training(m)
475475
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
476-
from torchao.quantization.quant_api import float8_weight_only, quantize_
476+
from torchao.quantization import Float8WeightOnlyConfig, quantize_
477477

478-
quantize_(m, float8_weight_only())
479-
assert m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn, (
478+
quantize_(m, Float8WeightOnlyConfig())
479+
assert m[0].weight.qdata.dtype == torch.float8_e4m3fn, (
480480
"Post quantization dtype should be torch.float8_e4m3fn"
481481
)
482482
with torch.no_grad():
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import unittest
7+
import warnings
8+
9+
import torch
10+
from torch.testing._internal import common_utils
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
run_tests,
14+
)
15+
from transformers import AutoModelForCausalLM, AutoTokenizer
16+
17+
from torchao.utils import is_sm_at_least_89
18+
19+
_MODEL_NAMES = [
20+
"torchao-testing/opt-125m-float8dq-row-v1-0.13-dev",
21+
]
22+
23+
24+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
25+
@unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+")
26+
class TestLoadingDeprecatedCheckpoint(TestCase):
27+
@common_utils.parametrize("model_name", _MODEL_NAMES)
28+
def test_load_model_and_run(self, model_name):
29+
"""Test that we print correct warning message when loading a deprecated checkpoint"""
30+
# Load and quantize model
31+
with warnings.catch_warnings(record=True) as caught_warnings:
32+
quantized_model = AutoModelForCausalLM.from_pretrained(
33+
model_name,
34+
torch_dtype="bfloat16",
35+
device_map="cuda",
36+
)
37+
assert any(
38+
"Stored version is not the same as current default version of the config"
39+
in str(w.message)
40+
for w in caught_warnings
41+
), "Didn't get expected warning message for version mismatch"
42+
43+
assert any(
44+
"Models quantized with VERSION 1 of Float8DynamicActivationFloat8WeightConfig is deprecated"
45+
in str(w.message)
46+
for w in caught_warnings
47+
), "Didn't get expected warning message for deprecation"
48+
49+
tokenizer = AutoTokenizer.from_pretrained(model_name)
50+
prompt = ("Hello, my name is",)
51+
inputs = tokenizer(
52+
prompt,
53+
return_tensors="pt",
54+
).to("cuda")
55+
generated_ids = quantized_model.generate(**inputs, max_new_tokens=128)
56+
# make sure it runs
57+
_ = tokenizer.batch_decode(
58+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
59+
)
60+
61+
62+
common_utils.instantiate_parametrized_tests(TestLoadingDeprecatedCheckpoint)
63+
64+
if __name__ == "__main__":
65+
run_tests()

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ def test_fp8_linear_variants(
184184
config = Float8DynamicActivationFloat8WeightConfig(
185185
granularity=granularity,
186186
kernel_preference=kernel_preference,
187-
VERSION=2,
188187
)
189188
else:
190189
assert mode == "weight-only", f"Unsupported mode: {mode}"
@@ -210,9 +209,7 @@ def test_fp8_linear_variants(
210209
"AssertionError: tensor(False, device='cuda:0') is not true : sqnr: -2.90625, will fix a bit later",
211210
)
212211
def test_slice(self, granularity):
213-
config = Float8DynamicActivationFloat8WeightConfig(
214-
granularity=granularity, VERSION=2
215-
)
212+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
216213
dtype = torch.bfloat16
217214
device = "cuda"
218215
dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device)
@@ -273,9 +270,7 @@ def test_slice(self, granularity):
273270

274271
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
275272
def test_slice_preserves_aliasing(self, granularity):
276-
config = Float8DynamicActivationFloat8WeightConfig(
277-
granularity=granularity, VERSION=2
278-
)
273+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
279274
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
280275
l.weight = torch.nn.Parameter(
281276
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
@@ -296,9 +291,7 @@ def test_slice_and_copy_similar_to_vllm(self, granularity):
296291

297292
dtype = torch.bfloat16
298293
device = "cuda"
299-
config = Float8DynamicActivationFloat8WeightConfig(
300-
granularity=granularity, VERSION=2
301-
)
294+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
302295
l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype)
303296
quantize_(l, config)
304297

@@ -335,9 +328,7 @@ def test_slice_and_copy_similar_to_vllm(self, granularity):
335328
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
336329
def test_bmm(self):
337330
# only support per row quantization
338-
config = Float8DynamicActivationFloat8WeightConfig(
339-
granularity=PerRow(), VERSION=2
340-
)
331+
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
341332

342333
class M(torch.nn.Module):
343334
def __init__(self, weight):
@@ -369,9 +360,7 @@ def forward(self, x):
369360
],
370361
)
371362
def test_to_device(self, granularity, sizes):
372-
config = Float8DynamicActivationFloat8WeightConfig(
373-
granularity=granularity, VERSION=2
374-
)
363+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
375364
M, N, K = sizes
376365
dtype = torch.bfloat16
377366
for device in self.GPU_DEVICES:
@@ -401,9 +390,7 @@ def test_to_device(self, granularity, sizes):
401390
],
402391
)
403392
def test_cat(self, granularity, sizes):
404-
config = Float8DynamicActivationFloat8WeightConfig(
405-
granularity=granularity, VERSION=2
406-
)
393+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
407394
dtype = torch.bfloat16
408395
device = "cuda"
409396
M, N, K = sizes
@@ -461,9 +448,7 @@ def test_moe_weight_reshape_ops(self):
461448
dtype = torch.bfloat16
462449
device = "cuda"
463450

464-
bmm_config = Float8DynamicActivationFloat8WeightConfig(
465-
granularity=granularity, VERSION=2
466-
)
451+
bmm_config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
467452
moe_config = MoEQuantConfig(bmm_config)
468453

469454
batch_size = 4

0 commit comments

Comments
 (0)