@@ -986,25 +986,45 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
986986 ]
987987
988988
989- def _get_embedding_ops_patterns_and_replacements_torchao () -> ( # noqa C901
989+ def _get_embedding_ops_patterns_and_replacements_torchao (node_value_dict ) -> ( # noqa C901
990990 List [Tuple [Callable , Callable , List [Callable ]]]
991991):
992+ def get_embedding_replacement_filter (has_nonzero_zero_point ):
993+ def _filter (match , original_graph , pattern_graph ):
994+ def get_val (name ):
995+ node = [n for n in match .nodes_map if n .name == name ][0 ]
996+ val = match .nodes_map [node ]
997+ if isinstance (val , torch .fx .Node ) and val .target in node_value_dict :
998+ return node_value_dict [val .target ]
999+ return val
1000+
1001+ zero_point = get_val ("zero_point" )
1002+ all_zero = (zero_point == 0 ).all ().item ()
1003+ if has_nonzero_zero_point :
1004+ return not all_zero
1005+ else :
1006+ return all_zero
1007+ return _filter
1008+
9921009 def embedding_byte_pattern (indices , int_data , group_size , scale , zero_point ):
9931010 dq = torch .ops .torchao .dequantize_affine .default (
9941011 int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 128 , 127
9951012 )
9961013 return torch .ops .aten .embedding .default (dq , indices )
9971014
998- def embedding_byte_replacement (indices , int_data , group_size , scale , zero_point ):
999- zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1000- return torch .ops .quantized_decomposed .embedding_byte .default (
1001- int_data ,
1002- scale ,
1003- zero_point_dtype_cast ,
1004- - 128 ,
1005- 127 ,
1006- indices ,
1007- )
1015+ def get_embedding_byte_replacement (has_nonzero_zero_point ):
1016+ def embedding_byte_replacement (indices , int_data , group_size , scale , zero_point ):
1017+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1018+ zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1019+ return torch .ops .quantized_decomposed .embedding_byte .default (
1020+ int_data ,
1021+ scale ,
1022+ zero_point_dtype_cast ,
1023+ - 128 ,
1024+ 127 ,
1025+ indices ,
1026+ )
1027+ return embedding_byte_replacement
10081028
10091029 def embedding_byte_dtype_pattern (
10101030 indices , int_data , group_size , scale , zero_point , output_dtype
@@ -1021,34 +1041,40 @@ def embedding_byte_dtype_pattern(
10211041 )
10221042 return torch .ops .aten .embedding .default (dq , indices )
10231043
1024- def embedding_byte_dtype_replacement (
1025- indices , int_data , group_size , scale , zero_point , output_dtype
1026- ):
1027- zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1028- return torch .ops .quantized_decomposed .embedding_byte .dtype (
1029- int_data ,
1030- scale ,
1031- zero_point_dtype_cast ,
1032- - 128 ,
1033- 127 ,
1034- indices ,
1035- dtype = output_dtype ,
1036- )
1044+ def get_embedding_byte_dtype_replacement (has_nonzero_zero_point ):
1045+ def embedding_byte_dtype_replacement (
1046+ indices , int_data , group_size , scale , zero_point , output_dtype
1047+ ):
1048+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1049+ zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1050+ return torch .ops .quantized_decomposed .embedding_byte .dtype (
1051+ int_data ,
1052+ scale ,
1053+ zero_point_dtype_cast ,
1054+ - 128 ,
1055+ 127 ,
1056+ indices ,
1057+ dtype = output_dtype ,
1058+ )
1059+ return embedding_byte_dtype_replacement
10371060
10381061 def embedding_2bit_pattern (indices , int_data , group_size , scale , zero_point ):
10391062 dq = torch .ops .torchao .dequantize_affine .default (
10401063 int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 2 , 1
10411064 )
10421065 return torch .ops .aten .embedding .default (dq , indices )
10431066
1044- def embedding_2bit_replacement (indices , int_data , group_size , scale , zero_point ):
1045- packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1046- int_data , 2
1047- )
1048- zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1049- return torch .ops .quantized_decomposed .embedding_2bit .default (
1050- packed_int_data , scale , zero_point_dtype_cast , - 2 , 1 , indices
1051- )
1067+ def get_embedding_2bit_replacement (has_nonzero_zero_point ):
1068+ def embedding_2bit_replacement (indices , int_data , group_size , scale , zero_point ):
1069+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1070+ int_data , 2
1071+ )
1072+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1073+ zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1074+ return torch .ops .quantized_decomposed .embedding_2bit .default (
1075+ packed_int_data , scale , zero_point_dtype_cast , - 2 , 1 , indices
1076+ )
1077+ return embedding_2bit_replacement
10521078
10531079 def embedding_2bit_dtype_pattern (
10541080 indices , int_data , group_size , scale , zero_point , output_dtype
@@ -1065,37 +1091,43 @@ def embedding_2bit_dtype_pattern(
10651091 )
10661092 return torch .ops .aten .embedding .default (dq , indices )
10671093
1068- def embedding_2bit_dtype_replacement (
1069- indices , int_data , group_size , scale , zero_point , output_dtype
1070- ):
1071- packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1072- int_data , 2
1073- )
1074- zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1075- return torch .ops .quantized_decomposed .embedding_2bit .dtype (
1076- packed_int_data ,
1077- scale ,
1078- zero_point_dtype_cast ,
1079- - 2 ,
1080- 1 ,
1081- indices ,
1082- dtype = output_dtype ,
1083- )
1094+ def get_embedding_2bit_dtype_replacement (has_nonzero_zero_point ):
1095+ def embedding_2bit_dtype_replacement (
1096+ indices , int_data , group_size , scale , zero_point , output_dtype
1097+ ):
1098+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1099+ int_data , 2
1100+ )
1101+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1102+ zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1103+ return torch .ops .quantized_decomposed .embedding_2bit .dtype (
1104+ packed_int_data ,
1105+ scale ,
1106+ zero_point_dtype_cast ,
1107+ - 2 ,
1108+ 1 ,
1109+ indices ,
1110+ dtype = output_dtype ,
1111+ )
1112+ return embedding_2bit_dtype_replacement
10841113
10851114 def embedding_4bit_pattern (indices , int_data , group_size , scale , zero_point ):
10861115 dq = torch .ops .torchao .dequantize_affine .default (
10871116 int_data , [1 , group_size ], scale , zero_point , torch .int8 , - 8 , 7
10881117 )
10891118 return torch .ops .aten .embedding .default (dq , indices )
10901119
1091- def embedding_4bit_replacement (indices , int_data , group_size , scale , zero_point ):
1092- packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1093- int_data , 4
1094- )
1095- zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1096- return torch .ops .quantized_decomposed .embedding_4bit .default (
1097- packed_int_data , scale , zero_point_dtype_cast , - 8 , 7 , indices
1098- )
1120+ def get_embedding_4bit_replacement (has_nonzero_zero_point ):
1121+ def embedding_4bit_replacement (indices , int_data , group_size , scale , zero_point ):
1122+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1123+ int_data , 4
1124+ )
1125+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1126+ zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1127+ return torch .ops .quantized_decomposed .embedding_4bit .default (
1128+ packed_int_data , scale , zero_point_dtype_cast , - 8 , 7 , indices
1129+ )
1130+ return embedding_4bit_replacement
10991131
11001132 def embedding_4bit_dtype_pattern (
11011133 indices , int_data , group_size , scale , zero_point , output_dtype
@@ -1112,53 +1144,86 @@ def embedding_4bit_dtype_pattern(
11121144 )
11131145 return torch .ops .aten .embedding .default (dq , indices )
11141146
1115- def embedding_4bit_dtype_replacement (
1116- indices , int_data , group_size , scale , zero_point , output_dtype
1117- ):
1118- packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1119- int_data , 4
1120- )
1121- zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1122- return torch .ops .quantized_decomposed .embedding_4bit .dtype (
1123- packed_int_data ,
1124- scale ,
1125- zero_point_dtype_cast ,
1126- - 8 ,
1127- 7 ,
1128- indices ,
1129- dtype = output_dtype ,
1130- )
1147+ def get_embedding_4bit_dtype_replacement (has_nonzero_zero_point ):
1148+ def embedding_4bit_dtype_replacement (
1149+ indices , int_data , group_size , scale , zero_point , output_dtype
1150+ ):
1151+ packed_int_data = torch .ops .quant_fusion ._pack_embedding_weight .default (
1152+ int_data , 4
1153+ )
1154+ zero_point_dtype_cast = torch .ops .aten .to .dtype (zero_point , scale .dtype )
1155+ zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1156+ return torch .ops .quantized_decomposed .embedding_4bit .dtype (
1157+ packed_int_data ,
1158+ scale ,
1159+ zero_point_dtype_cast ,
1160+ - 8 ,
1161+ 7 ,
1162+ indices ,
1163+ dtype = output_dtype ,
1164+ )
1165+ return embedding_4bit_dtype_replacement
11311166
11321167 return [
11331168 (
11341169 _trace_and_lower_to_edge_ops (embedding_byte_pattern ),
1135- _trace_and_lower_to_edge_ops (embedding_byte_replacement ),
1136- [],
1170+ _trace_and_lower_to_edge_ops (get_embedding_byte_replacement (False )),
1171+ [get_embedding_replacement_filter (has_nonzero_zero_point = False )],
1172+ ),
1173+ (
1174+ _trace_and_lower_to_edge_ops (embedding_byte_pattern ),
1175+ _trace_and_lower_to_edge_ops (get_embedding_byte_replacement (True )),
1176+ [get_embedding_replacement_filter (has_nonzero_zero_point = True )],
11371177 ),
11381178 (
11391179 _trace_and_lower_to_edge_ops (embedding_byte_dtype_pattern ),
1140- _trace_and_lower_to_edge_ops (embedding_byte_dtype_replacement ),
1141- [],
1180+ _trace_and_lower_to_edge_ops (get_embedding_byte_dtype_replacement (False )),
1181+ [get_embedding_replacement_filter (has_nonzero_zero_point = False )],
1182+ ),
1183+ (
1184+ _trace_and_lower_to_edge_ops (embedding_byte_dtype_pattern ),
1185+ _trace_and_lower_to_edge_ops (get_embedding_byte_dtype_replacement (True )),
1186+ [get_embedding_replacement_filter (has_nonzero_zero_point = True )],
11421187 ),
11431188 (
11441189 _trace_and_lower_to_edge_ops (embedding_2bit_pattern ),
1145- _trace_and_lower_to_edge_ops (embedding_2bit_replacement ),
1146- [],
1190+ _trace_and_lower_to_edge_ops (get_embedding_2bit_replacement (False )),
1191+ [get_embedding_replacement_filter (has_nonzero_zero_point = False )],
1192+ ),
1193+ (
1194+ _trace_and_lower_to_edge_ops (embedding_2bit_pattern ),
1195+ _trace_and_lower_to_edge_ops (get_embedding_2bit_replacement (True )),
1196+ [get_embedding_replacement_filter (has_nonzero_zero_point = True )],
11471197 ),
11481198 (
11491199 _trace_and_lower_to_edge_ops (embedding_2bit_dtype_pattern ),
1150- _trace_and_lower_to_edge_ops (embedding_2bit_dtype_replacement ),
1151- [],
1200+ _trace_and_lower_to_edge_ops (get_embedding_2bit_dtype_replacement (False )),
1201+ [get_embedding_replacement_filter (has_nonzero_zero_point = False )],
1202+ ),
1203+ (
1204+ _trace_and_lower_to_edge_ops (embedding_2bit_dtype_pattern ),
1205+ _trace_and_lower_to_edge_ops (get_embedding_2bit_dtype_replacement (True )),
1206+ [get_embedding_replacement_filter (has_nonzero_zero_point = True )],
11521207 ),
11531208 (
11541209 _trace_and_lower_to_edge_ops (embedding_4bit_pattern ),
1155- _trace_and_lower_to_edge_ops (embedding_4bit_replacement ),
1156- [],
1210+ _trace_and_lower_to_edge_ops (get_embedding_4bit_replacement (has_nonzero_zero_point = False )),
1211+ [get_embedding_replacement_filter (has_nonzero_zero_point = False )],
1212+ ),
1213+ (
1214+ _trace_and_lower_to_edge_ops (embedding_4bit_pattern ),
1215+ _trace_and_lower_to_edge_ops (get_embedding_4bit_replacement (has_nonzero_zero_point = True )),
1216+ [get_embedding_replacement_filter (has_nonzero_zero_point = True )],
11571217 ),
11581218 (
11591219 _trace_and_lower_to_edge_ops (embedding_4bit_dtype_pattern ),
1160- _trace_and_lower_to_edge_ops (embedding_4bit_dtype_replacement ),
1161- [],
1220+ _trace_and_lower_to_edge_ops (get_embedding_4bit_dtype_replacement (has_nonzero_zero_point = False )),
1221+ [get_embedding_replacement_filter (has_nonzero_zero_point = False )],
1222+ ),
1223+ (
1224+ _trace_and_lower_to_edge_ops (embedding_4bit_dtype_pattern ),
1225+ _trace_and_lower_to_edge_ops (get_embedding_4bit_dtype_replacement (has_nonzero_zero_point = True )),
1226+ [get_embedding_replacement_filter (has_nonzero_zero_point = True )],
11621227 ),
11631228 ]
11641229
@@ -1445,7 +1510,7 @@ def replacement(x, x_scale, x_zero_point, x_qmin, x_qmax):
14451510"""
14461511
14471512
1448- def get_quant_patterns_and_replacements () -> (
1513+ def get_quant_patterns_and_replacements (node_value_dict ) -> (
14491514 List [Tuple [Callable , Callable , List [Callable ]]]
14501515):
14511516
@@ -1457,6 +1522,6 @@ def get_quant_patterns_and_replacements() -> (
14571522 * _get_slice_patterns_and_replacements (),
14581523 # *_get_fixed_qparams_ops_patterns_and_replacements(),
14591524 * _get_embedding_ops_patterns_and_replacements (),
1460- * _get_embedding_ops_patterns_and_replacements_torchao (),
1525+ * _get_embedding_ops_patterns_and_replacements_torchao (node_value_dict ),
14611526 ]
14621527 )
0 commit comments