1616
1717from  executorch .extension .llm .export .builder  import  DType 
1818
19- from  torchao .dtypes  import  PackedLinearInt8DynamicActivationIntxWeightLayout 
2019from  torchao .quantization .granularity  import  PerAxis , PerGroup 
2120from  torchao .quantization .quant_api  import  (
2221    Int8DynamicActivationIntxWeightConfig ,
2322    IntxWeightOnlyConfig ,
2423    MappingType ,
2524    quantize_ ,
2625)
26+ from  torchao .utils  import  unwrap_tensor_subclass 
2727
2828
2929try :
@@ -125,6 +125,8 @@ def quantize(  # noqa C901
125125        assert  len (matches ) ==  1 , f"Expected 1 match for pattern but got { len (matches )}  " 
126126        bitwidth  =  int (matches [0 ][0 ])
127127
128+         from  torchao .dtypes  import  PackedLinearInt8DynamicActivationIntxWeightLayout 
129+ 
128130        with  torch .no_grad ():
129131            # Computation dtype is fixed to fp32 in the implementation of quantize_, so 
130132            # no way to decouple checkpoint and computation dtype. 
@@ -139,6 +141,7 @@ def quantize(  # noqa C901
139141                    layout = PackedLinearInt8DynamicActivationIntxWeightLayout (),
140142                ),
141143            )
144+             model  =  unwrap_tensor_subclass (model )
142145        if  verbose :
143146            print ("quantized model:" , model )
144147        return  model 
@@ -157,6 +160,7 @@ def quantize(  # noqa C901
157160                weight_mapping_type = MappingType .SYMMETRIC ,
158161            ),
159162        )
163+         model  =  unwrap_tensor_subclass (model )
160164        # TODO: deal with checkpoint / computation dtype decoupling. 
161165        if  verbose :
162166            print ("quantized model:" , model )
@@ -798,6 +802,7 @@ def _embedding_quantizer(model):
798802            ),
799803            lambda  m , fqn : isinstance (m , nn .Embedding ),
800804        )
805+         model  =  unwrap_tensor_subclass (model )
801806        return  model 
802807
803808    return  _embedding_quantizer 
0 commit comments