diff --git a/test/core/test_config.py b/test/core/test_config.py index 91a5f67767..fc752d989e 100644 --- a/test/core/test_config.py +++ b/test/core/test_config.py @@ -7,6 +7,7 @@ import json import os import tempfile +import warnings from dataclasses import dataclass from unittest import mock @@ -15,7 +16,6 @@ from torchao.core.config import ( AOBaseConfig, - VersionMismatchError, config_from_dict, config_to_dict, ) @@ -151,7 +151,9 @@ def test_reconstructable_dict_file_round_trip(config): # Define a dummy config in a non-allowed module @dataclass class DummyNonAllowedConfig(AOBaseConfig): - VERSION = 2 + # NOTE: must be `version: int` (with type annotations) to + # overload the version variable from AOBaseConfig + version: int = 2 value: int = 42 @@ -172,11 +174,11 @@ def test_disallowed_modules(): reconstructed = config_from_dict(reconstructable) assert isinstance(reconstructed, DummyNonAllowedConfig) assert reconstructed.value == 42 - assert reconstructed.VERSION == 2 + assert reconstructed.version == 2 def test_version_mismatch(): - """Test that version mismatch raises an error during reconstruction.""" + """Test that version mismatch prints a warning during reconstruction.""" # Create a config dummy_config = DummyNonAllowedConfig() reconstructable = config_to_dict(dummy_config) @@ -186,17 +188,19 @@ def test_version_mismatch(): # Patch to allow the module but should still fail due to version mismatch with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}): - with pytest.raises( - VersionMismatchError, - match="Version mismatch for DummyNonAllowedConfig: stored version 1 != current version 2", - ): + with warnings.catch_warnings(record=True) as caught_warnings: config_from_dict(reconstructable) + assert any( + "Stored version is not the same as current default version of the config" + in str(w.message) + for w in caught_warnings + ), "Didn't get expected warning message for version mismatch" def test_default_version(): """Making sure the default version for a new config inheriting from AOBaseConfig is always 1 - because it's the default VERSION that all children has when they haven't explicitly - defined a VERSION class variable + because it's the default version that all children has when they haven't explicitly + defined a version class variable """ @dataclass @@ -204,7 +208,7 @@ class DummyConfig(AOBaseConfig): pass config = DummyConfig() - assert config.VERSION == 1, "Default version must be 1" + assert config.version == 1, "Default version must be 1" if __name__ == "__main__": diff --git a/test/dtypes/test_affine_quantized_float.py b/test/dtypes/test_affine_quantized_float.py index 1f88bdd65d..1dfed4dda8 100644 --- a/test/dtypes/test_affine_quantized_float.py +++ b/test/dtypes/test_affine_quantized_float.py @@ -30,17 +30,14 @@ from torchao.float8.float8_utils import compute_error from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, - float8_dynamic_activation_float8_weight, - float8_weight_only, + Float8StaticActivationFloat8WeightConfig, + Float8WeightOnlyConfig, quantize_, ) from torchao.quantization.granularity import ( PerRow, PerTensor, ) -from torchao.quantization.quant_api import ( - float8_static_activation_float8_weight, -) from torchao.quantization.quant_primitives import ( MappingType, _choose_scale_float8, @@ -119,11 +116,13 @@ def test_fp8_linear_variants( ) mode_map = { "dynamic": partial( - float8_dynamic_activation_float8_weight, granularity=granularity + Float8DynamicActivationFloat8WeightConfig, + granularity=granularity, + version=1, ), - "weight-only": float8_weight_only, + "weight-only": partial(Float8WeightOnlyConfig, version=1), "static": partial( - float8_static_activation_float8_weight, + Float8StaticActivationFloat8WeightConfig, scale=scale, granularity=granularity, ), @@ -152,7 +151,7 @@ def test_fp8_linear_variants( ) def test_invalid_granularity(self): with pytest.raises(ValueError, match="Invalid granularity specification"): - float8_dynamic_activation_float8_weight(granularity="invalid") + Float8DynamicActivationFloat8WeightConfig(granularity="invalid") @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" @@ -162,7 +161,9 @@ def test_mismatched_granularity(self): ValueError, match="Different granularities for activation and weight are not supported", ): - float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow())) + Float8DynamicActivationFloat8WeightConfig( + granularity=(PerTensor(), PerRow()) + ) @unittest.skipIf( not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9" @@ -172,8 +173,8 @@ class UnsupportedGranularity: pass with pytest.raises(ValueError, match="Invalid granularity types"): - float8_dynamic_activation_float8_weight( - granularity=(UnsupportedGranularity(), UnsupportedGranularity()) + Float8DynamicActivationFloat8WeightConfig( + granularity=(UnsupportedGranularity(), UnsupportedGranularity()), ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -187,7 +188,8 @@ def test_per_row_with_float32(self): ): model = ToyLinearModel(64, 64).eval().to(torch.float32).to("cuda") quantize_( - model, float8_dynamic_activation_float8_weight(granularity=PerRow()) + model, + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), ) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @@ -201,15 +203,18 @@ def test_serialization(self, mode: str): mode_map = { "dynamic": partial( - float8_dynamic_activation_float8_weight, granularity=PerTensor() + Float8DynamicActivationFloat8WeightConfig, + granularity=PerTensor(), + version=1, ), - "weight-only": float8_weight_only, + "weight-only": partial(Float8WeightOnlyConfig, version=1), "static": partial( - float8_static_activation_float8_weight, + Float8StaticActivationFloat8WeightConfig, scale=torch.tensor(1.0, dtype=torch.float32, device="cuda"), granularity=PerTensor(), ), } + factory = mode_map[mode]() quantize_(model, factory) @@ -275,7 +280,10 @@ def test_fp8_weight_dimension_warning(self): "torchao.quantization.quant_api", level="INFO" ) as log_context: quantize_( - model, float8_dynamic_activation_float8_weight(granularity=PerTensor()) + model, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), version=1 + ), ) print(model) @@ -320,7 +328,8 @@ def test_mm_float8dq_per_row( ) test_linear = copy.deepcopy(ref_linear) quantize_( - test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + test_linear, + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), version=1), ) quant_weight = test_linear.weight @@ -472,7 +481,10 @@ def test_float8_tensor_slicing_basic(self, granularity): # Create and quantize a model model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) quantize_( - model, Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + model, + Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, version=1 + ), ) weight_impl = model.weight.original_weight_tensor.tensor_impl @@ -506,7 +518,10 @@ def test_float8_tensor_slicing_per_tensor(self): # Create and quantize with per-tensor granularity model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) quantize_( - model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()) + model, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), version=1 + ), ) original_weight = model.weight @@ -537,7 +552,8 @@ def test_float8_tensor_slicing_per_row(self): # Create and quantize with per-row granularity model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) quantize_( - model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + model, + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), version=1), ) original_weight = model.weight # Shape: (32, 64) @@ -575,7 +591,10 @@ def test_float8_tensor_slicing_edge_cases(self): # Create and quantize a model model = torch.nn.Linear(64, 32, bias=False).to(device).to(dtype) quantize_( - model, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()) + model, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), version=1 + ), ) original_weight = model.weight @@ -613,7 +632,9 @@ def test_float8_tensor_slicing_functional_correctness(self, granularity): quant_model = copy.deepcopy(ref_model) quantize_( quant_model, - Float8DynamicActivationFloat8WeightConfig(granularity=granularity), + Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, version=1 + ), ) # Create input with batch size that works well with slicing @@ -720,6 +741,7 @@ def test_preprocess_scale_3d_reshape(self): self.assertEqual(result.shape, expected_shape) @torch.no_grad() + @unittest.skip("test is flaky in CI, will turn on a bit later") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf( not is_sm_at_least_90(), "Requires GPU with compute capability >= 9.0" @@ -743,7 +765,13 @@ def test_expected_kernels_on_gpu(self, granularity, torch_compile_mode): m = torch.nn.Sequential( torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16) ) - quantize_(m, Float8DynamicActivationFloat8WeightConfig(granularity=granularity)) + quantize_( + m, + Float8DynamicActivationFloat8WeightConfig( + granularity=granularity, version=1 + ), + ) + m = torch.compile(m, mode=torch_compile_mode) x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) diff --git a/test/float8/test_base.py b/test/float8/test_base.py index c19478e02a..c2b2c5488a 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -473,10 +473,10 @@ def test_quantize(self): m = nn.Sequential(nn.Linear(32, 32)).cuda() m = convert_to_float8_training(m) assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear" - from torchao.quantization.quant_api import float8_weight_only, quantize_ + from torchao.quantization import Float8WeightOnlyConfig, quantize_ - quantize_(m, float8_weight_only()) - assert m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn, ( + quantize_(m, Float8WeightOnlyConfig()) + assert m[0].weight.qdata.dtype == torch.float8_e4m3fn, ( "Post quantization dtype should be torch.float8_e4m3fn" ) with torch.no_grad(): diff --git a/test/integration/test_loading_deprecated_checkpoint.py b/test/integration/test_loading_deprecated_checkpoint.py new file mode 100644 index 0000000000..b7006eba36 --- /dev/null +++ b/test/integration/test_loading_deprecated_checkpoint.py @@ -0,0 +1,70 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import unittest +import warnings + +import torch +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + +from torchao.utils import is_sm_at_least_89 + +_MODEL_NAME_AND_VERSIONS = [ + ("torchao-testing/opt-125m-float8dq-row-v1-0.13-dev", 1), +] + + +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+") +class TestLoadingDeprecatedCheckpoint(TestCase): + @common_utils.parametrize("model_name_and_version", _MODEL_NAME_AND_VERSIONS) + def test_load_model_and_run(self, model_name_and_version): + """Test that we print correct warning message when loading a deprecated checkpoint""" + # Load and quantize model + model_name, version = model_name_and_version + with warnings.catch_warnings(record=True) as caught_warnings: + quantized_model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="bfloat16", + device_map="cuda", + ) + assert any( + "Stored version is not the same as current default version of the config" + in str(w.message) + for w in caught_warnings + ), "Didn't get expected warning message for version mismatch" + + assert any( + "Models quantized with version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated" + in str(w.message) + for w in caught_warnings + ), "Didn't get expected warning message for deprecation" + assert isinstance(quantized_model.config.quantization_config, TorchAoConfig) + assert ( + quantized_model.config.quantization_config.quant_type.version == version + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompt = ("Hello, my name is",) + inputs = tokenizer( + prompt, + return_tensors="pt", + ).to("cuda") + generated_ids = quantized_model.generate(**inputs, max_new_tokens=128) + # make sure it runs + _ = tokenizer.batch_decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + +common_utils.instantiate_parametrized_tests(TestLoadingDeprecatedCheckpoint) + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index e53f1412c2..5372bb280d 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -184,7 +184,6 @@ def test_fp8_linear_variants( config = Float8DynamicActivationFloat8WeightConfig( granularity=granularity, kernel_preference=kernel_preference, - VERSION=2, ) else: assert mode == "weight-only", f"Unsupported mode: {mode}" @@ -210,9 +209,7 @@ def test_fp8_linear_variants( "AssertionError: tensor(False, device='cuda:0') is not true : sqnr: -2.90625, will fix a bit later", ) def test_slice(self, granularity): - config = Float8DynamicActivationFloat8WeightConfig( - granularity=granularity, VERSION=2 - ) + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) dtype = torch.bfloat16 device = "cuda" dummy = torch.nn.Linear(256, 256, bias=False, dtype=dtype, device=device) @@ -273,9 +270,7 @@ def test_slice(self, granularity): @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) def test_slice_preserves_aliasing(self, granularity): - config = Float8DynamicActivationFloat8WeightConfig( - granularity=granularity, VERSION=2 - ) + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16) l.weight = torch.nn.Parameter( torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda") @@ -296,9 +291,7 @@ def test_slice_and_copy_similar_to_vllm(self, granularity): dtype = torch.bfloat16 device = "cuda" - config = Float8DynamicActivationFloat8WeightConfig( - granularity=granularity, VERSION=2 - ) + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype) quantize_(l, config) @@ -335,9 +328,7 @@ def test_slice_and_copy_similar_to_vllm(self, granularity): @unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+") def test_bmm(self): # only support per row quantization - config = Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), VERSION=2 - ) + config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) class M(torch.nn.Module): def __init__(self, weight): @@ -369,9 +360,7 @@ def forward(self, x): ], ) def test_to_device(self, granularity, sizes): - config = Float8DynamicActivationFloat8WeightConfig( - granularity=granularity, VERSION=2 - ) + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) M, N, K = sizes dtype = torch.bfloat16 for device in self.GPU_DEVICES: @@ -401,9 +390,7 @@ def test_to_device(self, granularity, sizes): ], ) def test_cat(self, granularity, sizes): - config = Float8DynamicActivationFloat8WeightConfig( - granularity=granularity, VERSION=2 - ) + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) dtype = torch.bfloat16 device = "cuda" M, N, K = sizes @@ -461,9 +448,7 @@ def test_moe_weight_reshape_ops(self): dtype = torch.bfloat16 device = "cuda" - bmm_config = Float8DynamicActivationFloat8WeightConfig( - granularity=granularity, VERSION=2 - ) + bmm_config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) moe_config = MoEQuantConfig(bmm_config) batch_size = 4 diff --git a/torchao/core/config.py b/torchao/core/config.py index b7e85d6b3d..26e71360e2 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -8,13 +8,13 @@ import enum import importlib import json -from typing import Any, ClassVar, Dict +import warnings +from typing import Any, Dict import torch __all__ = [ "AOBaseConfig", - "VersionMismatchError", "config_from_dict", "config_to_dict", "ALLOWED_AO_MODULES", @@ -50,29 +50,21 @@ def _transform( """ """ - Note: this is not the version of AOBaseConfig, but the default version for all child configs - inheriting from AOBaseConfig, and it should be `_DEFAULT_VERSION` and never change - this is making sure all configs has a version defined, when they need to bump the version - they have to define a class variable VERSION for the child config to overwrite the default VERSION - that's defined here. Different child configs will maintain their own VERSION. + Note: this is not the version of AOBaseConfig, but the default version for instances of + all child configs inheriting from AOBaseConfig, and it should be `_DEFAULT_VERSION` and never change + this is making sure all config instances has a version defined, when they need to bump the default + version they have to define a instance variable version for the child config to overwrite the default version + that's defined here. Different child config instances will maintain their own version. - default Version of a config, should never change - """ - VERSION: ClassVar[int] = _DEFAULT_VERSION + Why version is instance variable instead of class variable? instance level version is needed becuase + when we have multiple versions co-exist, we need to be able to load objects with earlier versions, + class level version is global and can't achieve this goal so we have to use instance variable. + to overwrite this in subclasses, we need to define `version: int` (with type annotations) -class VersionMismatchError(Exception): - """Raised when trying to deserialize a config with a different version""" - - def __init__(self, type_path, stored_version, current_version): - self.type_path = type_path - self.stored_version = stored_version - self.current_version = current_version - message = ( - f"Version mismatch for {type_path}: " - f"stored version {stored_version} != current version {current_version}" - ) - super().__init__(message) + default Version of a config, should never change + """ + version: int = _DEFAULT_VERSION class ConfigJSONEncoder(json.JSONEncoder): @@ -84,14 +76,14 @@ def default(self, o): data_dict = {} # Process each attribute to handle nested objects for k, v in o.__dict__.items(): - if not k.startswith("_") and k != "VERSION": + if not k.startswith("_") and k != "version": # Recursively encode each value (important for nested objects) data_dict[k] = self.encode_value(v) return { # Only store the class name, not the full module path "_type": o.__class__.__name__, - "_version": getattr(o.__class__, "VERSION", 1), + "_version": getattr(o, "version", _DEFAULT_VERSION), "_data": data_dict, } @@ -105,7 +97,7 @@ def default(self, o): return { "_type": o.__class__.__name__, - "_version": getattr(o.__class__, "VERSION", 1), + "_version": getattr(o, "version", _DEFAULT_VERSION), "_data": processed_data, } @@ -114,13 +106,13 @@ def default(self, o): data_dict = {} # Process each field to handle nested objects for f in dataclasses.fields(o): - if f.name != "VERSION": + if f.name != "version": data_dict[f.name] = self.encode_value(getattr(o, f.name)) return { # Only store the class name for dataclasses too "_type": o.__class__.__name__, - "_version": getattr(o.__class__, "VERSION", 1), + "_version": getattr(o, "version", _DEFAULT_VERSION), "_data": data_dict, } @@ -218,7 +210,6 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: An instance of the appropriate AOBaseConfig subclass Raises: - VersionMismatchError: If the stored version doesn't match the class version ValueError: If deserialization fails for other reasons """ if not isinstance(data, dict): @@ -228,7 +219,7 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: raise ValueError("Input dictionary missing required '_type' or '_data' fields") type_path = data["_type"] - stored_version = data.get("_version", 1) + stored_version = data.get("_version", _DEFAULT_VERSION) obj_data = data["_data"] # Handle torch.dtype @@ -253,10 +244,11 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: f"Failed to find class {type_path} in any of the allowed modules: {allowed_modules_str}" ) - # Check version - require exact match - current_version = getattr(cls, "VERSION", 1) - if stored_version != current_version: - raise VersionMismatchError(type_path, stored_version, current_version) + current_default_version = getattr(cls, "version", _DEFAULT_VERSION) + if stored_version != current_default_version: + warnings.warn( + f"Stored version is not the same as current default version of the config: {stored_version=}, {current_default_version=}, please check the deprecation warning" + ) # Handle the case where obj_data is not a dictionary if not isinstance(obj_data, dict): @@ -271,7 +263,11 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: return obj_data # Process nested structures for dictionary obj_data - processed_data = {} + if stored_version != current_default_version: + processed_data = {"version": stored_version} + else: + processed_data = {} + for key, value in obj_data.items(): if isinstance(value, dict) and "_type" in value and "_data" in value: # Recursively handle nested configs diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index e5ddc9e4bb..4afc5fdfee 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import warnings from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -109,6 +110,9 @@ def __init__( transposed: bool, _layout: Layout, ): + warnings.warn( + "Models quantized with version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please upgrade torchao and quantize again, or download a newer torchao checkpoint, see https://github.com/pytorch/ao/issues/2649 for more details" + ) self.float8_data = float8_data self.scale = scale self.transposed = transposed diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 42088a28bc..ce61000105 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1489,7 +1489,7 @@ class Float8WeightOnlyConfig(AOBaseConfig): Args: weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn. set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. - VERSION (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor + version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor (default) Note: The actual matmul will be computed in original precision of the weight tensor. @@ -1497,7 +1497,7 @@ class Float8WeightOnlyConfig(AOBaseConfig): weight_dtype: torch.dtype = e4m3_dtype set_inductor_config: bool = True - VERSION: int = 1 + version: int = 2 # for BC @@ -1505,7 +1505,10 @@ class Float8WeightOnlyConfig(AOBaseConfig): def _float8_weight_only_quant_tensor(weight, config): - if config.VERSION == 1: + if config.version == 1: + warnings.warn( + "version 1 of Float8WeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" + ) from torchao.dtypes import to_affine_quantized_floatx block_size = tuple([1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]) @@ -1517,7 +1520,7 @@ def _float8_weight_only_quant_tensor(weight, config): _layout=Float8Layout(mm_config=None), ) else: - assert config.VERSION == 2, f"Unexpected version: {config.VERSION}" + assert config.version == 2, f"Unexpected version: {config.version}" weight_dtype = config.weight_dtype new_weight = Float8Tensor.to_float8( weight, float8_dtype=weight_dtype, granularity=PerRow() @@ -1629,7 +1632,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): activation_value_ub (Optional[float]): the upper bound for activation value for calculating scale kernel_preference (KernelPreference): kernel preference for ops like matmul, grouped matmul etc. by defalut (KernelPreference.AUTO) it will be chosen for user based on hardware or other information, this only needs to be set in weight set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. - VERSION (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor + version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor (default) """ @@ -1641,7 +1644,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): activation_value_ub: Optional[float] = None kernel_preference: KernelPreference = KernelPreference.AUTO set_inductor_config: bool = True - VERSION: int = 1 + version: int = 2 def __post_init__(self): if self.mm_config is None: @@ -1679,7 +1682,11 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): "PerRow quantization only works for bfloat16 precision input weight" ) - if config.VERSION == 1: + if config.version == 1: + warnings.warn( + "version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details" + ) + block_size = get_block_size(weight.shape[-2:], weight_granularity) if weight.dim() == 3: block_size = tuple([1] + list(block_size)) @@ -1701,7 +1708,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs ) else: - assert config.VERSION == 2, f"Unexpected version: {config.VERSION}" + assert config.version == 2, f"Unexpected version: {config.version}" act_quant_kwargs = QuantizeTensorToFloat8Kwargs( activation_dtype, activation_granularity,