Skip to content

Commit 15ad063

Browse files
committed
Bump torchao pin and use v2 torchao tensors
1 parent 9fa1b27 commit 15ad063

File tree

4 files changed

+31
-12
lines changed

4 files changed

+31
-12
lines changed

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2680,14 +2680,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
26802680
def apply_8da4w_quantization(self):
26812681
"""Apply TorchAO 8da4w quantization (int8 dynamic activation + int4 weight)."""
26822682
from torchao.quantization import (
2683-
int8_dynamic_activation_int4_weight,
2683+
Int8DynamicActivationIntxWeightConfig,
26842684
quantize_,
26852685
)
2686+
from torchao.quantization.granularity import PerGroup
26862687
from torchao.utils import unwrap_tensor_subclass
26872688

26882689
quantize_(
26892690
self,
2690-
int8_dynamic_activation_int4_weight(group_size=self.group_size),
2691+
Int8DynamicActivationIntxWeightConfig(
2692+
weight_dtype=torch.int4, granularity=PerGroup(self.group_size)
2693+
),
26912694
)
26922695
unwrap_tensor_subclass(self)
26932696
return self

backends/xnnpack/test/ops/test_linear.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@
3434
from torch.export.graph_signature import ExportGraphSignature, InputKind
3535

3636
try:
37+
from torchao.quantization.granularity import PerGroup
3738
from torchao.quantization.quant_api import (
38-
int8_dynamic_activation_int4_weight,
39+
Int8DynamicActivationIntxWeightConfig,
3940
quantize_,
4041
)
4142
from torchao.utils import unwrap_tensor_subclass
@@ -391,7 +392,12 @@ def _test_groupwise_dq_linear(
391392
"""
392393
Helper function to test groupwise dynamic quantized linear op with different configurations.
393394
"""
394-
quantize_(mod, int8_dynamic_activation_int4_weight(group_size=group_size))
395+
quantize_(
396+
mod,
397+
Int8DynamicActivationIntxWeightConfig(
398+
weight_dtype=torch.int4, weight_granularity=PerGroup(group_size)
399+
),
400+
)
395401
unwrap_tensor_subclass(mod)
396402
DynamicallyQuantizedPartitioner = XnnpackPartitioner(
397403
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,

examples/models/llama/source_transformation/quantize.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def quantize( # noqa C901
116116
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
117117
bitwidth = int(matches[0][0])
118118

119-
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
120119
from torchao.quantization.granularity import PerAxis, PerGroup
121120
from torchao.quantization.quant_api import (
122121
Int8DynamicActivationIntxWeightConfig,
@@ -136,7 +135,7 @@ def quantize( # noqa C901
136135
PerAxis(0) if group_size == 0 else PerGroup(group_size)
137136
),
138137
weight_mapping_type=MappingType.SYMMETRIC,
139-
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
138+
intx_packing_format="opaque_torchao_auto",
140139
),
141140
)
142141
model = unwrap_tensor_subclass(model)
@@ -148,10 +147,21 @@ def quantize( # noqa C901
148147
# TODO: Default value for group size for 8da4w. Need this here for refactor, will clean this up.
149148
group_size = 128
150149

151-
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
150+
from torchao.quantization import (
151+
Int8DynamicActivationIntxWeightConfig,
152+
quantize_,
153+
)
154+
from torchao.quantization.granularity import PerGroup
152155
from torchao.utils import unwrap_tensor_subclass
153156

154-
quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
157+
quantize_(
158+
model,
159+
Int8DynamicActivationIntxWeightConfig(
160+
weight_dtype=torch.int4,
161+
weight_granularity=PerGroup(group_size),
162+
),
163+
)
164+
155165
model = unwrap_tensor_subclass(model)
156166

157167
# TODO: deal with checkpoint / computation dtype decoupling.
@@ -751,9 +761,9 @@ def get_quant_embedding_transform(
751761
dtype_override: Optional[DType] = None,
752762
):
753763
if embedding_quantize.startswith("torchao:"):
754-
from torchao.experimental.quant_api import (
764+
from torchao.prototype.quantization.embedding.api import (
755765
EmbeddingQuantizer,
756-
SharedEmbeddingQuantizer,
766+
TiedEmbeddingQuantizer,
757767
)
758768
from torchao.quantization.granularity import PerAxis, PerGroup
759769
from torchao.quantization.quant_api import MappingType
@@ -787,7 +797,7 @@ def _torchao_embedding_quantizer(model):
787797
use_fallback=False,
788798
).quantize(model)
789799
else:
790-
SharedEmbeddingQuantizer(
800+
TiedEmbeddingQuantizer(
791801
weight_dtype=weight_dtype,
792802
granularity=granularity,
793803
mapping_type=mapping_type,

third-party/ao

Submodule ao updated 103 files

0 commit comments

Comments
 (0)