Skip to content

Commit a4b7de0

Browse files
authored
Bump torchao pin and use v2 torchao tensors (#14171)
This PR bumps the torchao pin in ExecuTorch, and adjusts the code in ExecuTorch to rely less on deprecated features. In particular, * torchao/experimental folder is being deprecated, so we switch embedding / tied embedding quantizers to their new home * v1 tensors based on AffineQuantizedTensor + QDQLayout are being deprecated. This switches ExecuTorch to use v2 tensors. See pytorch/ao#2967.
1 parent c638851 commit a4b7de0

File tree

5 files changed

+36
-14
lines changed

5 files changed

+36
-14
lines changed

backends/apple/coreml/test/test_coreml_recipes.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44

55

6+
import copy
67
import unittest
78

89
import coremltools as ct
@@ -152,8 +153,9 @@ def forward(self, x):
152153
# Test with different group sizes
153154
for group_size in [8, 16, 32]:
154155
with self.subTest(group_size=group_size):
156+
model_to_export = copy.deepcopy(model)
155157
session = export(
156-
model=model,
158+
model=model_to_export,
157159
example_inputs=example_inputs,
158160
export_recipe=ExportRecipe.get_recipe(
159161
CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP,
@@ -219,8 +221,9 @@ def forward(self, x):
219221
# Test with different group sizes
220222
for group_size in [16, 32, 64]:
221223
with self.subTest(group_size=group_size):
224+
model_to_export = copy.deepcopy(model)
222225
session = export(
223-
model=model,
226+
model=model_to_export,
224227
example_inputs=example_inputs,
225228
export_recipe=ExportRecipe.get_recipe(
226229
CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP,

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2680,14 +2680,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
26802680
def apply_8da4w_quantization(self):
26812681
"""Apply TorchAO 8da4w quantization (int8 dynamic activation + int4 weight)."""
26822682
from torchao.quantization import (
2683-
int8_dynamic_activation_int4_weight,
2683+
Int8DynamicActivationIntxWeightConfig,
26842684
quantize_,
26852685
)
2686+
from torchao.quantization.granularity import PerGroup
26862687
from torchao.utils import unwrap_tensor_subclass
26872688

26882689
quantize_(
26892690
self,
2690-
int8_dynamic_activation_int4_weight(group_size=self.group_size),
2691+
Int8DynamicActivationIntxWeightConfig(
2692+
weight_dtype=torch.int4, granularity=PerGroup(self.group_size)
2693+
),
26912694
)
26922695
unwrap_tensor_subclass(self)
26932696
return self

backends/xnnpack/test/ops/test_linear.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@
3434
from torch.export.graph_signature import ExportGraphSignature, InputKind
3535

3636
try:
37+
from torchao.quantization.granularity import PerGroup
3738
from torchao.quantization.quant_api import (
38-
int8_dynamic_activation_int4_weight,
39+
Int8DynamicActivationIntxWeightConfig,
3940
quantize_,
4041
)
4142
from torchao.utils import unwrap_tensor_subclass
@@ -391,7 +392,12 @@ def _test_groupwise_dq_linear(
391392
"""
392393
Helper function to test groupwise dynamic quantized linear op with different configurations.
393394
"""
394-
quantize_(mod, int8_dynamic_activation_int4_weight(group_size=group_size))
395+
quantize_(
396+
mod,
397+
Int8DynamicActivationIntxWeightConfig(
398+
weight_dtype=torch.int4, weight_granularity=PerGroup(group_size)
399+
),
400+
)
395401
unwrap_tensor_subclass(mod)
396402
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
397403
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,

examples/models/llama/source_transformation/quantize.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def quantize( # noqa C901
116116
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
117117
bitwidth = int(matches[0][0])
118118

119-
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
120119
from torchao.quantization.granularity import PerAxis, PerGroup
121120
from torchao.quantization.quant_api import (
122121
Int8DynamicActivationIntxWeightConfig,
@@ -136,7 +135,7 @@ def quantize( # noqa C901
136135
PerAxis(0) if group_size == 0 else PerGroup(group_size)
137136
),
138137
weight_mapping_type=MappingType.SYMMETRIC,
139-
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
138+
intx_packing_format="opaque_torchao_auto",
140139
),
141140
)
142141
model = unwrap_tensor_subclass(model)
@@ -148,10 +147,21 @@ def quantize( # noqa C901
148147
# TODO: Default value for group size for 8da4w. Need this here for refactor, will clean this up.
149148
group_size = 128
150149

151-
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
150+
from torchao.quantization import (
151+
Int8DynamicActivationIntxWeightConfig,
152+
quantize_,
153+
)
154+
from torchao.quantization.granularity import PerGroup
152155
from torchao.utils import unwrap_tensor_subclass
153156

154-
quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
157+
quantize_(
158+
model,
159+
Int8DynamicActivationIntxWeightConfig(
160+
weight_dtype=torch.int4,
161+
weight_granularity=PerGroup(group_size),
162+
),
163+
)
164+
155165
model = unwrap_tensor_subclass(model)
156166

157167
# TODO: deal with checkpoint / computation dtype decoupling.
@@ -744,9 +754,9 @@ def get_quant_embedding_transform(
744754
dtype_override: Optional[DType] = None,
745755
):
746756
if embedding_quantize.startswith("torchao:"):
747-
from torchao.experimental.quant_api import (
757+
from torchao.prototype.quantization.embedding.api import (
748758
EmbeddingQuantizer,
749-
SharedEmbeddingQuantizer,
759+
TiedEmbeddingQuantizer,
750760
)
751761
from torchao.quantization.granularity import PerAxis, PerGroup
752762
from torchao.quantization.quant_api import MappingType
@@ -780,7 +790,7 @@ def _torchao_embedding_quantizer(model):
780790
use_fallback=False,
781791
).quantize(model)
782792
else:
783-
SharedEmbeddingQuantizer(
793+
TiedEmbeddingQuantizer(
784794
weight_dtype=weight_dtype,
785795
granularity=granularity,
786796
mapping_type=mapping_type,

third-party/ao

Submodule ao updated 103 files

0 commit comments

Comments
 (0)