Skip to content

Commit d39d929

Browse files
metascroyfacebook-github-bot
authored andcommitted
Embedding quant unification (#14622)
Summary: Pull Request resolved: #14622 Differential Revision: D83318725
1 parent 684b5fd commit d39d929

File tree

5 files changed

+209
-118
lines changed

5 files changed

+209
-118
lines changed

.ci/scripts/test_llama.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ if [[ "${CUSTOM}" == "ON" ]]; then
237237
EXPORT_ARGS="${EXPORT_ARGS} model.use_sdpa_with_kv_cache=true"
238238
fi
239239
if [[ "${QE}" == "ON" ]]; then
240-
EXPORT_ARGS="${EXPORT_ARGS} quantization.embedding_quantize=\"8,1024\""
240+
EXPORT_ARGS="${EXPORT_ARGS} quantization.embedding_quantize=\"8,768\""
241241
fi
242242
if [[ "${MPS}" == "ON" ]]; then
243243
EXPORT_ARGS="${EXPORT_ARGS} backend.mps.enabled=true model.enable_dynamic_shape=false debug.verbose=true"

examples/models/llama/source_transformation/quantize.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -595,19 +595,12 @@ def __init__(
595595

596596
@torch.no_grad()
597597
def create_quantized_state_dict(self, packed=False) -> Dict:
598+
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_, MappingType
599+
from torchao.quantization.granularity import PerGroup
600+
598601
cur_state_dict = self.mod.state_dict()
599602

600-
if self.bitwidth == 2:
601-
range_min = -2
602-
range_max = 1
603-
elif self.bitwidth == 4:
604-
range_min = -8
605-
range_max = 7
606-
elif self.bitwidth == 8:
607-
range_min = -128
608-
range_max = 127
609-
else:
610-
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
603+
assert self.bitwidth in [2, 4, 8], f"Unsupported bitwidth {self.bitwidth}"
611604

612605
for fqn, mod in self.mod.named_modules():
613606
if isinstance(mod, nn.Embedding):
@@ -619,18 +612,18 @@ def create_quantized_state_dict(self, packed=False) -> Dict:
619612
print(
620613
f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
621614
)
622-
weight, scales, _ = dynamically_quantize_per_channel(
623-
(
624-
mod.weight.to(dtype=self.precision)
625-
if self.precision
626-
else mod.weight
627-
),
628-
range_min,
629-
range_max,
630-
torch.int8,
631-
self.group_size,
632-
scales_dtype=mod.weight.dtype,
615+
tmp_model = nn.Embedding(mod.weight.shape[0], mod.weight.shape[1])
616+
if self.precision:
617+
tmp_model = tmp_model.to(dtype=self.precision)
618+
tmp_model.weight = nn.Parameter(mod.weight)
619+
config = IntxWeightOnlyConfig(
620+
weight_dtype=getattr(torch, f"int{self.bitwidth}"),
621+
granularity=PerGroup(self.group_size),
622+
mapping_type=MappingType.SYMMETRIC,
633623
)
624+
quantize_(tmp_model, config, lambda m, fqn: isinstance(m, nn.Embedding))
625+
weight = tmp_model.weight.qdata
626+
scales = tmp_model.weight.scale
634627

635628
if packed:
636629
if self.bitwidth == 2:

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 152 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -986,25 +986,46 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
986986
]
987987

988988

989-
def _get_embedding_ops_patterns_and_replacements_torchao() -> ( # noqa C901
989+
def _get_embedding_ops_patterns_and_replacements_torchao(node_value_dict) -> ( # noqa C901
990990
List[Tuple[Callable, Callable, List[Callable]]]
991991
):
992+
assert node_value_dict is not None, "node_value_dict cannot be None"
993+
def get_embedding_replacement_filter(has_nonzero_zero_point):
994+
def _filter(match, original_graph, pattern_graph):
995+
def get_val(name):
996+
node = [n for n in match.nodes_map if n.name == name][0]
997+
val = match.nodes_map[node]
998+
if isinstance(val, torch.fx.Node) and val.target in node_value_dict:
999+
return node_value_dict[val.target]
1000+
return val
1001+
1002+
zero_point = get_val("zero_point")
1003+
all_zero = (zero_point == 0).all().item()
1004+
if has_nonzero_zero_point:
1005+
return not all_zero
1006+
else:
1007+
return all_zero
1008+
return _filter
1009+
9921010
def embedding_byte_pattern(indices, int_data, group_size, scale, zero_point):
9931011
dq = torch.ops.torchao.dequantize_affine.default(
9941012
int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127
9951013
)
9961014
return torch.ops.aten.embedding.default(dq, indices)
9971015

998-
def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point):
999-
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1000-
return torch.ops.quantized_decomposed.embedding_byte.default(
1001-
int_data,
1002-
scale,
1003-
zero_point_dtype_cast,
1004-
-128,
1005-
127,
1006-
indices,
1007-
)
1016+
def get_embedding_byte_replacement(has_nonzero_zero_point):
1017+
def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point):
1018+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1019+
zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1020+
return torch.ops.quantized_decomposed.embedding_byte.default(
1021+
int_data,
1022+
scale,
1023+
zero_point_dtype_cast,
1024+
-128,
1025+
127,
1026+
indices,
1027+
)
1028+
return embedding_byte_replacement
10081029

10091030
def embedding_byte_dtype_pattern(
10101031
indices, int_data, group_size, scale, zero_point, output_dtype
@@ -1021,34 +1042,40 @@ def embedding_byte_dtype_pattern(
10211042
)
10221043
return torch.ops.aten.embedding.default(dq, indices)
10231044

1024-
def embedding_byte_dtype_replacement(
1025-
indices, int_data, group_size, scale, zero_point, output_dtype
1026-
):
1027-
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1028-
return torch.ops.quantized_decomposed.embedding_byte.dtype(
1029-
int_data,
1030-
scale,
1031-
zero_point_dtype_cast,
1032-
-128,
1033-
127,
1034-
indices,
1035-
dtype=output_dtype,
1036-
)
1045+
def get_embedding_byte_dtype_replacement(has_nonzero_zero_point):
1046+
def embedding_byte_dtype_replacement(
1047+
indices, int_data, group_size, scale, zero_point, output_dtype
1048+
):
1049+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1050+
zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1051+
return torch.ops.quantized_decomposed.embedding_byte.dtype(
1052+
int_data,
1053+
scale,
1054+
zero_point_dtype_cast,
1055+
-128,
1056+
127,
1057+
indices,
1058+
dtype=output_dtype,
1059+
)
1060+
return embedding_byte_dtype_replacement
10371061

10381062
def embedding_2bit_pattern(indices, int_data, group_size, scale, zero_point):
10391063
dq = torch.ops.torchao.dequantize_affine.default(
10401064
int_data, [1, group_size], scale, zero_point, torch.int8, -2, 1
10411065
)
10421066
return torch.ops.aten.embedding.default(dq, indices)
10431067

1044-
def embedding_2bit_replacement(indices, int_data, group_size, scale, zero_point):
1045-
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1046-
int_data, 2
1047-
)
1048-
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1049-
return torch.ops.quantized_decomposed.embedding_2bit.default(
1050-
packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices
1051-
)
1068+
def get_embedding_2bit_replacement(has_nonzero_zero_point):
1069+
def embedding_2bit_replacement(indices, int_data, group_size, scale, zero_point):
1070+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1071+
int_data, 2
1072+
)
1073+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1074+
zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1075+
return torch.ops.quantized_decomposed.embedding_2bit.default(
1076+
packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices
1077+
)
1078+
return embedding_2bit_replacement
10521079

10531080
def embedding_2bit_dtype_pattern(
10541081
indices, int_data, group_size, scale, zero_point, output_dtype
@@ -1065,37 +1092,43 @@ def embedding_2bit_dtype_pattern(
10651092
)
10661093
return torch.ops.aten.embedding.default(dq, indices)
10671094

1068-
def embedding_2bit_dtype_replacement(
1069-
indices, int_data, group_size, scale, zero_point, output_dtype
1070-
):
1071-
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1072-
int_data, 2
1073-
)
1074-
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1075-
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
1076-
packed_int_data,
1077-
scale,
1078-
zero_point_dtype_cast,
1079-
-2,
1080-
1,
1081-
indices,
1082-
dtype=output_dtype,
1083-
)
1095+
def get_embedding_2bit_dtype_replacement(has_nonzero_zero_point):
1096+
def embedding_2bit_dtype_replacement(
1097+
indices, int_data, group_size, scale, zero_point, output_dtype
1098+
):
1099+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1100+
int_data, 2
1101+
)
1102+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1103+
zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1104+
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
1105+
packed_int_data,
1106+
scale,
1107+
zero_point_dtype_cast,
1108+
-2,
1109+
1,
1110+
indices,
1111+
dtype=output_dtype,
1112+
)
1113+
return embedding_2bit_dtype_replacement
10841114

10851115
def embedding_4bit_pattern(indices, int_data, group_size, scale, zero_point):
10861116
dq = torch.ops.torchao.dequantize_affine.default(
10871117
int_data, [1, group_size], scale, zero_point, torch.int8, -8, 7
10881118
)
10891119
return torch.ops.aten.embedding.default(dq, indices)
10901120

1091-
def embedding_4bit_replacement(indices, int_data, group_size, scale, zero_point):
1092-
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1093-
int_data, 4
1094-
)
1095-
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1096-
return torch.ops.quantized_decomposed.embedding_4bit.default(
1097-
packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices
1098-
)
1121+
def get_embedding_4bit_replacement(has_nonzero_zero_point):
1122+
def embedding_4bit_replacement(indices, int_data, group_size, scale, zero_point):
1123+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1124+
int_data, 4
1125+
)
1126+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1127+
zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1128+
return torch.ops.quantized_decomposed.embedding_4bit.default(
1129+
packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices
1130+
)
1131+
return embedding_4bit_replacement
10991132

11001133
def embedding_4bit_dtype_pattern(
11011134
indices, int_data, group_size, scale, zero_point, output_dtype
@@ -1112,53 +1145,86 @@ def embedding_4bit_dtype_pattern(
11121145
)
11131146
return torch.ops.aten.embedding.default(dq, indices)
11141147

1115-
def embedding_4bit_dtype_replacement(
1116-
indices, int_data, group_size, scale, zero_point, output_dtype
1117-
):
1118-
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1119-
int_data, 4
1120-
)
1121-
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1122-
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
1123-
packed_int_data,
1124-
scale,
1125-
zero_point_dtype_cast,
1126-
-8,
1127-
7,
1128-
indices,
1129-
dtype=output_dtype,
1130-
)
1148+
def get_embedding_4bit_dtype_replacement(has_nonzero_zero_point):
1149+
def embedding_4bit_dtype_replacement(
1150+
indices, int_data, group_size, scale, zero_point, output_dtype
1151+
):
1152+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1153+
int_data, 4
1154+
)
1155+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1156+
zero_point_dtype_cast = zero_point_dtype_cast if has_nonzero_zero_point else None
1157+
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
1158+
packed_int_data,
1159+
scale,
1160+
zero_point_dtype_cast,
1161+
-8,
1162+
7,
1163+
indices,
1164+
dtype=output_dtype,
1165+
)
1166+
return embedding_4bit_dtype_replacement
11311167

11321168
return [
11331169
(
11341170
_trace_and_lower_to_edge_ops(embedding_byte_pattern),
1135-
_trace_and_lower_to_edge_ops(embedding_byte_replacement),
1136-
[],
1171+
_trace_and_lower_to_edge_ops(get_embedding_byte_replacement(False)),
1172+
[get_embedding_replacement_filter(has_nonzero_zero_point=False)],
1173+
),
1174+
(
1175+
_trace_and_lower_to_edge_ops(embedding_byte_pattern),
1176+
_trace_and_lower_to_edge_ops(get_embedding_byte_replacement(True)),
1177+
[get_embedding_replacement_filter(has_nonzero_zero_point=True)],
11371178
),
11381179
(
11391180
_trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern),
1140-
_trace_and_lower_to_edge_ops(embedding_byte_dtype_replacement),
1141-
[],
1181+
_trace_and_lower_to_edge_ops(get_embedding_byte_dtype_replacement(False)),
1182+
[get_embedding_replacement_filter(has_nonzero_zero_point=False)],
1183+
),
1184+
(
1185+
_trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern),
1186+
_trace_and_lower_to_edge_ops(get_embedding_byte_dtype_replacement(True)),
1187+
[get_embedding_replacement_filter(has_nonzero_zero_point=True)],
11421188
),
11431189
(
11441190
_trace_and_lower_to_edge_ops(embedding_2bit_pattern),
1145-
_trace_and_lower_to_edge_ops(embedding_2bit_replacement),
1146-
[],
1191+
_trace_and_lower_to_edge_ops(get_embedding_2bit_replacement(False)),
1192+
[get_embedding_replacement_filter(has_nonzero_zero_point=False)],
1193+
),
1194+
(
1195+
_trace_and_lower_to_edge_ops(embedding_2bit_pattern),
1196+
_trace_and_lower_to_edge_ops(get_embedding_2bit_replacement(True)),
1197+
[get_embedding_replacement_filter(has_nonzero_zero_point=True)],
11471198
),
11481199
(
11491200
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern),
1150-
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_replacement),
1151-
[],
1201+
_trace_and_lower_to_edge_ops(get_embedding_2bit_dtype_replacement(False)),
1202+
[get_embedding_replacement_filter(has_nonzero_zero_point=False)],
1203+
),
1204+
(
1205+
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern),
1206+
_trace_and_lower_to_edge_ops(get_embedding_2bit_dtype_replacement(True)),
1207+
[get_embedding_replacement_filter(has_nonzero_zero_point=True)],
11521208
),
11531209
(
11541210
_trace_and_lower_to_edge_ops(embedding_4bit_pattern),
1155-
_trace_and_lower_to_edge_ops(embedding_4bit_replacement),
1156-
[],
1211+
_trace_and_lower_to_edge_ops(get_embedding_4bit_replacement(has_nonzero_zero_point=False)),
1212+
[get_embedding_replacement_filter(has_nonzero_zero_point=False)],
1213+
),
1214+
(
1215+
_trace_and_lower_to_edge_ops(embedding_4bit_pattern),
1216+
_trace_and_lower_to_edge_ops(get_embedding_4bit_replacement(has_nonzero_zero_point=True)),
1217+
[get_embedding_replacement_filter(has_nonzero_zero_point=True)],
11571218
),
11581219
(
11591220
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern),
1160-
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_replacement),
1161-
[],
1221+
_trace_and_lower_to_edge_ops(get_embedding_4bit_dtype_replacement(has_nonzero_zero_point=False)),
1222+
[get_embedding_replacement_filter(has_nonzero_zero_point=False)],
1223+
),
1224+
(
1225+
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern),
1226+
_trace_and_lower_to_edge_ops(get_embedding_4bit_dtype_replacement(has_nonzero_zero_point=True)),
1227+
[get_embedding_replacement_filter(has_nonzero_zero_point=True)],
11621228
),
11631229
]
11641230

@@ -1445,7 +1511,7 @@ def replacement(x, x_scale, x_zero_point, x_qmin, x_qmax):
14451511
"""
14461512

14471513

1448-
def get_quant_patterns_and_replacements() -> (
1514+
def get_quant_patterns_and_replacements(node_value_dict) -> (
14491515
List[Tuple[Callable, Callable, List[Callable]]]
14501516
):
14511517

@@ -1457,6 +1523,6 @@ def get_quant_patterns_and_replacements() -> (
14571523
*_get_slice_patterns_and_replacements(),
14581524
# *_get_fixed_qparams_ops_patterns_and_replacements(),
14591525
*_get_embedding_ops_patterns_and_replacements(),
1460-
*_get_embedding_ops_patterns_and_replacements_torchao(),
1526+
*_get_embedding_ops_patterns_and_replacements_torchao(node_value_dict),
14611527
]
14621528
)

0 commit comments

Comments
 (0)