@@ -595,19 +595,16 @@ def __init__(
595595
596596 @torch .no_grad ()
597597 def create_quantized_state_dict (self , packed = False ) -> Dict :
598+ from torchao .quantization .granularity import PerAxis , PerGroup
599+ from torchao .quantization .quant_api import (
600+ IntxWeightOnlyConfig ,
601+ MappingType ,
602+ quantize_ ,
603+ )
604+
598605 cur_state_dict = self .mod .state_dict ()
599606
600- if self .bitwidth == 2 :
601- range_min = - 2
602- range_max = 1
603- elif self .bitwidth == 4 :
604- range_min = - 8
605- range_max = 7
606- elif self .bitwidth == 8 :
607- range_min = - 128
608- range_max = 127
609- else :
610- raise ValueError (f"Unsupported bitwidth { self .bitwidth } " )
607+ assert self .bitwidth in [2 , 4 , 8 ], f"Unsupported bitwidth { self .bitwidth } "
611608
612609 for fqn , mod in self .mod .named_modules ():
613610 if isinstance (mod , nn .Embedding ):
@@ -619,18 +616,22 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
619616 print (
620617 f"quantize { fqn , mod } with group_size { self .group_size } , bitwidth { self .bitwidth } "
621618 )
622- weight , scales , _ = dynamically_quantize_per_channel (
623- (
624- mod .weight .to (dtype = self .precision )
625- if self .precision
626- else mod .weight
619+ tmp_model = nn .Embedding (mod .weight .shape [0 ], mod .weight .shape [1 ])
620+ if self .precision :
621+ tmp_model = tmp_model .to (dtype = self .precision )
622+ tmp_model .weight = nn .Parameter (mod .weight )
623+ config = IntxWeightOnlyConfig (
624+ weight_dtype = getattr (torch , f"int{ self .bitwidth } " ),
625+ granularity = (
626+ PerAxis (0 )
627+ if (self .group_size is None or self .group_size == 0 )
628+ else PerGroup (self .group_size )
627629 ),
628- range_min ,
629- range_max ,
630- torch .int8 ,
631- self .group_size ,
632- scales_dtype = mod .weight .dtype ,
630+ mapping_type = MappingType .SYMMETRIC ,
633631 )
632+ quantize_ (tmp_model , config , lambda m , fqn : isinstance (m , nn .Embedding ))
633+ weight = tmp_model .weight .qdata # pyre-ignore[16]
634+ scales = tmp_model .weight .scale # pyre-ignore[16]
634635
635636 if packed :
636637 if self .bitwidth == 2 :
0 commit comments