@@ -116,7 +116,6 @@ def quantize( # noqa C901
116116 assert len (matches ) == 1 , f"Expected 1 match for pattern but got { len (matches )} "
117117 bitwidth = int (matches [0 ][0 ])
118118
119- from torchao .dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
120119 from torchao .quantization .granularity import PerAxis , PerGroup
121120 from torchao .quantization .quant_api import (
122121 Int8DynamicActivationIntxWeightConfig ,
@@ -136,7 +135,7 @@ def quantize( # noqa C901
136135 PerAxis (0 ) if group_size == 0 else PerGroup (group_size )
137136 ),
138137 weight_mapping_type = MappingType .SYMMETRIC ,
139- layout = PackedLinearInt8DynamicActivationIntxWeightLayout () ,
138+ intx_packing_format = "opaque_torchao_auto" ,
140139 ),
141140 )
142141 model = unwrap_tensor_subclass (model )
@@ -148,10 +147,21 @@ def quantize( # noqa C901
148147 # TODO: Default value for group size for 8da4w. Need this here for refactor, will clean this up.
149148 group_size = 128
150149
151- from torchao .quantization import int8_dynamic_activation_int4_weight , quantize_
150+ from torchao .quantization import (
151+ Int8DynamicActivationIntxWeightConfig ,
152+ quantize_ ,
153+ )
154+ from torchao .quantization .granularity import PerGroup
152155 from torchao .utils import unwrap_tensor_subclass
153156
154- quantize_ (model , int8_dynamic_activation_int4_weight (group_size = group_size ))
157+ quantize_ (
158+ model ,
159+ Int8DynamicActivationIntxWeightConfig (
160+ weight_dtype = torch .int4 ,
161+ weight_granularity = PerGroup (group_size ),
162+ ),
163+ )
164+
155165 model = unwrap_tensor_subclass (model )
156166
157167 # TODO: deal with checkpoint / computation dtype decoupling.
@@ -751,9 +761,9 @@ def get_quant_embedding_transform(
751761 dtype_override : Optional [DType ] = None ,
752762):
753763 if embedding_quantize .startswith ("torchao:" ):
754- from torchao .experimental . quant_api import (
764+ from torchao .prototype . quantization . embedding . api import (
755765 EmbeddingQuantizer ,
756- SharedEmbeddingQuantizer ,
766+ TiedEmbeddingQuantizer ,
757767 )
758768 from torchao .quantization .granularity import PerAxis , PerGroup
759769 from torchao .quantization .quant_api import MappingType
@@ -787,7 +797,7 @@ def _torchao_embedding_quantizer(model):
787797 use_fallback = False ,
788798 ).quantize (model )
789799 else :
790- SharedEmbeddingQuantizer (
800+ TiedEmbeddingQuantizer (
791801 weight_dtype = weight_dtype ,
792802 granularity = granularity ,
793803 mapping_type = mapping_type ,
0 commit comments