1818
1919from sentencepiece import SentencePieceProcessor
2020
21+ from torchao .dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
22+ from torchao .quantization .granularity import PerAxis , PerGroup
23+ from torchao .quantization .quant_api import (
24+ Int8DynamicActivationIntxWeightConfig ,
25+ IntxWeightOnlyConfig ,
26+ MappingType ,
27+ quantize_ ,
28+ )
29+
2130
2231try :
2332 from fairseq2 .nn .embedding import (
@@ -118,15 +127,6 @@ def quantize( # noqa C901
118127 assert len (matches ) == 1 , f"Expected 1 match for pattern but got { len (matches )} "
119128 bitwidth = int (matches [0 ][0 ])
120129
121- from torchao .dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
122- from torchao .quantization .granularity import PerAxis , PerGroup
123- from torchao .quantization .quant_api import (
124- Int8DynamicActivationIntxWeightConfig ,
125- MappingType ,
126- quantize_ ,
127- )
128- from torchao .utils import unwrap_tensor_subclass
129-
130130 with torch .no_grad ():
131131 # Computation dtype is fixed to fp32 in the implementation of quantize_, so
132132 # no way to decouple checkpoint and computation dtype.
@@ -141,7 +141,6 @@ def quantize( # noqa C901
141141 layout = PackedLinearInt8DynamicActivationIntxWeightLayout (),
142142 ),
143143 )
144- model = unwrap_tensor_subclass (model )
145144 if verbose :
146145 print ("quantized model:" , model )
147146 return model
@@ -150,14 +149,17 @@ def quantize( # noqa C901
150149 if group_size is None :
151150 raise Exception ("For 8da4w quantization, group size must be specified." )
152151
153- from torchao .quantization import int8_dynamic_activation_int4_weight , quantize_
154- from torchao .utils import unwrap_tensor_subclass
155-
156- quantize_ (model , int8_dynamic_activation_int4_weight (group_size = group_size ))
157- model = unwrap_tensor_subclass (model )
158-
152+ quantize_ (
153+ model ,
154+ Int8DynamicActivationIntxWeightConfig (
155+ weight_dtype = torch .int4 ,
156+ weight_granularity = (
157+ PerAxis (0 ) if group_size == 0 else PerGroup (group_size )
158+ ),
159+ weight_mapping_type = MappingType .SYMMETRIC ,
160+ ),
161+ )
159162 # TODO: deal with checkpoint / computation dtype decoupling.
160-
161163 if verbose :
162164 print ("quantized model:" , model )
163165 return model
@@ -563,254 +565,32 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
563565 )
564566
565567
566- #########################################################################
567- ##### embedding table quantization ######
568-
569-
570- def replace_embedding_weight_only_grouped_int8_per_channel (
571- module , device , bitwidth : int = 8 , group_size : Optional [int ] = None , packed = False
572- ):
573- for name , child in module .named_children ():
574- # print(f"name: {name}")
575- if isinstance (child , nn .Embedding ):
576- # print(f"{name, child}")
577- # print(f"weights size: {child.weight.size()}")
578- setattr (
579- module ,
580- name ,
581- QuantizedGroupEmbedding (
582- device = device ,
583- vocab_size = child .weight .shape [0 ],
584- embedding_dim = child .weight .shape [1 ],
585- group_size = group_size ,
586- dtype = child .weight .dtype ,
587- packed = packed ,
588- bitwidth = bitwidth ,
589- ),
590- )
591- else :
592- replace_embedding_weight_only_grouped_int8_per_channel (
593- child , device , bitwidth , group_size , packed
594- )
595-
596-
597- class EmbeddingQuantHandler (QuantHandler ):
598- def __init__ (
599- self ,
600- mod ,
601- device = "cpu" ,
602- * ,
603- bitwidth : int = 8 ,
604- group_size : Optional [int ] = None ,
605- packed = False ,
606- precision : Optional [torch .dtype ] = None ,
607- ):
608- if isinstance (packed , str ):
609- packed = packed == "True"
610- self .mod = mod
611- self .device = device
612- self .group_size = group_size
613- self .bitwidth = bitwidth
614- self .packed = packed
615- # Dtype of the weights right before quantization.
616- self .precision = precision
617- if (bitwidth not in [2 , 4 ]) and packed :
618- raise RuntimeError ("pack only works with bitsize 2, 4" )
619-
620- @torch .no_grad ()
621- def create_quantized_state_dict (self , packed = False ) -> Dict :
622- cur_state_dict = self .mod .state_dict ()
623-
624- if self .bitwidth == 2 :
625- range_min = - 2
626- range_max = 1
627- elif self .bitwidth == 4 :
628- range_min = - 8
629- range_max = 7
630- elif self .bitwidth == 8 :
631- range_min = - 128
632- range_max = 127
633- else :
634- raise ValueError (f"Unsupported bitwidth { self .bitwidth } " )
635-
636- for fqn , mod in self .mod .named_modules ():
637- if isinstance (mod , nn .Embedding ):
638- # print("****")
639- # print(f"Embedding identified: {fqn, mod}")
640- # print(f"weights size: {mod.weight.size()}")
641- # print(f"quantize {fqn}...")
642-
643- print (
644- f"quantize { fqn , mod } with group_size { self .group_size } , bitwidth { self .bitwidth } "
645- )
646- weight , scales , _ = dynamically_quantize_per_channel (
647- (
648- mod .weight .to (dtype = self .precision )
649- if self .precision
650- else mod .weight
651- ),
652- range_min ,
653- range_max ,
654- torch .int8 ,
655- self .group_size ,
656- scales_dtype = mod .weight .dtype ,
657- )
658-
659- if packed :
660- if self .bitwidth == 2 :
661- if weight .shape [- 1 ] % 4 != 0 :
662- raise RuntimeError ("automatic padding not implemented yet" )
663- weight_range_shifted = weight .add (2 ).view (torch .uint8 )
664- weight_view = weight_range_shifted .view (
665- weight .shape [0 ], weight .shape [1 ] // 4 , 4
666- )
667- weight_0 = weight_view [:, :, 0 ]
668- weight_1 = weight_view [:, :, 1 ] << 2
669- weight_2 = weight_view [:, :, 2 ] << 4
670- weight_3 = weight_view [:, :, 3 ] << 6
671- weight_packed = weight_0 + weight_1 + weight_2 + weight_3
672- weight = weight_packed
673- elif self .bitwidth == 4 :
674- if weight .shape [- 1 ] % 2 != 0 :
675- raise RuntimeError ("automatic padding not implemented yet" )
676- weight_range_shifted = weight .add (8 ).view (torch .uint8 )
677- weight_view = weight_range_shifted .view (
678- weight .shape [0 ], weight .shape [1 ] // 2 , 2
679- )
680- weight_even = weight_view [:, :, 0 ] * 16 # left shift 4
681- weight_odd = weight_view [:, :, 1 ]
682- weight_packed = weight_even + weight_odd
683- weight = weight_packed
684-
685- weight = weight .to (device = self .device )
686- scales = scales .to (device = self .device )
687- # Update state dict
688- cur_state_dict [f"{ fqn } .weight" ] = weight
689- # squeeze makes group_size=rowsize unidimensional
690- cur_state_dict [f"{ fqn } .scales" ] = scales .squeeze (dim = - 1 )
691-
692- return cur_state_dict
693-
694- def convert_for_runtime (self ) -> nn .Module :
695- replace_embedding_weight_only_grouped_int8_per_channel (
696- self .mod , self .device , self .bitwidth , self .group_size , self .packed
697- )
698- return self .mod
699-
700- def quantized_model (self ) -> nn .Module :
701- model_updated_state_dict = self .create_quantized_state_dict (self .packed )
702- self .convert_for_runtime ()
703- self .mod .load_state_dict (model_updated_state_dict , assign = True )
704- return self .mod
705-
706-
707- class QuantizedGroupEmbedding (torch .nn .Module ):
708- def __init__ (
709- self ,
710- device ,
711- vocab_size : int ,
712- embedding_dim : int ,
713- group_size : Optional [int ] = None ,
714- dtype = torch .half ,
715- packed = False ,
716- bitwidth : int = 8 ,
717- ) -> None :
718- super ().__init__ ()
719- if group_size is None or group_size == 0 :
720- group_size = embedding_dim
721- self .group_size = group_size
722- self .dtype = dtype
723- self .packed = packed
724- self .bitwidth = bitwidth
725- if not packed :
726- self .register_buffer (
727- "weight" ,
728- torch .zeros (
729- (vocab_size , embedding_dim ), dtype = torch .int8 , device = device
730- ),
731- )
732- else : # packed
733- if bitwidth == 2 :
734- self .register_buffer (
735- "weight" ,
736- torch .zeros (
737- (vocab_size , embedding_dim // 4 ),
738- dtype = torch .uint8 ,
739- device = device ,
740- ),
741- )
742- elif bitwidth == 4 :
743- self .register_buffer (
744- "weight" ,
745- torch .zeros (
746- (vocab_size , embedding_dim // 2 ),
747- dtype = torch .uint8 ,
748- device = device ,
749- ),
750- )
751-
752- groups_per_row = (embedding_dim + group_size - 1 ) // group_size
753- if groups_per_row > 1 :
754- self .register_buffer (
755- "scales" ,
756- torch .ones (
757- (vocab_size , groups_per_row ), dtype = torch .float16 , device = device
758- ),
759- )
760- else :
761- self .register_buffer (
762- "scales" , torch .ones ((vocab_size ,), dtype = torch .float16 , device = device )
763- )
764-
765- @torch .no_grad ()
766- def forward (self , indices : torch .Tensor ) -> torch .Tensor :
767- if not self .packed : # 8bit
768- return torch .ops .quantized_decomposed .embedding_byte .dtype (
769- self .weight , self .scales , None , - 128 , 127 , indices , dtype = self .dtype
770- )
771- else : # packed
772- if self .bitwidth == 2 :
773- return torch .ops .quantized_decomposed .embedding_2bit .dtype (
774- self .weight , self .scales , None , - 2 , 1 , indices , dtype = self .dtype
775- )
568+ ############################ Source Transform Start #######################
776569
777- # Remaining case (always return to make pyre happy)
778- assert self .bitwidth == 4
779- return torch .ops .quantized_decomposed .embedding_4bit .dtype (
780- self .weight , self .scales , None , - 8 , 7 , indices , dtype = self .dtype
781- )
782570
571+ def get_quant_embedding_transform (args , dtype_override : Optional [DType ] = None ):
572+ use_torchao = args .embedding_quantize .startswith ("torchao:" )
573+ if use_torchao :
574+ quant_args = args .embedding_quantize .split (":" )[1 ].split ("," )
575+ else :
576+ quant_args = args .embedding_quantize .split ("," )
783577
784- ############################ Source Transform Start #######################
578+ bitwidth = int (quant_args [0 ])
579+ group_size = quant_args [0 ]
580+ if group_size in ["none" , "None" , "0" ]:
581+ group_size = 0
582+ group_size = int (group_size )
583+ is_symmetric = bool (quant_args [3 ]) if len (quant_args ) > 2 else True
785584
585+ weight_dtype = getattr (torch , f"int{ bitwidth } " )
586+ granularity = PerAxis (0 ) if group_size == 0 else PerGroup (group_size )
587+ mapping_type = MappingType .SYMMETRIC if is_symmetric else MappingType .ASYMMETRIC
786588
787- def get_quant_embedding_transform (args , dtype_override : Optional [DType ] = None ):
788- if args .embedding_quantize .startswith ("torchao:" ):
589+ if use_torchao :
789590 from torchao .experimental .quant_api import (
790591 EmbeddingQuantizer ,
791592 SharedEmbeddingQuantizer ,
792593 )
793- from torchao .quantization .granularity import PerAxis , PerGroup
794- from torchao .quantization .quant_api import MappingType
795-
796- quant_args = args .embedding_quantize .split (":" )[1 ].split ("," )
797- if len (quant_args ) == 2 :
798- bitwidth , group_size = quant_args
799- is_asymmetric = True
800- else :
801- bitwidth , group_size , is_asymmetric = quant_args
802-
803- if group_size in ["none" , "None" , "0" ]:
804- group_size = 0
805-
806- group_size = int (group_size )
807- bitwidth = int (bitwidth )
808- is_asymmetric = bool (is_asymmetric )
809- weight_dtype = getattr (torch , f"int{ bitwidth } " )
810- granularity = PerAxis (0 ) if group_size == 0 else PerGroup (group_size )
811- mapping_type = (
812- MappingType .ASYMMETRIC if is_asymmetric else MappingType .SYMMETRIC
813- )
814594
815595 def _torchao_embedding_quantizer (model ):
816596 with torch .no_grad ():
@@ -831,20 +611,23 @@ def _torchao_embedding_quantizer(model):
831611
832612 return _torchao_embedding_quantizer
833613
834- bitwidth , group_size = args .embedding_quantize .split ("," )
835- if group_size == "none" or group_size == "None" or group_size == "0" :
836- group_size = None
837- else :
838- group_size = int (group_size )
839- bitwidth = int (bitwidth )
840- torch_dtype = dtype_override .to_torch_dtype () if dtype_override else None
841- return lambda model : EmbeddingQuantHandler (
842- model ,
843- bitwidth = bitwidth ,
844- group_size = group_size ,
845- packed = (bitwidth in [2 , 4 ]),
846- precision = torch_dtype ,
847- ).quantized_model ()
614+ def _quantize_embedding (model ):
615+ assert weight_dtype in [
616+ torch .int2 ,
617+ torch .int4 ,
618+ torch .int8 ,
619+ ], "Only 2, 4, or 8-bit embeddings are supported unless using torchao"
620+ quantize_ (
621+ model ,
622+ IntxWeightOnlyConfig (
623+ weight_dtype = weight_dtype ,
624+ granularity = granularity ,
625+ mapping_type = mapping_type ,
626+ ),
627+ )
628+ return model
629+
630+ return _quantize_embedding
848631
849632
850633def get_quant_weight_transform (
0 commit comments