Skip to content

Commit 50add41

Browse files
metascroyfacebook-github-bot
authored andcommitted
Embedding quant unification
Differential Revision: D83318725
1 parent 2283294 commit 50add41

File tree

3 files changed

+182
-111
lines changed

3 files changed

+182
-111
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -595,19 +595,12 @@ def __init__(
595595

596596
@torch.no_grad()
597597
def create_quantized_state_dict(self, packed=False) -> Dict:
598+
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_, MappingType
599+
from torchao.quantization.granularity import PerGroup
600+
598601
cur_state_dict = self.mod.state_dict()
599602

600-
if self.bitwidth == 2:
601-
range_min = -2
602-
range_max = 1
603-
elif self.bitwidth == 4:
604-
range_min = -8
605-
range_max = 7
606-
elif self.bitwidth == 8:
607-
range_min = -128
608-
range_max = 127
609-
else:
610-
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
603+
assert self.bitwidth in [2, 4, 8], f"Unsupported bitwidth {self.bitwidth}"
611604

612605
for fqn, mod in self.mod.named_modules():
613606
if isinstance(mod, nn.Embedding):
@@ -619,18 +612,18 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
619612
print(
620613
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
621614
)
622-
weight, scales, _ = dynamically_quantize_per_channel(
623-
(
624-
mod.weight.to(dtype=self.precision)
625-
if self.precision
626-
else mod.weight
627-
),
628-
range_min,
629-
range_max,
630-
torch.int8,
631-
self.group_size,
632-
scales_dtype=mod.weight.dtype,
615+
tmp_model = nn.Embedding(mod.weight.shape[0], mod.weight.shape[1])
616+
if self.precision:
617+
tmp_model = tmp_model.to(dtype=self.precision)
618+
tmp_model.weight = nn.Parameter(mod.weight)
619+
config = IntxWeightOnlyConfig(
620+
weight_dtype=getattr(torch, f"int{self.bitwidth}"),
621+
granularity=PerGroup(self.group_size),
622+
mapping_type=MappingType.SYMMETRIC,
633623
)
624+
quantize_(tmp_model, config, lambda m, fqn: isinstance(m, nn.Embedding))
625+
weight = tmp_model.weight.qdata
626+
scales = tmp_model.weight.scale
634627

635628
if packed:
636629
if self.bitwidth == 2:

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 151 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -986,25 +986,45 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
986986
]
987987

988988

989-
def _get_embedding_ops_patterns_and_replacements_torchao() -> ( # noqa C901
989+
def _get_embedding_ops_patterns_and_replacements_torchao(node_value_dict) -> ( # noqa C901
990990
List[Tuple[Callable, Callable, List[Callable]]]
991991
):
992+
def get_embedding_replacement_filter(has_nonzero_zero_point):
993+
def _filter(match, original_graph, pattern_graph):
994+
def get_val(name):
995+
node = [n for n in match.nodes_map if n.name == name][0]
996+
val = match.nodes_map[node]
997+
if isinstance(val, torch.fx.Node) and val.target in node_value_dict:
998+
return node_value_dict[val.target]
999+
return val
1000+
1001+
zero_point = get_val("zero_point")
1002+
all_zero = (zero_point == 0).all().item()
1003+
if has_nonzero_zero_point:
1004+
return not all_zero
1005+
else:
1006+
return all_zero
1007+
return _filter
1008+
9921009
def embedding_byte_pattern(indices, int_data, group_size, scale, zero_point):
9931010
dq = torch.ops.torchao.dequantize_affine.default(
9941011
int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127
9951012
)
9961013
return torch.ops.aten.embedding.default(dq, indices)
9971014

998-
def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point):
999-
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1000-
return torch.ops.quantized_decomposed.embedding_byte.default(
1001-
int_data,
1002-
scale,
1003-
zero_point_dtype_cast,
1004-
-128,
1005-
127,
1006-
indices,
1007-
)
1015+
def get_embedding_byte_replacement(has_nonzero_zero_point):
1016+
def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point):
1017+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1018+
zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1019+
return torch.ops.quantized_decomposed.embedding_byte.default(
1020+
int_data,
1021+
scale,
1022+
zero_point_dtype_cast,
1023+
-128,
1024+
127,
1025+
indices,
1026+
)
1027+
return embedding_byte_replacement
10081028

10091029
def embedding_byte_dtype_pattern(
10101030
indices, int_data, group_size, scale, zero_point, output_dtype
@@ -1021,34 +1041,40 @@ def embedding_byte_dtype_pattern(
10211041
)
10221042
return torch.ops.aten.embedding.default(dq, indices)
10231043

1024-
def embedding_byte_dtype_replacement(
1025-
indices, int_data, group_size, scale, zero_point, output_dtype
1026-
):
1027-
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1028-
return torch.ops.quantized_decomposed.embedding_byte.dtype(
1029-
int_data,
1030-
scale,
1031-
zero_point_dtype_cast,
1032-
-128,
1033-
127,
1034-
indices,
1035-
dtype=output_dtype,
1036-
)
1044+
def get_embedding_byte_dtype_replacement(has_nonzero_zero_point):
1045+
def embedding_byte_dtype_replacement(
1046+
indices, int_data, group_size, scale, zero_point, output_dtype
1047+
):
1048+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1049+
zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1050+
return torch.ops.quantized_decomposed.embedding_byte.dtype(
1051+
int_data,
1052+
scale,
1053+
zero_point_dtype_cast,
1054+
-128,
1055+
127,
1056+
indices,
1057+
dtype=output_dtype,
1058+
)
1059+
return embedding_byte_dtype_replacement
10371060

10381061
def embedding_2bit_pattern(indices, int_data, group_size, scale, zero_point):
10391062
dq = torch.ops.torchao.dequantize_affine.default(
10401063
int_data, [1, group_size], scale, zero_point, torch.int8, -2, 1
10411064
)
10421065
return torch.ops.aten.embedding.default(dq, indices)
10431066

1044-
def embedding_2bit_replacement(indices, int_data, group_size, scale, zero_point):
1045-
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1046-
int_data, 2
1047-
)
1048-
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1049-
return torch.ops.quantized_decomposed.embedding_2bit.default(
1050-
packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices
1051-
)
1067+
def get_embedding_2bit_replacement(has_nonzero_zero_point):
1068+
def embedding_2bit_replacement(indices, int_data, group_size, scale, zero_point):
1069+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1070+
int_data, 2
1071+
)
1072+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1073+
zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1074+
return torch.ops.quantized_decomposed.embedding_2bit.default(
1075+
packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices
1076+
)
1077+
return embedding_2bit_replacement
10521078

10531079
def embedding_2bit_dtype_pattern(
10541080
indices, int_data, group_size, scale, zero_point, output_dtype
@@ -1065,37 +1091,43 @@ def embedding_2bit_dtype_pattern(
10651091
)
10661092
return torch.ops.aten.embedding.default(dq, indices)
10671093

1068-
def embedding_2bit_dtype_replacement(
1069-
indices, int_data, group_size, scale, zero_point, output_dtype
1070-
):
1071-
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1072-
int_data, 2
1073-
)
1074-
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1075-
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
1076-
packed_int_data,
1077-
scale,
1078-
zero_point_dtype_cast,
1079-
-2,
1080-
1,
1081-
indices,
1082-
dtype=output_dtype,
1083-
)
1094+
def get_embedding_2bit_dtype_replacement(has_nonzero_zero_point):
1095+
def embedding_2bit_dtype_replacement(
1096+
indices, int_data, group_size, scale, zero_point, output_dtype
1097+
):
1098+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1099+
int_data, 2
1100+
)
1101+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1102+
zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1103+
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
1104+
packed_int_data,
1105+
scale,
1106+
zero_point_dtype_cast,
1107+
-2,
1108+
1,
1109+
indices,
1110+
dtype=output_dtype,
1111+
)
1112+
return embedding_2bit_dtype_replacement
10841113

10851114
def embedding_4bit_pattern(indices, int_data, group_size, scale, zero_point):
10861115
dq = torch.ops.torchao.dequantize_affine.default(
10871116
int_data, [1, group_size], scale, zero_point, torch.int8, -8, 7
10881117
)
10891118
return torch.ops.aten.embedding.default(dq, indices)
10901119

1091-
def embedding_4bit_replacement(indices, int_data, group_size, scale, zero_point):
1092-
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1093-
int_data, 4
1094-
)
1095-
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1096-
return torch.ops.quantized_decomposed.embedding_4bit.default(
1097-
packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices
1098-
)
1120+
def get_embedding_4bit_replacement(has_nonzero_zero_point):
1121+
def embedding_4bit_replacement(indices, int_data, group_size, scale, zero_point):
1122+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1123+
int_data, 4
1124+
)
1125+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1126+
zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1127+
return torch.ops.quantized_decomposed.embedding_4bit.default(
1128+
packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices
1129+
)
1130+
return embedding_4bit_replacement
10991131

11001132
def embedding_4bit_dtype_pattern(
11011133
indices, int_data, group_size, scale, zero_point, output_dtype
@@ -1112,53 +1144,86 @@ def embedding_4bit_dtype_pattern(
11121144
)
11131145
return torch.ops.aten.embedding.default(dq, indices)
11141146

1115-
def embedding_4bit_dtype_replacement(
1116-
indices, int_data, group_size, scale, zero_point, output_dtype
1117-
):
1118-
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1119-
int_data, 4
1120-
)
1121-
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1122-
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
1123-
packed_int_data,
1124-
scale,
1125-
zero_point_dtype_cast,
1126-
-8,
1127-
7,
1128-
indices,
1129-
dtype=output_dtype,
1130-
)
1147+
def get_embedding_4bit_dtype_replacement(has_nonzero_zero_point):
1148+
def embedding_4bit_dtype_replacement(
1149+
indices, int_data, group_size, scale, zero_point, output_dtype
1150+
):
1151+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1152+
int_data, 4
1153+
)
1154+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1155+
zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1156+
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
1157+
packed_int_data,
1158+
scale,
1159+
zero_point_dtype_cast,
1160+
-8,
1161+
7,
1162+
indices,
1163+
dtype=output_dtype,
1164+
)
1165+
return embedding_4bit_dtype_replacement
11311166

11321167
return [
11331168
(
11341169
_trace_and_lower_to_edge_ops(embedding_byte_pattern),
1135-
_trace_and_lower_to_edge_ops(embedding_byte_replacement),
1136-
[],
1170+
_trace_and_lower_to_edge_ops(get_embedding_byte_replacement(False)),
1171+
[get_embedding_replacement_filter(has_nonzero_zero_point=False)],
1172+
),
1173+
(
1174+
_trace_and_lower_to_edge_ops(embedding_byte_pattern),
1175+
_trace_and_lower_to_edge_ops(get_embedding_byte_replacement(True)),
1176+
[get_embedding_replacement_filter(has_nonzero_zero_point=True)],
11371177
),
11381178
(
11391179
_trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern),
1140-
_trace_and_lower_to_edge_ops(embedding_byte_dtype_replacement),
1141-
[],
1180+
_trace_and_lower_to_edge_ops(get_embedding_byte_dtype_replacement(False)),
1181+
[get_embedding_replacement_filter(has_nonzero_zero_point=False)],
1182+
),
1183+
(
1184+
_trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern),
1185+
_trace_and_lower_to_edge_ops(get_embedding_byte_dtype_replacement(True)),
1186+
[get_embedding_replacement_filter(has_nonzero_zero_point=True)],
11421187
),
11431188
(
11441189
_trace_and_lower_to_edge_ops(embedding_2bit_pattern),
1145-
_trace_and_lower_to_edge_ops(embedding_2bit_replacement),
1146-
[],
1190+
_trace_and_lower_to_edge_ops(get_embedding_2bit_replacement(False)),
1191+
[get_embedding_replacement_filter(has_nonzero_zero_point=False)],
1192+
),
1193+
(
1194+
_trace_and_lower_to_edge_ops(embedding_2bit_pattern),
1195+
_trace_and_lower_to_edge_ops(get_embedding_2bit_replacement(True)),
1196+
[get_embedding_replacement_filter(has_nonzero_zero_point=True)],
11471197
),
11481198
(
11491199
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern),
1150-
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_replacement),
1151-
[],
1200+
_trace_and_lower_to_edge_ops(get_embedding_2bit_dtype_replacement(False)),
1201+
[get_embedding_replacement_filter(has_nonzero_zero_point=False)],
1202+
),
1203+
(
1204+
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern),
1205+
_trace_and_lower_to_edge_ops(get_embedding_2bit_dtype_replacement(True)),
1206+
[get_embedding_replacement_filter(has_nonzero_zero_point=True)],
11521207
),
11531208
(
11541209
_trace_and_lower_to_edge_ops(embedding_4bit_pattern),
1155-
_trace_and_lower_to_edge_ops(embedding_4bit_replacement),
1156-
[],
1210+
_trace_and_lower_to_edge_ops(get_embedding_4bit_replacement(has_nonzero_zero_point=False)),
1211+
[get_embedding_replacement_filter(has_nonzero_zero_point=False)],
1212+
),
1213+
(
1214+
_trace_and_lower_to_edge_ops(embedding_4bit_pattern),
1215+
_trace_and_lower_to_edge_ops(get_embedding_4bit_replacement(has_nonzero_zero_point=True)),
1216+
[get_embedding_replacement_filter(has_nonzero_zero_point=True)],
11571217
),
11581218
(
11591219
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern),
1160-
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_replacement),
1161-
[],
1220+
_trace_and_lower_to_edge_ops(get_embedding_4bit_dtype_replacement(has_nonzero_zero_point=False)),
1221+
[get_embedding_replacement_filter(has_nonzero_zero_point=False)],
1222+
),
1223+
(
1224+
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern),
1225+
_trace_and_lower_to_edge_ops(get_embedding_4bit_dtype_replacement(has_nonzero_zero_point=True)),
1226+
[get_embedding_replacement_filter(has_nonzero_zero_point=True)],
11621227
),
11631228
]
11641229

@@ -1445,7 +1510,7 @@ def replacement(x, x_scale, x_zero_point, x_qmin, x_qmax):
14451510
"""
14461511

14471512

1448-
def get_quant_patterns_and_replacements() -> (
1513+
def get_quant_patterns_and_replacements(node_value_dict) -> (
14491514
List[Tuple[Callable, Callable, List[Callable]]]
14501515
):
14511516

@@ -1457,6 +1522,6 @@ def get_quant_patterns_and_replacements() -> (
14571522
*_get_slice_patterns_and_replacements(),
14581523
# *_get_fixed_qparams_ops_patterns_and_replacements(),
14591524
*_get_embedding_ops_patterns_and_replacements(),
1460-
*_get_embedding_ops_patterns_and_replacements_torchao(),
1525+
*_get_embedding_ops_patterns_and_replacements_torchao(node_value_dict),
14611526
]
14621527
)

0 commit comments

Comments
 (0)