2222    "get_quant_patterns_and_replacements" ,
2323]
2424
25+ 
26+ from  torch  import  Tensor 
27+ from  torch .library  import  custom_op 
28+ @custom_op ("quant_fusion::_pack_embedding_weight" , mutates_args = ()) 
29+ def  _pack_embedding_weight (weight : Tensor , bitwidth : int ) ->  Tensor :
30+     num_embeddings , embedding_dim  =  weight .shape 
31+ 
32+     if  bitwidth  ==  2 :
33+         assert  embedding_dim  %  4  ==  0 , "embedding_dim must be divisible by 4" 
34+         weight_range_shifted  =  weight .add (2 ).view (torch .uint8 )
35+         weight_view  =  weight_range_shifted .view (
36+             num_embeddings , embedding_dim  //  4 , 4 
37+         )
38+         weight_0  =  weight_view [:, :, 0 ]
39+         weight_1  =  weight_view [:, :, 1 ] <<  2 
40+         weight_2  =  weight_view [:, :, 2 ] <<  4 
41+         weight_3  =  weight_view [:, :, 3 ] <<  6 
42+         packed_weight  =  weight_0  +  weight_1  +  weight_2  +  weight_3 
43+         return  packed_weight 
44+     elif  bitwidth  ==  4 :
45+         assert  embedding_dim  %  2  ==  0 , "embedding_dim must be divisible by 2" 
46+         weight_range_shifted  =  weight .add (8 ).view (torch .uint8 )
47+         weight_view  =  weight_range_shifted .view (
48+             weight .shape [0 ], weight .shape [1 ] //  2 , 2 
49+         )
50+         weight_even  =  weight_view [:, :, 0 ] *  16   # left shift 4 
51+         weight_odd  =  weight_view [:, :, 1 ]
52+         packed_weight  =  weight_even  +  weight_odd 
53+         return  packed_weight 
54+     elif  bitwidth  ==  8 :
55+         return  weight 
56+     
57+     raise  RuntimeError (f"Unsupported bitwidth { bitwidth }  " )
58+ 
59+ 
60+ # Use register_fake to add a ``FakeTensor`` kernel for the operator 
61+ @_pack_embedding_weight .register_fake  
62+ def  _ (weight , bit_width ):
63+     assert  bit_width  in  [2 , 4 , 8 ]
64+     num_embeddings , embedding_dim  =  weight .shape 
65+     values_per_byte  =  8  //  bit_width 
66+     assert  embedding_dim  %  values_per_byte  ==  0 
67+     return  torch .empty (num_embeddings , embedding_dim  //  values_per_byte , dtype = torch .uint8 , device = weight .device )
68+ 
69+ 
2570# TODO: extending an existing library that is defined in OSS might be a bit 
2671# confusing, we can investigate if it is possible to define a new library 
2772
@@ -70,7 +115,7 @@ def embedding_weight_checks(weight, weight_scales, weight_zero_points):
70115        weight_zero_points  is  None  or  weight_zero_points .dtype  ==  weight_scales .dtype 
71116    ), "Expecting weight_zero_points to be None or have same dtype as weight_scales" 
72117    assert  (
73-         weight_zero_points  is  None  or  weight_zero_points .dim () ==   1 
118+         weight_zero_points  is  None  or  weight_zero_points .dim () in  [ 1 ,  2 ] 
74119    ), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found { weight_zero_points .dim ()}  " 
75120    assert  weight_zero_points  is  None  or  weight_zero_points .size (0 ) ==  weight .size (
76121        0 
@@ -233,6 +278,19 @@ def embedding_2bit(
233278    )
234279    return  torch .ops .aten .embedding .default (weight , indices )
235280
281+ @register_fake ("quantized_decomposed::embedding_2bit" ) 
282+ def  _ (
283+     weight : torch .Tensor ,
284+     weight_scales : torch .Tensor ,
285+     weight_zero_points : Optional [torch .Tensor ],
286+     weight_quant_min : int ,
287+     weight_quant_max : int ,
288+     indices : torch .Tensor ,
289+     ):
290+     num_embeddings , packed_embedding_dim  =  weight .shape 
291+     embedding_dim  =  packed_embedding_dim  *  4 
292+     embedding  =  torch .nn .Embedding (num_embeddings , embedding_dim , device = weight .device )
293+     return  embedding (indices )
236294
237295@register_fake ("quantized_decomposed::embedding_2bit.out" ) 
238296def  embedding_2bit_out_meta (
@@ -253,7 +311,6 @@ def embedding_2bit_out_meta(
253311        indices ,
254312    )
255313
256- 
257314@impl (quantized_decomposed_lib , "embedding_2bit.dtype" , "CompositeExplicitAutograd" ) 
258315def  embedding_2bit_dtype (
259316    weight : torch .Tensor ,
@@ -295,6 +352,20 @@ def embedding_2bit_dtype(
295352    )
296353    return  torch .ops .aten .embedding .default (weight , indices )
297354
355+ @register_fake ("quantized_decomposed::embedding_2bit.dtype" ) 
356+ def  _ (
357+     weight : torch .Tensor ,
358+     weight_scales : torch .Tensor ,
359+     weight_zero_points : Optional [torch .Tensor ],
360+     weight_quant_min : int ,
361+     weight_quant_max : int ,
362+     indices : torch .Tensor ,
363+     dtype : Optional [torch .dtype ],
364+     ) ->  torch .Tensor :
365+     num_embeddings , packed_embedding_dim  =  weight .shape 
366+     embedding_dim  =  packed_embedding_dim  *  4 
367+     embedding  =  torch .nn .Embedding (num_embeddings , embedding_dim , device = weight .device )
368+     return  embedding (indices ).to (dtype )
298369
299370@register_fake ("quantized_decomposed::embedding_2bit.dtype_out" ) 
300371def  embedding_2bit_dtype_out_meta (
@@ -377,6 +448,19 @@ def embedding_4bit(
377448    )
378449    return  torch .ops .aten .embedding .default (weight , indices )
379450
451+ @register_fake ("quantized_decomposed::embedding_4bit" ) 
452+ def  _ (
453+     weight : torch .Tensor ,
454+     weight_scales : torch .Tensor ,
455+     weight_zero_points : Optional [torch .Tensor ],
456+     weight_quant_min : int ,
457+     weight_quant_max : int ,
458+     indices : torch .Tensor ,
459+     ):
460+     num_embeddings , packed_embedding_dim  =  weight .shape 
461+     embedding_dim  =  packed_embedding_dim  *  2 
462+     embedding  =  torch .nn .Embedding (num_embeddings , embedding_dim , device = weight .device )
463+     return  embedding (indices )
380464
381465@register_fake ("quantized_decomposed::embedding_4bit.out" ) 
382466def  embedding_4bit_out_meta (
@@ -437,6 +521,20 @@ def embedding_4bit_dtype(
437521    )
438522    return  torch .ops .aten .embedding .default (weight , indices )
439523
524+ @register_fake ("quantized_decomposed::embedding_4bit.dtype" ) 
525+ def  _ (
526+     weight : torch .Tensor ,
527+     weight_scales : torch .Tensor ,
528+     weight_zero_points : Optional [torch .Tensor ],
529+     weight_quant_min : int ,
530+     weight_quant_max : int ,
531+     indices : torch .Tensor ,
532+     dtype : Optional [torch .dtype ],
533+     ) ->  torch .Tensor :
534+     num_embeddings , packed_embedding_dim  =  weight .shape 
535+     embedding_dim  =  packed_embedding_dim  *  2 
536+     embedding  =  torch .nn .Embedding (num_embeddings , embedding_dim , device = weight .device )
537+     return  embedding (indices ).to (dtype )
440538
441539@register_fake ("quantized_decomposed::embedding_4bit.dtype_out" ) 
442540def  embedding_4bit_dtype_out_meta (
@@ -872,6 +970,76 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
872970        )
873971    ]
874972
973+ def  _get_embedding_ops_patterns_and_replacements_torchao () ->   List [Tuple [Callable , Callable , List [Callable ]]]:
974+     def  embedding_byte_pattern (indices , int_data , group_size , scale , zero_point ):
975+         dq  =  torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 128 , 127 )
976+         return  torch .ops .aten .embedding .default (dq , indices )
977+     def  embedding_byte_replacement (indices , int_data , group_size , scale , zero_point ):
978+         zero_point_dtype_cast  =  torch .ops .aten .to .dtype (zero_point , scale .dtype )
979+         return  torch .ops .quantized_decomposed .embedding_byte .default (
980+             int_data ,
981+             scale ,
982+             zero_point_dtype_cast ,
983+             - 128 ,
984+             127 ,
985+             indices ,
986+         )
987+     def  embedding_byte_dtype_pattern (indices , int_data , group_size , scale , zero_point , output_dtype ):
988+         dq  =  torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 128 , 127 , 'INT' , output_dtype )
989+         return  torch .ops .aten .embedding .default (dq , indices )
990+     def  embedding_byte_dtype_replacement (indices , int_data , group_size , scale , zero_point , output_dtype ):
991+         zero_point_dtype_cast  =  torch .ops .aten .to .dtype (zero_point , scale .dtype )
992+         return  torch .ops .quantized_decomposed .embedding_byte .dtype (
993+             int_data ,
994+             scale ,
995+             zero_point_dtype_cast ,
996+             - 128 ,
997+             127 ,
998+             indices ,
999+             dtype = output_dtype 
1000+         )
1001+     
1002+     def  embedding_2bit_pattern (indices , int_data , group_size , scale , zero_point ):
1003+         dq  =  torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 2 , 1 )
1004+         return  torch .ops .aten .embedding .default (dq , indices )
1005+     def  embedding_2bit_replacement (indices , int_data , group_size , scale , zero_point ):
1006+         packed_int_data  =  torch .ops .quant_fusion ._pack_embedding_weight .default (int_data , 2 )
1007+         zero_point_dtype_cast  =  torch .ops .aten .to .dtype (zero_point , scale .dtype )
1008+         return  torch .ops .quantized_decomposed .embedding_2bit .default (packed_int_data , scale , zero_point_dtype_cast , - 2 , 1 , indices )
1009+ 
1010+     def  embedding_2bit_dtype_pattern (indices , int_data , group_size , scale , zero_point , output_dtype ):
1011+         dq  =  torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 2 , 1 , 'INT' , output_dtype )
1012+         return  torch .ops .aten .embedding .default (dq , indices )
1013+     def  embedding_2bit_dtype_replacement (indices , int_data , group_size , scale , zero_point , output_dtype ):
1014+         packed_int_data  =  torch .ops .quant_fusion ._pack_embedding_weight .default (int_data , 2 )
1015+         zero_point_dtype_cast  =  torch .ops .aten .to .dtype (zero_point , scale .dtype )
1016+         return  torch .ops .quantized_decomposed .embedding_2bit .dtype (packed_int_data , scale , zero_point_dtype_cast , - 2 , 1 , indices , dtype = output_dtype )
1017+     
1018+     def  embedding_4bit_pattern (indices , int_data , group_size , scale , zero_point ):
1019+         dq  =  torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 8 , 7 )
1020+         return  torch .ops .aten .embedding .default (dq , indices )
1021+     def  embedding_4bit_replacement (indices , int_data , group_size , scale , zero_point ):
1022+         packed_int_data  =  torch .ops .quant_fusion ._pack_embedding_weight .default (int_data , 4 )
1023+         zero_point_dtype_cast  =  torch .ops .aten .to .dtype (zero_point , scale .dtype )
1024+         return  torch .ops .quantized_decomposed .embedding_4bit .default (packed_int_data , scale , zero_point_dtype_cast , - 8 , 7 , indices )
1025+     
1026+     def  embedding_4bit_dtype_pattern (indices , int_data , group_size , scale , zero_point , output_dtype ):
1027+         dq  =  torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 8 , 7 , 'INT' , output_dtype )
1028+         return  torch .ops .aten .embedding .default (dq , indices )
1029+     def  embedding_4bit_dtype_replacement (indices , int_data , group_size , scale , zero_point , output_dtype ):
1030+         packed_int_data  =  torch .ops .quant_fusion ._pack_embedding_weight .default (int_data , 4 )
1031+         zero_point_dtype_cast  =  torch .ops .aten .to .dtype (zero_point , scale .dtype )
1032+         return  torch .ops .quantized_decomposed .embedding_4bit .dtype (packed_int_data , scale , zero_point_dtype_cast , - 8 , 7 , indices , dtype = output_dtype )
1033+ 
1034+     return  [
1035+         (_trace_and_lower_to_edge_ops (embedding_byte_pattern ), _trace_and_lower_to_edge_ops (embedding_byte_replacement ), []),
1036+         (_trace_and_lower_to_edge_ops (embedding_byte_dtype_pattern ), _trace_and_lower_to_edge_ops (embedding_byte_dtype_replacement ), []),
1037+         (_trace_and_lower_to_edge_ops (embedding_2bit_pattern ), _trace_and_lower_to_edge_ops (embedding_2bit_replacement ), []),
1038+         (_trace_and_lower_to_edge_ops (embedding_2bit_dtype_pattern ), _trace_and_lower_to_edge_ops (embedding_2bit_dtype_replacement ), []),
1039+         (_trace_and_lower_to_edge_ops (embedding_4bit_pattern ), _trace_and_lower_to_edge_ops (embedding_4bit_replacement ), []),
1040+         (_trace_and_lower_to_edge_ops (embedding_4bit_dtype_pattern ), _trace_and_lower_to_edge_ops (embedding_4bit_dtype_replacement ), []),
1041+     ]
1042+ 
8751043
8761044def  _get_embedding_ops_patterns_and_replacements () ->  (
8771045    List [Tuple [Callable , Callable , List [Callable ]]]
@@ -1167,5 +1335,6 @@ def get_quant_patterns_and_replacements() -> (
11671335            * _get_slice_patterns_and_replacements (),
11681336            # *_get_fixed_qparams_ops_patterns_and_replacements(), 
11691337            * _get_embedding_ops_patterns_and_replacements (),
1338+             * _get_embedding_ops_patterns_and_replacements_torchao (),
11701339        ]
11711340    )
0 commit comments