Skip to content

Remove old change_linear_weights_to_* APIs #2721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: gh/andrewor14/21/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion benchmarks/benchmark_aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
"""
from torchao.quantization.quant_api import (
_get_subclass_inserter,
_in_features_greater_than_16,
_is_linear,
)
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight

def _in_features_greater_than_16(mod, *args):
return hasattr(mod, "in_features") and mod.in_features > 16

if filter_fn is None:
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
*args
Expand Down
6 changes: 0 additions & 6 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
from torchao.quantization.quant_api import (
Float8DynamicActivationFloat8WeightConfig,
_replace_with_custom_fn_if_matches_filter,
change_linear_weights_to_int8_dqtensors,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
Expand Down Expand Up @@ -1829,11 +1828,6 @@ class TestAOTI(unittest.TestCase):
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
def test_aoti(self, api, test_device, test_dtype):
if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda":
self.skipTest(
f"{api} in {test_device} is not support for aoti compilation yet"
)

if (
test_device == "cuda"
and torch.cuda.is_available()
Expand Down
26 changes: 0 additions & 26 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,32 +146,6 @@ def forward(self, x):
return x


def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for int8 dynamic quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import (
_get_subclass_inserter,
_in_features_greater_than_16,
_is_linear,
)
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight

if filter_fn is None:
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
*args
)

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(
Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs
),
filter_fn,
)


def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
"""
Expand Down
15 changes: 0 additions & 15 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,44 +125,29 @@ be applied individually. While there are a large variety of quantization apis, t
#### A16W4 WeightOnly Quantization

```python
# for torch 2.4+
from torchao.quantization import quantize_, Int4WeightOnlyConfig
group_size = 32

# you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through
# use_hqq flag for `Int4WeightOnlyConfig` quantization
use_hqq = False
quantize_(model, Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq))

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
change_linear_weights_to_int4_woqtensors(model)
```

Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.

#### A16W8 Int8 WeightOnly Quantization

```python
# for torch 2.4+
from torchao.quantization import quantize_, Int8WeightOnlyConfig
quantize_(model, Int8WeightOnlyConfig())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
change_linear_weights_to_int8_woqtensors(model)
```

#### A8W8 Int8 Dynamic Quantization

```python
# for torch 2.4+
from torchao.quantization import quantize_, Int8DynamicActivationInt8WeightConfig
quantize_(model, Int8DynamicActivationInt8WeightConfig())

# for torch 2.2.2 and 2.3
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
change_linear_weights_to_int8_dqtensors(model)
```

### A16W8 Float8 WeightOnly Quantization
Expand Down
106 changes: 0 additions & 106 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,6 @@
ZeroPointDomain,
)
from .subclass import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file can be deleted as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like classes from this file is still being used in autoquant. I don't think we can delete or even deprecate yet...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, created an issue here instead: #2745. Let's do them separately because there are autoquant/benchmark dependencies

Int4WeightOnlyQuantizedLinearWeight,
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
)
from .unified import Quantizer, TwoStepQuantizer
Expand Down Expand Up @@ -172,109 +169,6 @@
}


######
# TO BE DEPRECATED START
######
def _in_features_greater_than_16(mod, *args):
return hasattr(mod, "in_features") and mod.in_features > 16


# TODO: delete
def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
"""
Converts all linear weight tensors to the `Int8DynamicallyQuantizedLinearWeight`
Tensor subclass, effectively applying the same form of quantization
as apply_dynamic_quant while not modifying the linear modules.
"""
raise ImportError(
"This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs"
)

if filter_fn is None:
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
*args
)

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(
Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs
),
filter_fn,
)


# TODO: delete
def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs):
"""
Converts all linear weight tensors to the
`Int8WeightOnlyQuantizedLinearWeight` tensor subclass,
effectively applying the same form of quantization
as apply_weight_only_int8_quant while not modifying the linear modules.
"""
raise ImportError(
"This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs"
)

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(
Int8WeightOnlyQuantizedLinearWeight, enable_parametrization=False, **kwargs
),
_is_linear if filter_fn is None else filter_fn,
)


# TODO: delete
def change_linear_weights_to_int4_woqtensors(
model,
groupsize=128,
inner_k_tiles=8,
filter_fn=None,
zero_point_domain=ZeroPointDomain.FLOAT,
preserve_zero=False,
):
"""
Converts all linear weight tensors to the
`Int4WeightOnlyQuantizedLinearWeight` tensor subclass,
effectively applying the same form of quantization
as apply_dynamic_quant while not modifying the linear modules.
Args:
`groupsize`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained, choices are [256, 128, 64, 32]
`inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2]
`filter_fn`: function that takes a nn.Module instance and fully qualified name of the module, \
returns True if we want to run `config` on
`zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, \
ZeroPointDomain.INT, ZeroPointDomain.NONE]
`preserve_zero`: whether to preserve zero, default is False
"""
raise ImportError(
"This API is deprecated for pytorch 2.4+, please checkout quantization/README.md for most up to date APIs"
)

if filter_fn is None:
filter_fn = _is_linear

_replace_with_custom_fn_if_matches_filter(
model,
_get_subclass_inserter(
Int4WeightOnlyQuantizedLinearWeight,
enable_parametrization=False,
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
zero_point_domain=zero_point_domain,
preserve_zero=preserve_zero,
),
filter_fn,
)


########
# TO BE DEPRECATED END
########


def _replace_with_custom_fn_if_matches_filter(
model,
replacement_fn,
Expand Down
Loading