Skip to content

Commit ccbddc5

Browse files
committed
up
1 parent dfff847 commit ccbddc5

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616

1717
from executorch.extension.llm.export.builder import DType
1818

19-
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
2019
from torchao.quantization.granularity import PerAxis, PerGroup
2120
from torchao.quantization.quant_api import (
2221
Int8DynamicActivationIntxWeightConfig,
2322
IntxWeightOnlyConfig,
2423
MappingType,
2524
quantize_,
2625
)
26+
from torchao.utils import unwrap_tensor_subclass
2727

2828

2929
try:
@@ -125,6 +125,8 @@ def quantize( # noqa C901
125125
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
126126
bitwidth = int(matches[0][0])
127127

128+
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
129+
128130
with torch.no_grad():
129131
# Computation dtype is fixed to fp32 in the implementation of quantize_, so
130132
# no way to decouple checkpoint and computation dtype.
@@ -139,6 +141,7 @@ def quantize( # noqa C901
139141
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
140142
),
141143
)
144+
model = unwrap_tensor_subclass(model)
142145
if verbose:
143146
print("quantized model:", model)
144147
return model
@@ -157,6 +160,7 @@ def quantize( # noqa C901
157160
weight_mapping_type=MappingType.SYMMETRIC,
158161
),
159162
)
163+
model = unwrap_tensor_subclass(model)
160164
# TODO: deal with checkpoint / computation dtype decoupling.
161165
if verbose:
162166
print("quantized model:", model)
@@ -798,6 +802,7 @@ def _embedding_quantizer(model):
798802
),
799803
lambda m, fqn: isinstance(m, nn.Embedding),
800804
)
805+
model = unwrap_tensor_subclass(model)
801806
return model
802807

803808
return _embedding_quantizer

0 commit comments

Comments
 (0)