1616
1717from executorch .extension .llm .export .builder import DType
1818
19- from torchao .dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
2019from torchao .quantization .granularity import PerAxis , PerGroup
2120from torchao .quantization .quant_api import (
2221 Int8DynamicActivationIntxWeightConfig ,
2322 IntxWeightOnlyConfig ,
2423 MappingType ,
2524 quantize_ ,
2625)
26+ from torchao .utils import unwrap_tensor_subclass
2727
2828
2929try :
@@ -125,6 +125,8 @@ def quantize( # noqa C901
125125 assert len (matches ) == 1 , f"Expected 1 match for pattern but got { len (matches )} "
126126 bitwidth = int (matches [0 ][0 ])
127127
128+ from torchao .dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
129+
128130 with torch .no_grad ():
129131 # Computation dtype is fixed to fp32 in the implementation of quantize_, so
130132 # no way to decouple checkpoint and computation dtype.
@@ -139,6 +141,7 @@ def quantize( # noqa C901
139141 layout = PackedLinearInt8DynamicActivationIntxWeightLayout (),
140142 ),
141143 )
144+ model = unwrap_tensor_subclass (model )
142145 if verbose :
143146 print ("quantized model:" , model )
144147 return model
@@ -157,6 +160,7 @@ def quantize( # noqa C901
157160 weight_mapping_type = MappingType .SYMMETRIC ,
158161 ),
159162 )
163+ model = unwrap_tensor_subclass (model )
160164 # TODO: deal with checkpoint / computation dtype decoupling.
161165 if verbose :
162166 print ("quantized model:" , model )
@@ -798,6 +802,7 @@ def _embedding_quantizer(model):
798802 ),
799803 lambda m , fqn : isinstance (m , nn .Embedding ),
800804 )
805+ model = unwrap_tensor_subclass (model )
801806 return model
802807
803808 return _embedding_quantizer
0 commit comments