2525
2626from torch import Tensor
2727from torch .library import custom_op
28+
29+
2830@custom_op ("quant_fusion::_pack_embedding_weight" , mutates_args = ())
2931def _pack_embedding_weight (weight : Tensor , bitwidth : int ) -> Tensor :
3032 num_embeddings , embedding_dim = weight .shape
3133
3234 if bitwidth == 2 :
3335 assert embedding_dim % 4 == 0 , "embedding_dim must be divisible by 4"
3436 weight_range_shifted = weight .add (2 ).view (torch .uint8 )
35- weight_view = weight_range_shifted .view (
36- num_embeddings , embedding_dim // 4 , 4
37- )
37+ weight_view = weight_range_shifted .view (num_embeddings , embedding_dim // 4 , 4 )
3838 weight_0 = weight_view [:, :, 0 ]
3939 weight_1 = weight_view [:, :, 1 ] << 2
4040 weight_2 = weight_view [:, :, 2 ] << 4
@@ -53,7 +53,7 @@ def _pack_embedding_weight(weight: Tensor, bitwidth: int) -> Tensor:
5353 return packed_weight
5454 elif bitwidth == 8 :
5555 return weight
56-
56+
5757 raise RuntimeError (f"Unsupported bitwidth { bitwidth } " )
5858
5959
@@ -64,7 +64,12 @@ def _(weight, bit_width):
6464 num_embeddings , embedding_dim = weight .shape
6565 values_per_byte = 8 // bit_width
6666 assert embedding_dim % values_per_byte == 0
67- return torch .empty (num_embeddings , embedding_dim // values_per_byte , dtype = torch .uint8 , device = weight .device )
67+ return torch .empty (
68+ num_embeddings ,
69+ embedding_dim // values_per_byte ,
70+ dtype = torch .uint8 ,
71+ device = weight .device ,
72+ )
6873
6974
7075# TODO: extending an existing library that is defined in OSS might be a bit
@@ -114,9 +119,10 @@ def embedding_weight_checks(weight, weight_scales, weight_zero_points):
114119 assert (
115120 weight_zero_points is None or weight_zero_points .dtype == weight_scales .dtype
116121 ), "Expecting weight_zero_points to be None or have same dtype as weight_scales"
117- assert (
118- weight_zero_points is None or weight_zero_points .dim () in [1 , 2 ]
119- ), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found { weight_zero_points .dim ()} "
122+ assert weight_zero_points is None or weight_zero_points .dim () in [
123+ 1 ,
124+ 2 ,
125+ ], f"Expecting weight_zero_points tensor to be None or have dim()==1, but found { weight_zero_points .dim ()} "
120126 assert weight_zero_points is None or weight_zero_points .size (0 ) == weight .size (
121127 0
122128 ), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found { weight .size ()} and { weight_zero_points .size ()} "
@@ -278,6 +284,7 @@ def embedding_2bit(
278284 )
279285 return torch .ops .aten .embedding .default (weight , indices )
280286
287+
281288@register_fake ("quantized_decomposed::embedding_2bit" )
282289def _ (
283290 weight : torch .Tensor ,
@@ -286,12 +293,13 @@ def _(
286293 weight_quant_min : int ,
287294 weight_quant_max : int ,
288295 indices : torch .Tensor ,
289- ):
296+ ):
290297 num_embeddings , packed_embedding_dim = weight .shape
291298 embedding_dim = packed_embedding_dim * 4
292299 embedding = torch .nn .Embedding (num_embeddings , embedding_dim , device = weight .device )
293300 return embedding (indices )
294301
302+
295303@register_fake ("quantized_decomposed::embedding_2bit.out" )
296304def embedding_2bit_out_meta (
297305 weight : torch .Tensor ,
@@ -311,6 +319,7 @@ def embedding_2bit_out_meta(
311319 indices ,
312320 )
313321
322+
314323@impl (quantized_decomposed_lib , "embedding_2bit.dtype" , "CompositeExplicitAutograd" )
315324def embedding_2bit_dtype (
316325 weight : torch .Tensor ,
@@ -352,6 +361,7 @@ def embedding_2bit_dtype(
352361 )
353362 return torch .ops .aten .embedding .default (weight , indices )
354363
364+
355365@register_fake ("quantized_decomposed::embedding_2bit.dtype" )
356366def _ (
357367 weight : torch .Tensor ,
@@ -361,12 +371,13 @@ def _(
361371 weight_quant_max : int ,
362372 indices : torch .Tensor ,
363373 dtype : Optional [torch .dtype ],
364- ) -> torch .Tensor :
374+ ) -> torch .Tensor :
365375 num_embeddings , packed_embedding_dim = weight .shape
366376 embedding_dim = packed_embedding_dim * 4
367377 embedding = torch .nn .Embedding (num_embeddings , embedding_dim , device = weight .device )
368378 return embedding (indices ).to (dtype )
369379
380+
370381@register_fake ("quantized_decomposed::embedding_2bit.dtype_out" )
371382def embedding_2bit_dtype_out_meta (
372383 weight : torch .Tensor ,
@@ -448,6 +459,7 @@ def embedding_4bit(
448459 )
449460 return torch .ops .aten .embedding .default (weight , indices )
450461
462+
451463@register_fake ("quantized_decomposed::embedding_4bit" )
452464def _ (
453465 weight : torch .Tensor ,
@@ -456,12 +468,13 @@ def _(
456468 weight_quant_min : int ,
457469 weight_quant_max : int ,
458470 indices : torch .Tensor ,
459- ):
471+ ):
460472 num_embeddings , packed_embedding_dim = weight .shape
461473 embedding_dim = packed_embedding_dim * 2
462474 embedding = torch .nn .Embedding (num_embeddings , embedding_dim , device = weight .device )
463475 return embedding (indices )
464476
477+
465478@register_fake ("quantized_decomposed::embedding_4bit.out" )
466479def embedding_4bit_out_meta (
467480 weight : torch .Tensor ,
@@ -521,6 +534,7 @@ def embedding_4bit_dtype(
521534 )
522535 return torch .ops .aten .embedding .default (weight , indices )
523536
537+
524538@register_fake ("quantized_decomposed::embedding_4bit.dtype" )
525539def _ (
526540 weight : torch .Tensor ,
@@ -530,12 +544,13 @@ def _(
530544 weight_quant_max : int ,
531545 indices : torch .Tensor ,
532546 dtype : Optional [torch .dtype ],
533- ) -> torch .Tensor :
547+ ) -> torch .Tensor :
534548 num_embeddings , packed_embedding_dim = weight .shape
535549 embedding_dim = packed_embedding_dim * 2
536550 embedding = torch .nn .Embedding (num_embeddings , embedding_dim , device = weight .device )
537551 return embedding (indices ).to (dtype )
538552
553+
539554@register_fake ("quantized_decomposed::embedding_4bit.dtype_out" )
540555def embedding_4bit_dtype_out_meta (
541556 weight : torch .Tensor ,
@@ -970,10 +985,16 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
970985 )
971986 ]
972987
973- def _get_embedding_ops_patterns_and_replacements_torchao () -> List [Tuple [Callable , Callable , List [Callable ]]]:
988+
989+ def _get_embedding_ops_patterns_and_replacements_torchao () -> ( # noqa C901
990+ List [Tuple [Callable , Callable , List [Callable ]]]
991+ ):
974992 def embedding_byte_pattern (indices , int_data , group_size , scale , zero_point ):
975- dq = torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 128 , 127 )
993+ dq = torch .ops .torchao .dequantize_affine .default (
994+ int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 128 , 127
995+ )
976996 return torch .ops .aten .embedding .default (dq , indices )
997+
977998 def embedding_byte_replacement (indices , int_data , group_size , scale , zero_point ):
978999 zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
9791000 return torch .ops .quantized_decomposed .embedding_byte .default (
@@ -984,10 +1005,26 @@ def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point)
9841005 127 ,
9851006 indices ,
9861007 )
987- def embedding_byte_dtype_pattern (indices , int_data , group_size , scale , zero_point , output_dtype ):
988- dq = torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 128 , 127 , 'INT' , output_dtype )
1008+
1009+ def embedding_byte_dtype_pattern (
1010+ indices , int_data , group_size , scale , zero_point , output_dtype
1011+ ):
1012+ dq = torch .ops .torchao .dequantize_affine .default (
1013+ int_data ,
1014+ [1 , group_size ],
1015+ scale ,
1016+ zero_point ,
1017+ torch .int8 ,
1018+ - 128 ,
1019+ 127 ,
1020+ "INT" ,
1021+ output_dtype ,
1022+ )
9891023 return torch .ops .aten .embedding .default (dq , indices )
990- def embedding_byte_dtype_replacement (indices , int_data , group_size , scale , zero_point , output_dtype ):
1024+
1025+ def embedding_byte_dtype_replacement (
1026+ indices , int_data , group_size , scale , zero_point , output_dtype
1027+ ):
9911028 zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
9921029 return torch .ops .quantized_decomposed .embedding_byte .dtype (
9931030 int_data ,
@@ -996,48 +1033,136 @@ def embedding_byte_dtype_replacement(indices, int_data, group_size, scale, zero_
9961033 - 128 ,
9971034 127 ,
9981035 indices ,
999- dtype = output_dtype
1036+ dtype = output_dtype ,
10001037 )
1001-
1038+
10021039 def embedding_2bit_pattern (indices , int_data , group_size , scale , zero_point ):
1003- dq = torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 2 , 1 )
1040+ dq = torch .ops .torchao .dequantize_affine .default (
1041+ int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 2 , 1
1042+ )
10041043 return torch .ops .aten .embedding .default (dq , indices )
1044+
10051045 def embedding_2bit_replacement (indices , int_data , group_size , scale , zero_point ):
1006- packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (int_data , 2 )
1046+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1047+ int_data , 2
1048+ )
10071049 zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1008- return torch .ops .quantized_decomposed .embedding_2bit .default (packed_int_data , scale , zero_point_dtype_cast , - 2 , 1 , indices )
1050+ return torch .ops .quantized_decomposed .embedding_2bit .default (
1051+ packed_int_data , scale , zero_point_dtype_cast , - 2 , 1 , indices
1052+ )
10091053
1010- def embedding_2bit_dtype_pattern (indices , int_data , group_size , scale , zero_point , output_dtype ):
1011- dq = torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 2 , 1 , 'INT' , output_dtype )
1054+ def embedding_2bit_dtype_pattern (
1055+ indices , int_data , group_size , scale , zero_point , output_dtype
1056+ ):
1057+ dq = torch .ops .torchao .dequantize_affine .default (
1058+ int_data ,
1059+ [1 , group_size ],
1060+ scale ,
1061+ zero_point ,
1062+ torch .int8 ,
1063+ - 2 ,
1064+ 1 ,
1065+ "INT" ,
1066+ output_dtype ,
1067+ )
10121068 return torch .ops .aten .embedding .default (dq , indices )
1013- def embedding_2bit_dtype_replacement (indices , int_data , group_size , scale , zero_point , output_dtype ):
1014- packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (int_data , 2 )
1069+
1070+ def embedding_2bit_dtype_replacement (
1071+ indices , int_data , group_size , scale , zero_point , output_dtype
1072+ ):
1073+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1074+ int_data , 2
1075+ )
10151076 zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1016- return torch .ops .quantized_decomposed .embedding_2bit .dtype (packed_int_data , scale , zero_point_dtype_cast , - 2 , 1 , indices , dtype = output_dtype )
1017-
1077+ return torch .ops .quantized_decomposed .embedding_2bit .dtype (
1078+ packed_int_data ,
1079+ scale ,
1080+ zero_point_dtype_cast ,
1081+ - 2 ,
1082+ 1 ,
1083+ indices ,
1084+ dtype = output_dtype ,
1085+ )
1086+
10181087 def embedding_4bit_pattern (indices , int_data , group_size , scale , zero_point ):
1019- dq = torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 8 , 7 )
1088+ dq = torch .ops .torchao .dequantize_affine .default (
1089+ int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 8 , 7
1090+ )
10201091 return torch .ops .aten .embedding .default (dq , indices )
1092+
10211093 def embedding_4bit_replacement (indices , int_data , group_size , scale , zero_point ):
1022- packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (int_data , 4 )
1094+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1095+ int_data , 4
1096+ )
10231097 zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1024- return torch .ops .quantized_decomposed .embedding_4bit .default (packed_int_data , scale , zero_point_dtype_cast , - 8 , 7 , indices )
1025-
1026- def embedding_4bit_dtype_pattern (indices , int_data , group_size , scale , zero_point , output_dtype ):
1027- dq = torch .ops .torchao .dequantize_affine .default (int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 8 , 7 , 'INT' , output_dtype )
1098+ return torch .ops .quantized_decomposed .embedding_4bit .default (
1099+ packed_int_data , scale , zero_point_dtype_cast , - 8 , 7 , indices
1100+ )
1101+
1102+ def embedding_4bit_dtype_pattern (
1103+ indices , int_data , group_size , scale , zero_point , output_dtype
1104+ ):
1105+ dq = torch .ops .torchao .dequantize_affine .default (
1106+ int_data ,
1107+ [1 , group_size ],
1108+ scale ,
1109+ zero_point ,
1110+ torch .int8 ,
1111+ - 8 ,
1112+ 7 ,
1113+ "INT" ,
1114+ output_dtype ,
1115+ )
10281116 return torch .ops .aten .embedding .default (dq , indices )
1029- def embedding_4bit_dtype_replacement (indices , int_data , group_size , scale , zero_point , output_dtype ):
1030- packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (int_data , 4 )
1117+
1118+ def embedding_4bit_dtype_replacement (
1119+ indices , int_data , group_size , scale , zero_point , output_dtype
1120+ ):
1121+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1122+ int_data , 4
1123+ )
10311124 zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1032- return torch .ops .quantized_decomposed .embedding_4bit .dtype (packed_int_data , scale , zero_point_dtype_cast , - 8 , 7 , indices , dtype = output_dtype )
1125+ return torch .ops .quantized_decomposed .embedding_4bit .dtype (
1126+ packed_int_data ,
1127+ scale ,
1128+ zero_point_dtype_cast ,
1129+ - 8 ,
1130+ 7 ,
1131+ indices ,
1132+ dtype = output_dtype ,
1133+ )
10331134
10341135 return [
1035- (_trace_and_lower_to_edge_ops (embedding_byte_pattern ), _trace_and_lower_to_edge_ops (embedding_byte_replacement ), []),
1036- (_trace_and_lower_to_edge_ops (embedding_byte_dtype_pattern ), _trace_and_lower_to_edge_ops (embedding_byte_dtype_replacement ), []),
1037- (_trace_and_lower_to_edge_ops (embedding_2bit_pattern ), _trace_and_lower_to_edge_ops (embedding_2bit_replacement ), []),
1038- (_trace_and_lower_to_edge_ops (embedding_2bit_dtype_pattern ), _trace_and_lower_to_edge_ops (embedding_2bit_dtype_replacement ), []),
1039- (_trace_and_lower_to_edge_ops (embedding_4bit_pattern ), _trace_and_lower_to_edge_ops (embedding_4bit_replacement ), []),
1040- (_trace_and_lower_to_edge_ops (embedding_4bit_dtype_pattern ), _trace_and_lower_to_edge_ops (embedding_4bit_dtype_replacement ), []),
1136+ (
1137+ _trace_and_lower_to_edge_ops (embedding_byte_pattern ),
1138+ _trace_and_lower_to_edge_ops (embedding_byte_replacement ),
1139+ [],
1140+ ),
1141+ (
1142+ _trace_and_lower_to_edge_ops (embedding_byte_dtype_pattern ),
1143+ _trace_and_lower_to_edge_ops (embedding_byte_dtype_replacement ),
1144+ [],
1145+ ),
1146+ (
1147+ _trace_and_lower_to_edge_ops (embedding_2bit_pattern ),
1148+ _trace_and_lower_to_edge_ops (embedding_2bit_replacement ),
1149+ [],
1150+ ),
1151+ (
1152+ _trace_and_lower_to_edge_ops (embedding_2bit_dtype_pattern ),
1153+ _trace_and_lower_to_edge_ops (embedding_2bit_dtype_replacement ),
1154+ [],
1155+ ),
1156+ (
1157+ _trace_and_lower_to_edge_ops (embedding_4bit_pattern ),
1158+ _trace_and_lower_to_edge_ops (embedding_4bit_replacement ),
1159+ [],
1160+ ),
1161+ (
1162+ _trace_and_lower_to_edge_ops (embedding_4bit_dtype_pattern ),
1163+ _trace_and_lower_to_edge_ops (embedding_4bit_dtype_replacement ),
1164+ [],
1165+ ),
10411166 ]
10421167
10431168
0 commit comments