2222 "get_quant_patterns_and_replacements" ,
2323]
2424
25+
26+ from torch import Tensor
27+ from torch .library import custom_op
28+ @custom_op ("quant_fusion::_pack_embedding_weight" , mutates_args = ())
29+ def _pack_embedding_weight (weight : Tensor , bitwidth : int ) -> Tensor :
30+ num_embeddings , embedding_dim = weight .shape
31+
32+ if bitwidth == 2 :
33+ assert embedding_dim % 4 == 0 , "embedding_dim must be divisible by 4"
34+ weight_range_shifted = weight .add (2 ).view (torch .uint8 )
35+ weight_view = weight_range_shifted .view (
36+ num_embeddings , embedding_dim // 4 , 4
37+ )
38+ weight_0 = weight_view [:, :, 0 ]
39+ weight_1 = weight_view [:, :, 1 ] << 2
40+ weight_2 = weight_view [:, :, 2 ] << 4
41+ weight_3 = weight_view [:, :, 3 ] << 6
42+ packed_weight = weight_0 + weight_1 + weight_2 + weight_3
43+ return packed_weight
44+ elif bitwidth == 4 :
45+ assert embedding_dim % 2 == 0 , "embedding_dim must be divisible by 2"
46+ weight_range_shifted = weight .add (8 ).view (torch .uint8 )
47+ weight_view = weight_range_shifted .view (
48+ weight .shape [0 ], weight .shape [1 ] // 2 , 2
49+ )
50+ weight_even = weight_view [:, :, 0 ] * 16 # left shift 4
51+ weight_odd = weight_view [:, :, 1 ]
52+ packed_weight = weight_even + weight_odd
53+ return packed_weight
54+ elif bitwidth == 8 :
55+ return weight
56+
57+ raise RuntimeError (f"Unsupported bitwidth { bitwidth } " )
58+
59+
60+ # Use register_fake to add a ``FakeTensor`` kernel for the operator
61+ @_pack_embedding_weight .register_fake
62+ def _ (weight , bit_width ):
63+ assert bit_width in [2 , 4 , 8 ]
64+ num_embeddings , embedding_dim = weight .shape
65+ values_per_byte = 8 // bit_width
66+ assert embedding_dim % values_per_byte == 0
67+ return torch .empty (num_embeddings , embedding_dim // values_per_byte , dtype = torch .uint8 , device = weight .device )
68+
69+
2570# TODO: extending an existing library that is defined in OSS might be a bit
2671# confusing, we can investigate if it is possible to define a new library
2772
@@ -70,7 +115,7 @@ def embedding_weight_checks(weight, weight_scales, weight_zero_points):
70115 weight_zero_points is None or weight_zero_points .dtype == weight_scales .dtype
71116 ), "Expecting weight_zero_points to be None or have same dtype as weight_scales"
72117 assert (
73- weight_zero_points is None or weight_zero_points .dim () == 1
118+ weight_zero_points is None or weight_zero_points .dim () in [ 1 , 2 ]
74119 ), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found { weight_zero_points .dim ()} "
75120 assert weight_zero_points is None or weight_zero_points .size (0 ) == weight .size (
76121 0
@@ -233,6 +278,19 @@ def embedding_2bit(
233278 )
234279 return torch .ops .aten .embedding .default (weight , indices )
235280
281+ @register_fake ("quantized_decomposed::embedding_2bit" )
282+ def _ (
283+ weight : torch .Tensor ,
284+ weight_scales : torch .Tensor ,
285+ weight_zero_points : Optional [torch .Tensor ],
286+ weight_quant_min : int ,
287+ weight_quant_max : int ,
288+ indices : torch .Tensor ,
289+ ):
290+ num_embeddings , packed_embedding_dim = weight .shape
291+ embedding_dim = packed_embedding_dim * 4
292+ embedding = torch .nn .Embedding (num_embeddings , embedding_dim , device = weight .device )
293+ return embedding (indices )
236294
237295@register_fake ("quantized_decomposed::embedding_2bit.out" )
238296def embedding_2bit_out_meta (
@@ -253,7 +311,6 @@ def embedding_2bit_out_meta(
253311 indices ,
254312 )
255313
256-
257314@impl (quantized_decomposed_lib , "embedding_2bit.dtype" , "CompositeExplicitAutograd" )
258315def embedding_2bit_dtype (
259316 weight : torch .Tensor ,
@@ -295,6 +352,20 @@ def embedding_2bit_dtype(
295352 )
296353 return torch .ops .aten .embedding .default (weight , indices )
297354
355+ @register_fake ("quantized_decomposed::embedding_2bit.dtype" )
356+ def _ (
357+ weight : torch .Tensor ,
358+ weight_scales : torch .Tensor ,
359+ weight_zero_points : Optional [torch .Tensor ],
360+ weight_quant_min : int ,
361+ weight_quant_max : int ,
362+ indices : torch .Tensor ,
363+ dtype : Optional [torch .dtype ],
364+ ) -> torch .Tensor :
365+ num_embeddings , packed_embedding_dim = weight .shape
366+ embedding_dim = packed_embedding_dim * 4
367+ embedding = torch .nn .Embedding (num_embeddings , embedding_dim , device = weight .device )
368+ return embedding (indices ).to (dtype )
298369
299370@register_fake ("quantized_decomposed::embedding_2bit.dtype_out" )
300371def embedding_2bit_dtype_out_meta (
@@ -377,6 +448,19 @@ def embedding_4bit(
377448 )
378449 return torch .ops .aten .embedding .default (weight , indices )
379450
451+ @register_fake ("quantized_decomposed::embedding_4bit" )
452+ def _ (
453+ weight : torch .Tensor ,
454+ weight_scales : torch .Tensor ,
455+ weight_zero_points : Optional [torch .Tensor ],
456+ weight_quant_min : int ,
457+ weight_quant_max : int ,
458+ indices : torch .Tensor ,
459+ ):
460+ num_embeddings , packed_embedding_dim = weight .shape
461+ embedding_dim = packed_embedding_dim * 2
462+ embedding = torch .nn .Embedding (num_embeddings , embedding_dim , device = weight .device )
463+ return embedding (indices )
380464
381465@register_fake ("quantized_decomposed::embedding_4bit.out" )
382466def embedding_4bit_out_meta (
@@ -437,6 +521,20 @@ def embedding_4bit_dtype(
437521 )
438522 return torch .ops .aten .embedding .default (weight , indices )
439523
524+ @register_fake ("quantized_decomposed::embedding_4bit.dtype" )
525+ def _ (
526+ weight : torch .Tensor ,
527+ weight_scales : torch .Tensor ,
528+ weight_zero_points : Optional [torch .Tensor ],
529+ weight_quant_min : int ,
530+ weight_quant_max : int ,
531+ indices : torch .Tensor ,
532+ dtype : Optional [torch .dtype ],
533+ ) -> torch .Tensor :
534+ num_embeddings , packed_embedding_dim = weight .shape
535+ embedding_dim = packed_embedding_dim * 2
536+ embedding = torch .nn .Embedding (num_embeddings , embedding_dim , device = weight .device )
537+ return embedding (indices ).to (dtype )
440538
441539@register_fake ("quantized_decomposed::embedding_4bit.dtype_out" )
442540def embedding_4bit_dtype_out_meta (
@@ -872,6 +970,76 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
872970 )
873971 ]
874972
973+ def _get_embedding_ops_patterns_and_replacements_torchao () -> List [Tuple [Callable , Callable , List [Callable ]]]:
974+ def embedding_byte_pattern (indices , int_data , group_size , scale , zero_point ):
975+ dq = torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 128 , 127 )
976+ return torch .ops .aten .embedding .default (dq , indices )
977+ def embedding_byte_replacement (indices , int_data , group_size , scale , zero_point ):
978+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
979+ return torch .ops .quantized_decomposed .embedding_byte .default (
980+ int_data ,
981+ scale ,
982+ zero_point_dtype_cast ,
983+ - 128 ,
984+ 127 ,
985+ indices ,
986+ )
987+ def embedding_byte_dtype_pattern (indices , int_data , group_size , scale , zero_point , output_dtype ):
988+ dq = torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 128 , 127 , 'INT' , output_dtype )
989+ return torch .ops .aten .embedding .default (dq , indices )
990+ def embedding_byte_dtype_replacement (indices , int_data , group_size , scale , zero_point , output_dtype ):
991+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
992+ return torch .ops .quantized_decomposed .embedding_byte .dtype (
993+ int_data ,
994+ scale ,
995+ zero_point_dtype_cast ,
996+ - 128 ,
997+ 127 ,
998+ indices ,
999+ dtype = output_dtype
1000+ )
1001+
1002+ def embedding_2bit_pattern (indices , int_data , group_size , scale , zero_point ):
1003+ dq = torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 2 , 1 )
1004+ return torch .ops .aten .embedding .default (dq , indices )
1005+ def embedding_2bit_replacement (indices , int_data , group_size , scale , zero_point ):
1006+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (int_data , 2 )
1007+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1008+ return torch .ops .quantized_decomposed .embedding_2bit .default (packed_int_data , scale , zero_point_dtype_cast , - 2 , 1 , indices )
1009+
1010+ def embedding_2bit_dtype_pattern (indices , int_data , group_size , scale , zero_point , output_dtype ):
1011+ dq = torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 2 , 1 , 'INT' , output_dtype )
1012+ return torch .ops .aten .embedding .default (dq , indices )
1013+ def embedding_2bit_dtype_replacement (indices , int_data , group_size , scale , zero_point , output_dtype ):
1014+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (int_data , 2 )
1015+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1016+ return torch .ops .quantized_decomposed .embedding_2bit .dtype (packed_int_data , scale , zero_point_dtype_cast , - 2 , 1 , indices , dtype = output_dtype )
1017+
1018+ def embedding_4bit_pattern (indices , int_data , group_size , scale , zero_point ):
1019+ dq = torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 8 , 7 )
1020+ return torch .ops .aten .embedding .default (dq , indices )
1021+ def embedding_4bit_replacement (indices , int_data , group_size , scale , zero_point ):
1022+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (int_data , 4 )
1023+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1024+ return torch .ops .quantized_decomposed .embedding_4bit .default (packed_int_data , scale , zero_point_dtype_cast , - 8 , 7 , indices )
1025+
1026+ def embedding_4bit_dtype_pattern (indices , int_data , group_size , scale , zero_point , output_dtype ):
1027+ dq = torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 8 , 7 , 'INT' , output_dtype )
1028+ return torch .ops .aten .embedding .default (dq , indices )
1029+ def embedding_4bit_dtype_replacement (indices , int_data , group_size , scale , zero_point , output_dtype ):
1030+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (int_data , 4 )
1031+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1032+ return torch .ops .quantized_decomposed .embedding_4bit .dtype (packed_int_data , scale , zero_point_dtype_cast , - 8 , 7 , indices , dtype = output_dtype )
1033+
1034+ return [
1035+ (_trace_and_lower_to_edge_ops (embedding_byte_pattern ), _trace_and_lower_to_edge_ops (embedding_byte_replacement ), []),
1036+ (_trace_and_lower_to_edge_ops (embedding_byte_dtype_pattern ), _trace_and_lower_to_edge_ops (embedding_byte_dtype_replacement ), []),
1037+ (_trace_and_lower_to_edge_ops (embedding_2bit_pattern ), _trace_and_lower_to_edge_ops (embedding_2bit_replacement ), []),
1038+ (_trace_and_lower_to_edge_ops (embedding_2bit_dtype_pattern ), _trace_and_lower_to_edge_ops (embedding_2bit_dtype_replacement ), []),
1039+ (_trace_and_lower_to_edge_ops (embedding_4bit_pattern ), _trace_and_lower_to_edge_ops (embedding_4bit_replacement ), []),
1040+ (_trace_and_lower_to_edge_ops (embedding_4bit_dtype_pattern ), _trace_and_lower_to_edge_ops (embedding_4bit_dtype_replacement ), []),
1041+ ]
1042+
8751043
8761044def _get_embedding_ops_patterns_and_replacements () -> (
8771045 List [Tuple [Callable , Callable , List [Callable ]]]
@@ -1167,5 +1335,6 @@ def get_quant_patterns_and_replacements() -> (
11671335 * _get_slice_patterns_and_replacements (),
11681336 # *_get_fixed_qparams_ops_patterns_and_replacements(),
11691337 * _get_embedding_ops_patterns_and_replacements (),
1338+ * _get_embedding_ops_patterns_and_replacements_torchao (),
11701339 ]
11711340 )
0 commit comments