@@ -510,6 +510,223 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
510510 self .precision ,
511511 )
512512
513+ #########################################################################
514+ ##### embedding table quantization ######
515+
516+
517+ def replace_embedding_weight_only_grouped_int8_per_channel (
518+ module , device , bitwidth : int = 8 , group_size : Optional [int ] = None , packed = False
519+ ):
520+ for name , child in module .named_children ():
521+ # print(f"name: {name}")
522+ if isinstance (child , nn .Embedding ):
523+ # print(f"{name, child}")
524+ # print(f"weights size: {child.weight.size()}")
525+ setattr (
526+ module ,
527+ name ,
528+ QuantizedGroupEmbedding (
529+ device = device ,
530+ vocab_size = child .weight .shape [0 ],
531+ embedding_dim = child .weight .shape [1 ],
532+ group_size = group_size ,
533+ dtype = child .weight .dtype ,
534+ packed = packed ,
535+ bitwidth = bitwidth ,
536+ ),
537+ )
538+ else :
539+ replace_embedding_weight_only_grouped_int8_per_channel (
540+ child , device , bitwidth , group_size , packed
541+ )
542+
543+
544+ class EmbeddingQuantHandler (QuantHandler ):
545+ def __init__ (
546+ self ,
547+ mod ,
548+ device = "cpu" ,
549+ * ,
550+ bitwidth : int = 8 ,
551+ group_size : Optional [int ] = None ,
552+ packed = False ,
553+ precision : Optional [torch .dtype ] = None ,
554+ ):
555+ if isinstance (packed , str ):
556+ packed = packed == "True"
557+ self .mod = mod
558+ self .device = device
559+ self .group_size = group_size
560+ self .bitwidth = bitwidth
561+ self .packed = packed
562+ # Dtype of the weights right before quantization.
563+ self .precision = precision
564+ if (bitwidth not in [2 , 4 ]) and packed :
565+ raise RuntimeError ("pack only works with bitsize 2, 4" )
566+
567+ @torch .no_grad ()
568+ def create_quantized_state_dict (self , packed = False ) -> Dict :
569+ cur_state_dict = self .mod .state_dict ()
570+
571+ if self .bitwidth == 2 :
572+ range_min = - 2
573+ range_max = 1
574+ elif self .bitwidth == 4 :
575+ range_min = - 8
576+ range_max = 7
577+ elif self .bitwidth == 8 :
578+ range_min = - 128
579+ range_max = 127
580+ else :
581+ raise ValueError (f"Unsupported bitwidth { self .bitwidth } " )
582+
583+ for fqn , mod in self .mod .named_modules ():
584+ if isinstance (mod , nn .Embedding ):
585+ # print("****")
586+ # print(f"Embedding identified: {fqn, mod}")
587+ # print(f"weights size: {mod.weight.size()}")
588+ # print(f"quantize {fqn}...")
589+
590+ print (
591+ f"quantize { fqn , mod } with group_size { self .group_size } , bitwidth { self .bitwidth } "
592+ )
593+ weight , scales , _ = dynamically_quantize_per_channel (
594+ (
595+ mod .weight .to (dtype = self .precision )
596+ if self .precision
597+ else mod .weight
598+ ),
599+ range_min ,
600+ range_max ,
601+ torch .int8 ,
602+ self .group_size ,
603+ scales_dtype = mod .weight .dtype ,
604+ )
605+
606+ if packed :
607+ if self .bitwidth == 2 :
608+ if weight .shape [- 1 ] % 4 != 0 :
609+ raise RuntimeError ("automatic padding not implemented yet" )
610+ weight_range_shifted = weight .add (2 ).view (torch .uint8 )
611+ weight_view = weight_range_shifted .view (
612+ weight .shape [0 ], weight .shape [1 ] // 4 , 4
613+ )
614+ weight_0 = weight_view [:, :, 0 ]
615+ weight_1 = weight_view [:, :, 1 ] << 2
616+ weight_2 = weight_view [:, :, 2 ] << 4
617+ weight_3 = weight_view [:, :, 3 ] << 6
618+ weight_packed = weight_0 + weight_1 + weight_2 + weight_3
619+ weight = weight_packed
620+ elif self .bitwidth == 4 :
621+ if weight .shape [- 1 ] % 2 != 0 :
622+ raise RuntimeError ("automatic padding not implemented yet" )
623+ weight_range_shifted = weight .add (8 ).view (torch .uint8 )
624+ weight_view = weight_range_shifted .view (
625+ weight .shape [0 ], weight .shape [1 ] // 2 , 2
626+ )
627+ weight_even = weight_view [:, :, 0 ] * 16 # left shift 4
628+ weight_odd = weight_view [:, :, 1 ]
629+ weight_packed = weight_even + weight_odd
630+ weight = weight_packed
631+
632+ weight = weight .to (device = self .device )
633+ scales = scales .to (device = self .device )
634+ # Update state dict
635+ cur_state_dict [f"{ fqn } .weight" ] = weight
636+ # squeeze makes group_size=rowsize unidimensional
637+ cur_state_dict [f"{ fqn } .scales" ] = scales .squeeze (dim = - 1 )
638+
639+ return cur_state_dict
640+
641+ def convert_for_runtime (self ) -> nn .Module :
642+ replace_embedding_weight_only_grouped_int8_per_channel (
643+ self .mod , self .device , self .bitwidth , self .group_size , self .packed
644+ )
645+ return self .mod
646+
647+ def quantized_model (self ) -> nn .Module :
648+ model_updated_state_dict = self .create_quantized_state_dict (self .packed )
649+ self .convert_for_runtime ()
650+ self .mod .load_state_dict (model_updated_state_dict , assign = True )
651+ return self .mod
652+
653+
654+ class QuantizedGroupEmbedding (torch .nn .Module ):
655+ def __init__ (
656+ self ,
657+ device ,
658+ vocab_size : int ,
659+ embedding_dim : int ,
660+ group_size : Optional [int ] = None ,
661+ dtype = torch .half ,
662+ packed = False ,
663+ bitwidth : int = 8 ,
664+ ) -> None :
665+ super ().__init__ ()
666+ if group_size is None or group_size == 0 :
667+ group_size = embedding_dim
668+ self .group_size = group_size
669+ self .dtype = dtype
670+ self .packed = packed
671+ self .bitwidth = bitwidth
672+ if not packed :
673+ self .register_buffer (
674+ "weight" ,
675+ torch .zeros (
676+ (vocab_size , embedding_dim ), dtype = torch .int8 , device = device
677+ ),
678+ )
679+ else : # packed
680+ if bitwidth == 2 :
681+ self .register_buffer (
682+ "weight" ,
683+ torch .zeros (
684+ (vocab_size , embedding_dim // 4 ),
685+ dtype = torch .uint8 ,
686+ device = device ,
687+ ),
688+ )
689+ elif bitwidth == 4 :
690+ self .register_buffer (
691+ "weight" ,
692+ torch .zeros (
693+ (vocab_size , embedding_dim // 2 ),
694+ dtype = torch .uint8 ,
695+ device = device ,
696+ ),
697+ )
698+
699+ groups_per_row = (embedding_dim + group_size - 1 ) // group_size
700+ if groups_per_row > 1 :
701+ self .register_buffer (
702+ "scales" ,
703+ torch .ones (
704+ (vocab_size , groups_per_row ), dtype = torch .float16 , device = device
705+ ),
706+ )
707+ else :
708+ self .register_buffer (
709+ "scales" , torch .ones ((vocab_size ,), dtype = torch .float16 , device = device )
710+ )
711+
712+ @torch .no_grad ()
713+ def forward (self , indices : torch .Tensor ) -> torch .Tensor :
714+ if not self .packed : # 8bit
715+ return torch .ops .quantized_decomposed .embedding_byte .dtype (
716+ self .weight , self .scales , None , - 128 , 127 , indices , dtype = self .dtype
717+ )
718+ else : # packed
719+ if self .bitwidth == 2 :
720+ return torch .ops .quantized_decomposed .embedding_2bit .dtype (
721+ self .weight , self .scales , None , - 2 , 1 , indices , dtype = self .dtype
722+ )
723+
724+ # Remaining case (always return to make pyre happy)
725+ assert self .bitwidth == 4
726+ return torch .ops .quantized_decomposed .embedding_4bit .dtype (
727+ self .weight , self .scales , None , - 8 , 7 , indices , dtype = self .dtype
728+ )
729+
513730
514731############################ Source Transform Start #######################
515732
0 commit comments