@@ -172,6 +172,138 @@ def embedding_byte_dtype_out_meta(
172172 )
173173
174174
175+ quantized_decomposed_lib .define (
176+ "embedding_2bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
177+ "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor" ,
178+ )
179+
180+ quantized_decomposed_lib .define (
181+ "embedding_2bit.dtype(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
182+ "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None) -> Tensor" ,
183+ )
184+
185+ quantized_decomposed_lib .define (
186+ "embedding_2bit.out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
187+ "int weight_quant_min, int weight_quant_max, Tensor indices, *, Tensor(a!) out) -> Tensor(a!)" ,
188+ )
189+
190+ quantized_decomposed_lib .define (
191+ "embedding_2bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
192+ "int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)" ,
193+ )
194+
195+
196+ @impl (quantized_decomposed_lib , "embedding_2bit" , "CompositeExplicitAutograd" )
197+ def embedding_2bit (
198+ weight : torch .Tensor ,
199+ weight_scales : torch .Tensor ,
200+ weight_zero_points : Optional [torch .Tensor ],
201+ weight_quant_min : int ,
202+ weight_quant_max : int ,
203+ indices : torch .Tensor ,
204+ ) -> torch .Tensor :
205+ embedding_weight_checks (weight , weight_scales , weight_zero_points )
206+ group_size = (4 * weight .size (1 )) // (
207+ weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
208+ )
209+ weight_0 = (weight & 3 )
210+ weight_1 = (weight & 12 ) >> 2
211+ weight_2 = (weight & 48 ) >> 4
212+ weight_3 = (weight & 192 ) >> 6
213+ weight_unpacked = torch .stack ((weight_0 , weight_1 , weight_2 , weight_3 ), dim = - 1 )
214+ weight = weight_unpacked .view (weight .shape [0 ], - 1 )
215+ weight = weight .view (torch .int8 ).add (- 2 )
216+
217+ weight = torch .ops .quantized_decomposed .dequantize_per_channel_group .default (
218+ weight ,
219+ weight_scales ,
220+ weight_zero_points ,
221+ weight_quant_min ,
222+ weight_quant_max ,
223+ weight .dtype ,
224+ group_size ,
225+ weight_scales .dtype ,
226+ )
227+ return torch .ops .aten .embedding .default (weight , indices )
228+
229+
230+ @register_fake ("quantized_decomposed::embedding_2bit.out" )
231+ def embedding_2bit_out_meta (
232+ weight : torch .Tensor ,
233+ weight_scales : torch .Tensor ,
234+ weight_zero_points : Optional [torch .Tensor ],
235+ weight_quant_min : int ,
236+ weight_quant_max : int ,
237+ indices : torch .Tensor ,
238+ out : torch .Tensor ,
239+ ) -> torch .Tensor :
240+ return embedding_2bit (
241+ weight ,
242+ weight_scales ,
243+ weight_zero_points ,
244+ weight_quant_min ,
245+ weight_quant_max ,
246+ indices ,
247+ )
248+
249+
250+ @impl (quantized_decomposed_lib , "embedding_2bit.dtype" , "CompositeExplicitAutograd" )
251+ def embedding_2bit_dtype (
252+ weight : torch .Tensor ,
253+ weight_scales : torch .Tensor ,
254+ weight_zero_points : Optional [torch .Tensor ],
255+ weight_quant_min : int ,
256+ weight_quant_max : int ,
257+ indices : torch .Tensor ,
258+ dtype : Optional [torch .dtype ],
259+ ) -> torch .Tensor :
260+ embedding_weight_checks (weight , weight_scales , weight_zero_points )
261+ group_size = (4 * weight .size (1 )) // (
262+ weight_scales .size (1 ) if weight_scales .dim () == 2 else 1
263+ )
264+ weight_0 = (weight & 3 )
265+ weight_1 = (weight & 12 ) >> 2
266+ weight_2 = (weight & 48 ) >> 4
267+ weight_3 = (weight & 192 ) >> 6
268+ weight_unpacked = torch .stack ((weight_0 , weight_1 , weight_2 , weight_3 ), dim = - 1 )
269+ weight = weight_unpacked .view (weight .shape [0 ], - 1 )
270+ weight = weight .view (torch .int8 ).add (- 2 )
271+
272+ weight = torch .ops .quantized_decomposed .dequantize_per_channel_group .default (
273+ weight ,
274+ weight_scales ,
275+ weight_zero_points ,
276+ weight_quant_min ,
277+ weight_quant_max ,
278+ weight .dtype ,
279+ group_size ,
280+ dtype ,
281+ )
282+ return torch .ops .aten .embedding .default (weight , indices )
283+
284+
285+ @register_fake ("quantized_decomposed::embedding_2bit.dtype_out" )
286+ def embedding_2bit_dtype_out_meta (
287+ weight : torch .Tensor ,
288+ weight_scales : torch .Tensor ,
289+ weight_zero_points : Optional [torch .Tensor ],
290+ weight_quant_min : int ,
291+ weight_quant_max : int ,
292+ indices : torch .Tensor ,
293+ dtype : Optional [torch .dtype ],
294+ out : torch .Tensor ,
295+ ) -> torch .Tensor :
296+ return embedding_2bit_dtype (
297+ weight ,
298+ weight_scales ,
299+ weight_zero_points ,
300+ weight_quant_min ,
301+ weight_quant_max ,
302+ indices ,
303+ dtype ,
304+ )
305+
306+
175307quantized_decomposed_lib .define (
176308 "embedding_4bit(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
177309 "int weight_quant_min, int weight_quant_max, Tensor indices) -> Tensor" ,
0 commit comments