@@ -73,69 +73,24 @@ def quantize(  # noqa C901
7373        # Add quantization mode options here: group size, bit width, etc. 
7474        return  WeightOnlyInt8QuantHandler (model ).quantized_model ()
7575    elif  qmode .startswith ("torchao:" ):
76-         import  glob 
77-         import  os 
78- 
79-         libs  =  glob .glob (
80-             os .path .abspath (
81-                 os .path .join (
82-                     os .environ .get ("CMAKE_INSTALL_PREFIX" , "" ),
83-                     "lib/libtorchao_ops_aten.*" ,
84-                 )
85-             )
86-         )
87-         assert  (
88-             len (libs ) ==  1 
89-         ), f"Expected 1 library but got { len (libs )}  
90-         logging .info (f"Loading custom ops library: { libs [0 ]}  )
91-         torch .ops .load_library (libs [0 ])
92- 
93-         logging .warning (
94-             "When qmode is torchao, the groupsize is obtained from the qmode string with regex parse; blocksize is ignored." 
95-         )
96-         embedding_pattern  =  r"emb.(\d+),(\d+)" 
97-         linear_pattern  =  r"lin8da.(\d+),(\d+)" 
98- 
99-         matches  =  re .findall (linear_pattern , qmode )
100-         if  matches :
101-             assert  (
102-                 len (matches ) ==  1 
103-             ), f"Expected 1 match for linear_pattern but got { len (matches )}  
104-             bitwidth  =  int (matches [0 ][0 ])
105-             groupsize  =  int (matches [0 ][1 ])
106-             from  torchao .experimental .quant_api  import  (
107-                 Int8DynActIntxWeightLinearQuantizer ,
108-             )
109- 
110-             with  torch .no_grad ():
111-                 model  =  Int8DynActIntxWeightLinearQuantizer (
112-                     device = "cpu" ,
113-                     precision = torch_dtype ,
114-                     groupsize = groupsize ,
115-                     bitwidth = bitwidth ,
116-                     has_weight_zeros = False ,
117-                 ).quantize (model )
118- 
119-         matches  =  re .findall (embedding_pattern , qmode )
120-         if  matches :
121-             assert  (
122-                 len (matches ) ==  1 
123-             ), f"Expected 1 match for embedding_pattern but got { len (matches )}  
124-             bitwidth  =  int (matches [0 ][0 ])
125-             groupsize  =  int (matches [0 ][1 ])
126-             from  torchao .experimental .quant_api  import  IntxWeightEmbeddingQuantizer 
127- 
128-             with  torch .no_grad ():
129-                 model  =  IntxWeightEmbeddingQuantizer (
130-                     device = "cpu" ,
131-                     precision = torch_dtype ,
132-                     bitwidth = bitwidth ,
133-                     groupsize = groupsize ,
134-                 ).quantize (model )
76+         pattern  =  r"torchao:8da(\d+)w" 
77+         matches  =  re .findall (pattern , qmode )
78+         assert  len (matches ) ==  1 , f"Expected 1 match for pattern but got { len (matches )}  
79+         bitwidth  =  int (matches [0 ][0 ])
80+         _load_torchao_ops_aten ()
81+         from  torchao .experimental .quant_api  import  Int8DynActIntxWeightLinearQuantizer 
82+ 
83+         with  torch .no_grad ():
84+             model  =  Int8DynActIntxWeightLinearQuantizer (
85+                 device = "cpu" ,
86+                 precision = torch .float32 ,
87+                 groupsize = group_size ,
88+                 bitwidth = bitwidth ,
89+                 has_weight_zeros = False ,
90+             ).quantize (model )
13591
13692        if  verbose :
13793            print ("quantized model:" , model )
138- 
13994        return  model 
14095    elif  qmode  ==  "8da4w" :
14196        # Check for required args 
@@ -760,6 +715,25 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
760715
761716
762717def  get_quant_embedding_transform (args ):
718+     if  args .embedding_quantize .startswith ("torchao:" ):
719+         bitwidth , group_size  =  args .embedding_quantize .split (":" )[1 ].split ("," )
720+         group_size  =  int (group_size )
721+         bitwidth  =  int (bitwidth )
722+         _load_torchao_ops_aten ()
723+         from  torchao .experimental .quant_api  import  IntxWeightEmbeddingQuantizer 
724+ 
725+         def  _torchao_embedding_quantizer (model ):
726+             with  torch .no_grad ():
727+                 model  =  IntxWeightEmbeddingQuantizer (
728+                     device = "cpu" ,
729+                     precision = torch .float32 ,
730+                     bitwidth = bitwidth ,
731+                     groupsize = group_size ,
732+                 ).quantize (model )
733+             return  model 
734+ 
735+         return  _torchao_embedding_quantizer 
736+ 
763737    bitwidth , group_size  =  args .embedding_quantize .split ("," )
764738    if  group_size  ==  "none"  or  group_size  ==  "None"  or  group_size  ==  "0" :
765739        group_size  =  None 
@@ -801,4 +775,23 @@ def get_quant_weight_transform(args, dtype_override, verbose):
801775    )
802776
803777
778+ def  _load_torchao_ops_aten ():
779+     import  glob 
780+     import  os 
781+ 
782+     libs  =  glob .glob (
783+         os .path .abspath (
784+             os .path .join (
785+                 os .environ .get ("CMAKE_INSTALL_PREFIX" , "" ),
786+                 "lib/libtorchao_ops_aten.*" ,
787+             )
788+         )
789+     )
790+     assert  (
791+         len (libs ) ==  1 
792+     ), f"Expected 1 library but got { len (libs )}  
793+     logging .info (f"Loading custom ops library: { libs [0 ]}  )
794+     torch .ops .load_library (libs [0 ])
795+ 
796+ 
804797############################ Source Transform End ####################### 
0 commit comments