@@ -595,19 +595,16 @@ def __init__(
595
595
596
596
@torch .no_grad ()
597
597
def create_quantized_state_dict (self , packed = False ) -> Dict :
598
+ from torchao .quantization .granularity import PerAxis , PerGroup
599
+ from torchao .quantization .quant_api import (
600
+ IntxWeightOnlyConfig ,
601
+ MappingType ,
602
+ quantize_ ,
603
+ )
604
+
598
605
cur_state_dict = self .mod .state_dict ()
599
606
600
- if self .bitwidth == 2 :
601
- range_min = - 2
602
- range_max = 1
603
- elif self .bitwidth == 4 :
604
- range_min = - 8
605
- range_max = 7
606
- elif self .bitwidth == 8 :
607
- range_min = - 128
608
- range_max = 127
609
- else :
610
- raise ValueError (f"Unsupported bitwidth { self .bitwidth } " )
607
+ assert self .bitwidth in [2 , 4 , 8 ], f"Unsupported bitwidth { self .bitwidth } "
611
608
612
609
for fqn , mod in self .mod .named_modules ():
613
610
if isinstance (mod , nn .Embedding ):
@@ -619,18 +616,22 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
619
616
print (
620
617
f"quantize { fqn , mod } with group_size { self .group_size } , bitwidth { self .bitwidth } "
621
618
)
622
- weight , scales , _ = dynamically_quantize_per_channel (
623
- (
624
- mod .weight .to (dtype = self .precision )
625
- if self .precision
626
- else mod .weight
619
+ tmp_model = nn .Embedding (mod .weight .shape [0 ], mod .weight .shape [1 ])
620
+ if self .precision :
621
+ tmp_model = tmp_model .to (dtype = self .precision )
622
+ tmp_model .weight = nn .Parameter (mod .weight )
623
+ config = IntxWeightOnlyConfig (
624
+ weight_dtype = getattr (torch , f"int{ self .bitwidth } " ),
625
+ granularity = (
626
+ PerAxis (0 )
627
+ if (self .group_size is None or self .group_size == 0 )
628
+ else PerGroup (self .group_size )
627
629
),
628
- range_min ,
629
- range_max ,
630
- torch .int8 ,
631
- self .group_size ,
632
- scales_dtype = mod .weight .dtype ,
630
+ mapping_type = MappingType .SYMMETRIC ,
633
631
)
632
+ quantize_ (tmp_model , config , lambda m , fqn : isinstance (m , nn .Embedding ))
633
+ weight = tmp_model .weight .qdata # pyre-ignore[16]
634
+ scales = tmp_model .weight .scale # pyre-ignore[16]
634
635
635
636
if packed :
636
637
if self .bitwidth == 2 :
0 commit comments