@@ -986,25 +986,46 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
986986 ]
987987
988988
989- def _get_embedding_ops_patterns_and_replacements_torchao () -> ( # noqa C901
989+ def _get_embedding_ops_patterns_and_replacements_torchao (node_value_dict ) -> ( # noqa C901
990990 List [Tuple [Callable , Callable , List [Callable ]]]
991991):
992+ assert node_value_dict is not None , "node_value_dict cannot be None"
993+ def get_embedding_replacement_filter (has_nonzero_zero_point ):
994+ def _filter (match , original_graph , pattern_graph ):
995+ def get_val (name ):
996+ node = [n for n in match .nodes_map if n .name == name ][0 ]
997+ val = match .nodes_map [node ]
998+ if isinstance (val , torch .fx .Node ) and val .target in node_value_dict :
999+ return node_value_dict [val .target ]
1000+ return val
1001+
1002+ zero_point = get_val ("zero_point" )
1003+ all_zero = (zero_point == 0 ).all ().item ()
1004+ if has_nonzero_zero_point :
1005+ return not all_zero
1006+ else :
1007+ return all_zero
1008+ return _filter
1009+
9921010 def embedding_byte_pattern (indices , int_data , group_size , scale , zero_point ):
9931011 dq = torch .ops .torchao .dequantize_affine .default (
9941012 int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 128 , 127
9951013 )
9961014 return torch .ops .aten .embedding .default (dq , indices )
9971015
998- def embedding_byte_replacement (indices , int_data , group_size , scale , zero_point ):
999- zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1000- return torch .ops .quantized_decomposed .embedding_byte .default (
1001- int_data ,
1002- scale ,
1003- zero_point_dtype_cast ,
1004- - 128 ,
1005- 127 ,
1006- indices ,
1007- )
1016+ def get_embedding_byte_replacement (has_nonzero_zero_point ):
1017+ def embedding_byte_replacement (indices , int_data , group_size , scale , zero_point ):
1018+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1019+ zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1020+ return torch .ops .quantized_decomposed .embedding_byte .default (
1021+ int_data ,
1022+ scale ,
1023+ zero_point_dtype_cast ,
1024+ - 128 ,
1025+ 127 ,
1026+ indices ,
1027+ )
1028+ return embedding_byte_replacement
10081029
10091030 def embedding_byte_dtype_pattern (
10101031 indices , int_data , group_size , scale , zero_point , output_dtype
@@ -1021,34 +1042,40 @@ def embedding_byte_dtype_pattern(
10211042 )
10221043 return torch .ops .aten .embedding .default (dq , indices )
10231044
1024- def embedding_byte_dtype_replacement (
1025- indices , int_data , group_size , scale , zero_point , output_dtype
1026- ):
1027- zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1028- return torch .ops .quantized_decomposed .embedding_byte .dtype (
1029- int_data ,
1030- scale ,
1031- zero_point_dtype_cast ,
1032- - 128 ,
1033- 127 ,
1034- indices ,
1035- dtype = output_dtype ,
1036- )
1045+ def get_embedding_byte_dtype_replacement (has_nonzero_zero_point ):
1046+ def embedding_byte_dtype_replacement (
1047+ indices , int_data , group_size , scale , zero_point , output_dtype
1048+ ):
1049+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1050+ zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1051+ return torch .ops .quantized_decomposed .embedding_byte .dtype (
1052+ int_data ,
1053+ scale ,
1054+ zero_point_dtype_cast ,
1055+ - 128 ,
1056+ 127 ,
1057+ indices ,
1058+ dtype = output_dtype ,
1059+ )
1060+ return embedding_byte_dtype_replacement
10371061
10381062 def embedding_2bit_pattern (indices , int_data , group_size , scale , zero_point ):
10391063 dq = torch .ops .torchao .dequantize_affine .default (
10401064 int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 2 , 1
10411065 )
10421066 return torch .ops .aten .embedding .default (dq , indices )
10431067
1044- def embedding_2bit_replacement (indices , int_data , group_size , scale , zero_point ):
1045- packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1046- int_data , 2
1047- )
1048- zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1049- return torch .ops .quantized_decomposed .embedding_2bit .default (
1050- packed_int_data , scale , zero_point_dtype_cast , - 2 , 1 , indices
1051- )
1068+ def get_embedding_2bit_replacement (has_nonzero_zero_point ):
1069+ def embedding_2bit_replacement (indices , int_data , group_size , scale , zero_point ):
1070+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1071+ int_data , 2
1072+ )
1073+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1074+ zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1075+ return torch .ops .quantized_decomposed .embedding_2bit .default (
1076+ packed_int_data , scale , zero_point_dtype_cast , - 2 , 1 , indices
1077+ )
1078+ return embedding_2bit_replacement
10521079
10531080 def embedding_2bit_dtype_pattern (
10541081 indices , int_data , group_size , scale , zero_point , output_dtype
@@ -1065,37 +1092,43 @@ def embedding_2bit_dtype_pattern(
10651092 )
10661093 return torch .ops .aten .embedding .default (dq , indices )
10671094
1068- def embedding_2bit_dtype_replacement (
1069- indices , int_data , group_size , scale , zero_point , output_dtype
1070- ):
1071- packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1072- int_data , 2
1073- )
1074- zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1075- return torch .ops .quantized_decomposed .embedding_2bit .dtype (
1076- packed_int_data ,
1077- scale ,
1078- zero_point_dtype_cast ,
1079- - 2 ,
1080- 1 ,
1081- indices ,
1082- dtype = output_dtype ,
1083- )
1095+ def get_embedding_2bit_dtype_replacement (has_nonzero_zero_point ):
1096+ def embedding_2bit_dtype_replacement (
1097+ indices , int_data , group_size , scale , zero_point , output_dtype
1098+ ):
1099+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1100+ int_data , 2
1101+ )
1102+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1103+ zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1104+ return torch .ops .quantized_decomposed .embedding_2bit .dtype (
1105+ packed_int_data ,
1106+ scale ,
1107+ zero_point_dtype_cast ,
1108+ - 2 ,
1109+ 1 ,
1110+ indices ,
1111+ dtype = output_dtype ,
1112+ )
1113+ return embedding_2bit_dtype_replacement
10841114
10851115 def embedding_4bit_pattern (indices , int_data , group_size , scale , zero_point ):
10861116 dq = torch .ops .torchao .dequantize_affine .default (
10871117 int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 8 , 7
10881118 )
10891119 return torch .ops .aten .embedding .default (dq , indices )
10901120
1091- def embedding_4bit_replacement (indices , int_data , group_size , scale , zero_point ):
1092- packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1093- int_data , 4
1094- )
1095- zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1096- return torch .ops .quantized_decomposed .embedding_4bit .default (
1097- packed_int_data , scale , zero_point_dtype_cast , - 8 , 7 , indices
1098- )
1121+ def get_embedding_4bit_replacement (has_nonzero_zero_point ):
1122+ def embedding_4bit_replacement (indices , int_data , group_size , scale , zero_point ):
1123+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1124+ int_data , 4
1125+ )
1126+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1127+ zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1128+ return torch .ops .quantized_decomposed .embedding_4bit .default (
1129+ packed_int_data , scale , zero_point_dtype_cast , - 8 , 7 , indices
1130+ )
1131+ return embedding_4bit_replacement
10991132
11001133 def embedding_4bit_dtype_pattern (
11011134 indices , int_data , group_size , scale , zero_point , output_dtype
@@ -1112,53 +1145,86 @@ def embedding_4bit_dtype_pattern(
11121145 )
11131146 return torch .ops .aten .embedding .default (dq , indices )
11141147
1115- def embedding_4bit_dtype_replacement (
1116- indices , int_data , group_size , scale , zero_point , output_dtype
1117- ):
1118- packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1119- int_data , 4
1120- )
1121- zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1122- return torch .ops .quantized_decomposed .embedding_4bit .dtype (
1123- packed_int_data ,
1124- scale ,
1125- zero_point_dtype_cast ,
1126- - 8 ,
1127- 7 ,
1128- indices ,
1129- dtype = output_dtype ,
1130- )
1148+ def get_embedding_4bit_dtype_replacement (has_nonzero_zero_point ):
1149+ def embedding_4bit_dtype_replacement (
1150+ indices , int_data , group_size , scale , zero_point , output_dtype
1151+ ):
1152+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1153+ int_data , 4
1154+ )
1155+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1156+ zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1157+ return torch .ops .quantized_decomposed .embedding_4bit .dtype (
1158+ packed_int_data ,
1159+ scale ,
1160+ zero_point_dtype_cast ,
1161+ - 8 ,
1162+ 7 ,
1163+ indices ,
1164+ dtype = output_dtype ,
1165+ )
1166+ return embedding_4bit_dtype_replacement
11311167
11321168 return [
11331169 (
11341170 _trace_and_lower_to_edge_ops (embedding_byte_pattern ),
1135- _trace_and_lower_to_edge_ops (embedding_byte_replacement ),
1136- [],
1171+ _trace_and_lower_to_edge_ops (get_embedding_byte_replacement (False )),
1172+ [get_embedding_replacement_filter (has_nonzero_zero_point = False )],
1173+ ),
1174+ (
1175+ _trace_and_lower_to_edge_ops (embedding_byte_pattern ),
1176+ _trace_and_lower_to_edge_ops (get_embedding_byte_replacement (True )),
1177+ [get_embedding_replacement_filter (has_nonzero_zero_point = True )],
11371178 ),
11381179 (
11391180 _trace_and_lower_to_edge_ops (embedding_byte_dtype_pattern ),
1140- _trace_and_lower_to_edge_ops (embedding_byte_dtype_replacement ),
1141- [],
1181+ _trace_and_lower_to_edge_ops (get_embedding_byte_dtype_replacement (False )),
1182+ [get_embedding_replacement_filter (has_nonzero_zero_point = False )],
1183+ ),
1184+ (
1185+ _trace_and_lower_to_edge_ops (embedding_byte_dtype_pattern ),
1186+ _trace_and_lower_to_edge_ops (get_embedding_byte_dtype_replacement (True )),
1187+ [get_embedding_replacement_filter (has_nonzero_zero_point = True )],
11421188 ),
11431189 (
11441190 _trace_and_lower_to_edge_ops (embedding_2bit_pattern ),
1145- _trace_and_lower_to_edge_ops (embedding_2bit_replacement ),
1146- [],
1191+ _trace_and_lower_to_edge_ops (get_embedding_2bit_replacement (False )),
1192+ [get_embedding_replacement_filter (has_nonzero_zero_point = False )],
1193+ ),
1194+ (
1195+ _trace_and_lower_to_edge_ops (embedding_2bit_pattern ),
1196+ _trace_and_lower_to_edge_ops (get_embedding_2bit_replacement (True )),
1197+ [get_embedding_replacement_filter (has_nonzero_zero_point = True )],
11471198 ),
11481199 (
11491200 _trace_and_lower_to_edge_ops (embedding_2bit_dtype_pattern ),
1150- _trace_and_lower_to_edge_ops (embedding_2bit_dtype_replacement ),
1151- [],
1201+ _trace_and_lower_to_edge_ops (get_embedding_2bit_dtype_replacement (False )),
1202+ [get_embedding_replacement_filter (has_nonzero_zero_point = False )],
1203+ ),
1204+ (
1205+ _trace_and_lower_to_edge_ops (embedding_2bit_dtype_pattern ),
1206+ _trace_and_lower_to_edge_ops (get_embedding_2bit_dtype_replacement (True )),
1207+ [get_embedding_replacement_filter (has_nonzero_zero_point = True )],
11521208 ),
11531209 (
11541210 _trace_and_lower_to_edge_ops (embedding_4bit_pattern ),
1155- _trace_and_lower_to_edge_ops (embedding_4bit_replacement ),
1156- [],
1211+ _trace_and_lower_to_edge_ops (get_embedding_4bit_replacement (has_nonzero_zero_point = False )),
1212+ [get_embedding_replacement_filter (has_nonzero_zero_point = False )],
1213+ ),
1214+ (
1215+ _trace_and_lower_to_edge_ops (embedding_4bit_pattern ),
1216+ _trace_and_lower_to_edge_ops (get_embedding_4bit_replacement (has_nonzero_zero_point = True )),
1217+ [get_embedding_replacement_filter (has_nonzero_zero_point = True )],
11571218 ),
11581219 (
11591220 _trace_and_lower_to_edge_ops (embedding_4bit_dtype_pattern ),
1160- _trace_and_lower_to_edge_ops (embedding_4bit_dtype_replacement ),
1161- [],
1221+ _trace_and_lower_to_edge_ops (get_embedding_4bit_dtype_replacement (has_nonzero_zero_point = False )),
1222+ [get_embedding_replacement_filter (has_nonzero_zero_point = False )],
1223+ ),
1224+ (
1225+ _trace_and_lower_to_edge_ops (embedding_4bit_dtype_pattern ),
1226+ _trace_and_lower_to_edge_ops (get_embedding_4bit_dtype_replacement (has_nonzero_zero_point = True )),
1227+ [get_embedding_replacement_filter (has_nonzero_zero_point = True )],
11621228 ),
11631229 ]
11641230
@@ -1445,7 +1511,7 @@ def replacement(x, x_scale, x_zero_point, x_qmin, x_qmax):
14451511"""
14461512
14471513
1448- def get_quant_patterns_and_replacements () -> (
1514+ def get_quant_patterns_and_replacements (node_value_dict ) -> (
14491515 List [Tuple [Callable , Callable , List [Callable ]]]
14501516):
14511517
@@ -1457,6 +1523,6 @@ def get_quant_patterns_and_replacements() -> (
14571523 * _get_slice_patterns_and_replacements (),
14581524 # *_get_fixed_qparams_ops_patterns_and_replacements(),
14591525 * _get_embedding_ops_patterns_and_replacements (),
1460- * _get_embedding_ops_patterns_and_replacements_torchao (),
1526+ * _get_embedding_ops_patterns_and_replacements_torchao (node_value_dict ),
14611527 ]
14621528 )
0 commit comments