@@ -494,6 +494,7 @@ def replace_embedding_weight_only_grouped_int8_per_channel(
494494 group_size = group_size ,
495495 dtype = child .weight .dtype ,
496496 packed = packed ,
497+ bitwidth = bitwidth ,
497498 ),
498499 )
499500 else :
@@ -519,14 +520,17 @@ def __init__(
519520 self .group_size = group_size
520521 self .bitwidth = bitwidth
521522 self .packed = packed
522- if (bitwidth != 4 ) and packed :
523- raise RuntimeError ("pack only works with bitsize 4" )
523+ if (bitwidth not in [ 2 , 4 ] ) and packed :
524+ raise RuntimeError ("pack only works with bitsize 2, 4" )
524525
525526 @torch .no_grad ()
526527 def create_quantized_state_dict (self , packed = False ) -> Dict :
527528 cur_state_dict = self .mod .state_dict ()
528529
529- if self .bitwidth == 4 :
530+ if self .bitwidth == 2 :
531+ range_min = - 2
532+ range_max = 1
533+ elif self .bitwidth == 4 :
530534 range_min = - 8
531535 range_max = 7
532536 elif self .bitwidth == 8 :
@@ -555,17 +559,30 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
555559 )
556560
557561 if packed :
558- if weight .shape [- 1 ] % 2 != 0 :
559- raise RuntimeError ("automatic padding not implemented yet" )
560-
561- weight_range_shifted = weight .add (8 ).view (torch .uint8 )
562- weight_view = weight_range_shifted .view (
563- weight .shape [0 ], weight .shape [1 ] // 2 , 2
564- )
565- weight_even = weight_view [:, :, 0 ] * 16 # left shift 4
566- weight_odd = weight_view [:, :, 1 ]
567- weight_packed = weight_even + weight_odd
568- weight = weight_packed
562+ if self .bitwidth == 2 :
563+ if weight .shape [- 1 ] % 4 != 0 :
564+ raise RuntimeError ("automatic padding not implemented yet" )
565+ weight_range_shifted = weight .add (2 ).view (torch .uint8 )
566+ weight_view = weight_range_shifted .view (
567+ weight .shape [0 ], weight .shape [1 ] // 4 , 4
568+ )
569+ weight_0 = weight_view [:, :, 0 ]
570+ weight_1 = weight_view [:, :, 1 ] << 2
571+ weight_2 = weight_view [:, :, 2 ] << 4
572+ weight_3 = weight_view [:, :, 3 ] << 6
573+ weight_packed = weight_0 + weight_1 + weight_2 + weight_3
574+ weight = weight_packed
575+ elif self .bitwidth == 4 :
576+ if weight .shape [- 1 ] % 2 != 0 :
577+ raise RuntimeError ("automatic padding not implemented yet" )
578+ weight_range_shifted = weight .add (8 ).view (torch .uint8 )
579+ weight_view = weight_range_shifted .view (
580+ weight .shape [0 ], weight .shape [1 ] // 2 , 2
581+ )
582+ weight_even = weight_view [:, :, 0 ] * 16 # left shift 4
583+ weight_odd = weight_view [:, :, 1 ]
584+ weight_packed = weight_even + weight_odd
585+ weight = weight_packed
569586
570587 weight = weight .to (device = self .device )
571588 scales = scales .to (device = self .device )
@@ -598,13 +615,15 @@ def __init__(
598615 group_size : Optional [int ] = None ,
599616 dtype = torch .half ,
600617 packed = False ,
618+ bitwidth : int = 8 ,
601619 ) -> None :
602620 super ().__init__ ()
603621 if group_size is None or group_size == 0 :
604622 group_size = embedding_dim
605623 self .group_size = group_size
606624 self .dtype = dtype
607625 self .packed = packed
626+ self .bitwidth = bitwidth
608627 if not packed :
609628 self .register_buffer (
610629 "weight" ,
@@ -613,12 +632,25 @@ def __init__(
613632 ),
614633 )
615634 else : # packed
616- self .register_buffer (
617- "weight" ,
618- torch .empty (
619- (vocab_size , embedding_dim // 2 ), dtype = torch .uint8 , device = device
620- ),
621- )
635+ if bitwidth == 2 :
636+ self .register_buffer (
637+ "weight" ,
638+ torch .empty (
639+ (vocab_size , embedding_dim // 4 ),
640+ dtype = torch .uint8 ,
641+ device = device ,
642+ ),
643+ )
644+ elif bitwidth == 4 :
645+ self .register_buffer (
646+ "weight" ,
647+ torch .empty (
648+ (vocab_size , embedding_dim // 2 ),
649+ dtype = torch .uint8 ,
650+ device = device ,
651+ ),
652+ )
653+
622654 groups_per_row = (embedding_dim + group_size - 1 ) // group_size
623655 if groups_per_row > 1 :
624656 self .register_buffer (
@@ -638,7 +670,14 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
638670 return torch .ops .quantized_decomposed .embedding_byte .dtype (
639671 self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
640672 )
641- else : # 4bit packed
673+ else : # packed
674+ if self .bitwidth == 2 :
675+ return torch .ops .quantized_decomposed .embedding_2bit .dtype (
676+ self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
677+ )
678+
679+ # Remaining case (always return to make pyre happy)
680+ assert self .bitwidth == 4
642681 return torch .ops .quantized_decomposed .embedding_4bit .dtype (
643682 self .weight , self .scales , None , 0 , 0 , indices , dtype = self .dtype
644683 )
@@ -658,7 +697,7 @@ def get_quant_embedding_transform(args):
658697 model ,
659698 bitwidth = bitwidth ,
660699 group_size = group_size ,
661- packed = (bitwidth == 4 ),
700+ packed = (bitwidth in [ 2 , 4 ] ),
662701 ).quantized_model ()
663702
664703
0 commit comments