Skip to content

Commit 21f5382

Browse files
committed
New embedding quant fusion
Summary: The diff adds new quant fusion passes to recognize 2, 4, and 8 bit quantized embeedings (per group and per channel) and fuses them to ExecuTorch kernels. This makes torchao's quantize_ integrate with ExecuTorch: ``` quantize_( model, IntxWeightOnlyConfig(weight_dtype=torch.int4, granularity=PerGroup(32)), lambda m, fqn: isinstance(m, torch.nn.Embedding) ) # lower model to executorch ``` For the model to lower, we need to run QuantFusionPass. For subbyte, we also need to run constant_prop_pass. (See new unit tests for examples). In follow-up diffs, we will enable these passes by default in to_executorch before the memory passing and out-variant passes. Differential Revision: D73381542
1 parent ad1b154 commit 21f5382

File tree

6 files changed

+283
-4
lines changed

6 files changed

+283
-4
lines changed

exir/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ python_library(
1616
"//caffe2:torch",
1717
"//executorch/exir/operator:convert",
1818
"//executorch/extension/pytree:pylib",
19+
"//pytorch/ao:torchao",
1920
],
2021
)
2122

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 171 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,51 @@
2222
"get_quant_patterns_and_replacements",
2323
]
2424

25+
26+
from torch import Tensor
27+
from torch.library import custom_op
28+
@custom_op("quant_fusion::_pack_embedding_weight", mutates_args=())
29+
def _pack_embedding_weight(weight: Tensor, bitwidth: int) -> Tensor:
30+
num_embeddings, embedding_dim = weight.shape
31+
32+
if bitwidth == 2:
33+
assert embedding_dim % 4 == 0, "embedding_dim must be divisible by 4"
34+
weight_range_shifted = weight.add(2).view(torch.uint8)
35+
weight_view = weight_range_shifted.view(
36+
num_embeddings, embedding_dim // 4, 4
37+
)
38+
weight_0 = weight_view[:, :, 0]
39+
weight_1 = weight_view[:, :, 1] << 2
40+
weight_2 = weight_view[:, :, 2] << 4
41+
weight_3 = weight_view[:, :, 3] << 6
42+
packed_weight = weight_0 + weight_1 + weight_2 + weight_3
43+
return packed_weight
44+
elif bitwidth == 4:
45+
assert embedding_dim % 2 == 0, "embedding_dim must be divisible by 2"
46+
weight_range_shifted = weight.add(8).view(torch.uint8)
47+
weight_view = weight_range_shifted.view(
48+
weight.shape[0], weight.shape[1] // 2, 2
49+
)
50+
weight_even = weight_view[:, :, 0] * 16 # left shift 4
51+
weight_odd = weight_view[:, :, 1]
52+
packed_weight = weight_even + weight_odd
53+
return packed_weight
54+
elif bitwidth == 8:
55+
return weight
56+
57+
raise RuntimeError(f"Unsupported bitwidth {bitwidth}")
58+
59+
60+
# Use register_fake to add a ``FakeTensor`` kernel for the operator
61+
@_pack_embedding_weight.register_fake
62+
def _(weight, bit_width):
63+
assert bit_width in [2, 4, 8]
64+
num_embeddings, embedding_dim = weight.shape
65+
values_per_byte = 8 // bit_width
66+
assert embedding_dim % values_per_byte == 0
67+
return torch.empty(num_embeddings, embedding_dim // values_per_byte, dtype=torch.uint8, device=weight.device)
68+
69+
2570
# TODO: extending an existing library that is defined in OSS might be a bit
2671
# confusing, we can investigate if it is possible to define a new library
2772

@@ -70,7 +115,7 @@ def embedding_weight_checks(weight, weight_scales, weight_zero_points):
70115
weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype
71116
), "Expecting weight_zero_points to be None or have same dtype as weight_scales"
72117
assert (
73-
weight_zero_points is None or weight_zero_points.dim() == 1
118+
weight_zero_points is None or weight_zero_points.dim() in [1, 2]
74119
), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"
75120
assert weight_zero_points is None or weight_zero_points.size(0) == weight.size(
76121
0
@@ -233,6 +278,19 @@ def embedding_2bit(
233278
)
234279
return torch.ops.aten.embedding.default(weight, indices)
235280

281+
@register_fake("quantized_decomposed::embedding_2bit")
282+
def _(
283+
weight: torch.Tensor,
284+
weight_scales: torch.Tensor,
285+
weight_zero_points: Optional[torch.Tensor],
286+
weight_quant_min: int,
287+
weight_quant_max: int,
288+
indices: torch.Tensor,
289+
):
290+
num_embeddings, packed_embedding_dim = weight.shape
291+
embedding_dim = packed_embedding_dim * 4
292+
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
293+
return embedding(indices)
236294

237295
@register_fake("quantized_decomposed::embedding_2bit.out")
238296
def embedding_2bit_out_meta(
@@ -253,7 +311,6 @@ def embedding_2bit_out_meta(
253311
indices,
254312
)
255313

256-
257314
@impl(quantized_decomposed_lib, "embedding_2bit.dtype", "CompositeExplicitAutograd")
258315
def embedding_2bit_dtype(
259316
weight: torch.Tensor,
@@ -295,6 +352,20 @@ def embedding_2bit_dtype(
295352
)
296353
return torch.ops.aten.embedding.default(weight, indices)
297354

355+
@register_fake("quantized_decomposed::embedding_2bit.dtype")
356+
def _(
357+
weight: torch.Tensor,
358+
weight_scales: torch.Tensor,
359+
weight_zero_points: Optional[torch.Tensor],
360+
weight_quant_min: int,
361+
weight_quant_max: int,
362+
indices: torch.Tensor,
363+
dtype: Optional[torch.dtype],
364+
) -> torch.Tensor:
365+
num_embeddings, packed_embedding_dim = weight.shape
366+
embedding_dim = packed_embedding_dim * 4
367+
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
368+
return embedding(indices).to(dtype)
298369

299370
@register_fake("quantized_decomposed::embedding_2bit.dtype_out")
300371
def embedding_2bit_dtype_out_meta(
@@ -377,6 +448,19 @@ def embedding_4bit(
377448
)
378449
return torch.ops.aten.embedding.default(weight, indices)
379450

451+
@register_fake("quantized_decomposed::embedding_4bit")
452+
def _(
453+
weight: torch.Tensor,
454+
weight_scales: torch.Tensor,
455+
weight_zero_points: Optional[torch.Tensor],
456+
weight_quant_min: int,
457+
weight_quant_max: int,
458+
indices: torch.Tensor,
459+
):
460+
num_embeddings, packed_embedding_dim = weight.shape
461+
embedding_dim = packed_embedding_dim * 2
462+
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
463+
return embedding(indices)
380464

381465
@register_fake("quantized_decomposed::embedding_4bit.out")
382466
def embedding_4bit_out_meta(
@@ -437,6 +521,20 @@ def embedding_4bit_dtype(
437521
)
438522
return torch.ops.aten.embedding.default(weight, indices)
439523

524+
@register_fake("quantized_decomposed::embedding_4bit.dtype")
525+
def _(
526+
weight: torch.Tensor,
527+
weight_scales: torch.Tensor,
528+
weight_zero_points: Optional[torch.Tensor],
529+
weight_quant_min: int,
530+
weight_quant_max: int,
531+
indices: torch.Tensor,
532+
dtype: Optional[torch.dtype],
533+
) -> torch.Tensor:
534+
num_embeddings, packed_embedding_dim = weight.shape
535+
embedding_dim = packed_embedding_dim * 2
536+
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
537+
return embedding(indices).to(dtype)
440538

441539
@register_fake("quantized_decomposed::embedding_4bit.dtype_out")
442540
def embedding_4bit_dtype_out_meta(
@@ -872,6 +970,76 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
872970
)
873971
]
874972

973+
def _get_embedding_ops_patterns_and_replacements_torchao() -> List[Tuple[Callable, Callable, List[Callable]]]:
974+
def embedding_byte_pattern(indices, int_data, group_size, scale, zero_point):
975+
dq = torch.ops.torchao.dequantize_affine.default(int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127)
976+
return torch.ops.aten.embedding.default(dq, indices)
977+
def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point):
978+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
979+
return torch.ops.quantized_decomposed.embedding_byte.default(
980+
int_data,
981+
scale,
982+
zero_point_dtype_cast,
983+
-128,
984+
127,
985+
indices,
986+
)
987+
def embedding_byte_dtype_pattern(indices, int_data, group_size, scale, zero_point, output_dtype):
988+
dq = torch.ops.torchao.dequantize_affine.default(int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127, 'INT', output_dtype)
989+
return torch.ops.aten.embedding.default(dq, indices)
990+
def embedding_byte_dtype_replacement(indices, int_data, group_size, scale, zero_point, output_dtype):
991+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
992+
return torch.ops.quantized_decomposed.embedding_byte.dtype(
993+
int_data,
994+
scale,
995+
zero_point_dtype_cast,
996+
-128,
997+
127,
998+
indices,
999+
dtype=output_dtype
1000+
)
1001+
1002+
def embedding_2bit_pattern(indices, int_data, group_size, scale, zero_point):
1003+
dq = torch.ops.torchao.dequantize_affine.default(int_data, [1, group_size], scale, zero_point, torch.int8, -2, 1)
1004+
return torch.ops.aten.embedding.default(dq, indices)
1005+
def embedding_2bit_replacement(indices, int_data, group_size, scale, zero_point):
1006+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(int_data, 2)
1007+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1008+
return torch.ops.quantized_decomposed.embedding_2bit.default(packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices)
1009+
1010+
def embedding_2bit_dtype_pattern(indices, int_data, group_size, scale, zero_point, output_dtype):
1011+
dq = torch.ops.torchao.dequantize_affine.default(int_data, [1, group_size], scale, zero_point, torch.int8, -2, 1, 'INT', output_dtype)
1012+
return torch.ops.aten.embedding.default(dq, indices)
1013+
def embedding_2bit_dtype_replacement(indices, int_data, group_size, scale, zero_point, output_dtype):
1014+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(int_data, 2)
1015+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1016+
return torch.ops.quantized_decomposed.embedding_2bit.dtype(packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices, dtype=output_dtype)
1017+
1018+
def embedding_4bit_pattern(indices, int_data, group_size, scale, zero_point):
1019+
dq = torch.ops.torchao.dequantize_affine.default(int_data, [1, group_size], scale, zero_point, torch.int8, -8, 7)
1020+
return torch.ops.aten.embedding.default(dq, indices)
1021+
def embedding_4bit_replacement(indices, int_data, group_size, scale, zero_point):
1022+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(int_data, 4)
1023+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1024+
return torch.ops.quantized_decomposed.embedding_4bit.default(packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices)
1025+
1026+
def embedding_4bit_dtype_pattern(indices, int_data, group_size, scale, zero_point, output_dtype):
1027+
dq = torch.ops.torchao.dequantize_affine.default(int_data, [1, group_size], scale, zero_point, torch.int8, -8, 7, 'INT', output_dtype)
1028+
return torch.ops.aten.embedding.default(dq, indices)
1029+
def embedding_4bit_dtype_replacement(indices, int_data, group_size, scale, zero_point, output_dtype):
1030+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(int_data, 4)
1031+
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1032+
return torch.ops.quantized_decomposed.embedding_4bit.dtype(packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices, dtype=output_dtype)
1033+
1034+
return [
1035+
(_trace_and_lower_to_edge_ops(embedding_byte_pattern), _trace_and_lower_to_edge_ops(embedding_byte_replacement), []),
1036+
(_trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern), _trace_and_lower_to_edge_ops(embedding_byte_dtype_replacement), []),
1037+
(_trace_and_lower_to_edge_ops(embedding_2bit_pattern), _trace_and_lower_to_edge_ops(embedding_2bit_replacement), []),
1038+
(_trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern), _trace_and_lower_to_edge_ops(embedding_2bit_dtype_replacement), []),
1039+
(_trace_and_lower_to_edge_ops(embedding_4bit_pattern), _trace_and_lower_to_edge_ops(embedding_4bit_replacement), []),
1040+
(_trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern), _trace_and_lower_to_edge_ops(embedding_4bit_dtype_replacement), []),
1041+
]
1042+
8751043

8761044
def _get_embedding_ops_patterns_and_replacements() -> (
8771045
List[Tuple[Callable, Callable, List[Callable]]]
@@ -1167,5 +1335,6 @@ def get_quant_patterns_and_replacements() -> (
11671335
*_get_slice_patterns_and_replacements(),
11681336
# *_get_fixed_qparams_ops_patterns_and_replacements(),
11691337
*_get_embedding_ops_patterns_and_replacements(),
1338+
*_get_embedding_ops_patterns_and_replacements_torchao(),
11701339
]
11711340
)

exir/passes/quant_fusion_pass.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,17 @@ def _get_qparams(node):
8989
qnode.replace_all_uses_with(maybe_cat)
9090
model.graph.erase_node(qnode)
9191

92-
92+
def _remove_dtype_getattr_nodes(model: GraphModule) -> None:
93+
for n in model.graph.nodes:
94+
if n.op == "call_function" and n.target == getattr:
95+
if isinstance(n.args[0], torch.fx.Node) and n.args[1] == "dtype":
96+
dtype = n.args[0].meta["val"].dtype
97+
n.replace_all_uses_with(dtype)
98+
model.graph.erase_node(n)
99+
model.graph.eliminate_dead_code()
100+
model.graph.lint()
101+
model.recompile()
102+
93103
class QuantFusionPass(ExportPass):
94104
def __init__(self, _fix_node_meta_val=False):
95105
super().__init__()
@@ -123,6 +133,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
123133
torch.fx.Node, lambda x: x.meta["val"], (n.args, n.kwargs)
124134
)
125135
n.meta["val"] = n.target(*args, **kwargs)
136+
_remove_dtype_getattr_nodes(graph_module)
126137
graph_module.graph.lint()
127138
graph_module.graph.eliminate_dead_code()
128139
return PassResult(graph_module, True)

exir/tests/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ python_unittest(
298298
"//caffe2:torch",
299299
"//executorch/exir:lib",
300300
"//executorch/exir/passes:quant_fusion_pass",
301+
"//pytorch/ao:torchao",
302+
"//executorch/exir/passes:constant_prop_pass",
301303
],
302304
)
303305

exir/tests/test_quant_fusion_pass.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch import exir
1313
from executorch.exir import EdgeCompileConfig, to_edge
1414
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
15+
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
1516
from executorch.exir.tests.common import register_additional_test_aten_ops
1617
from torch.ao.quantization import ( # @manual
1718
float_qparams_weight_only_qconfig,
@@ -30,7 +31,8 @@
3031
from torch.nn import functional as F
3132

3233
from torch.testing import FileCheck
33-
34+
from torchao.quantization.quant_api import quantize_, IntxWeightOnlyConfig
35+
from torchao.quantization.granularity import PerGroup, PerAxis
3436

3537
class TestQuantFusionPass(unittest.TestCase):
3638
@classmethod
@@ -373,3 +375,78 @@ def forward(self, indices):
373375
# ).run(
374376
# m.dump_graph_module().code
375377
# )
378+
379+
def test_embedding_torchao(self):
380+
for bit_width, test_dtype_variant, test_per_group in zip([2, 4, 8], [True, False], [True, False]):
381+
self._test_embedding_torchao(bit_width, test_dtype_variant, test_per_group)
382+
383+
def _test_embedding_torchao(self, bit_width: int, test_dtype_variant: bool, test_per_group: bool) -> None:
384+
assert bit_width in [2, 4, 8]
385+
embedding_suffix = f"{bit_width}bit" if bit_width < 8 else "byte"
386+
if test_dtype_variant:
387+
embedding_suffix = f"{embedding_suffix}_dtype"
388+
389+
indices = torch.tensor([1, 2, 3], dtype=torch.int64)
390+
model = torch.nn.Sequential(*[torch.nn.Embedding(10, 64), torch.nn.Linear(64, 8)])
391+
example_inputs = (indices,)
392+
393+
# torchao adds a dtype cast to match embeddings original weight type
394+
# this does not happen for float32 because it is the default dtype
395+
model = model.to(torch.float16) if test_dtype_variant else model
396+
397+
# quantize the model
398+
granularity = PerGroup(32) if test_per_group else PerAxis(0)
399+
quantize_(
400+
model,
401+
IntxWeightOnlyConfig(weight_dtype=getattr(torch, f"int{bit_width}"), granularity=granularity),
402+
lambda m, fqn: isinstance(m, torch.nn.Embedding)
403+
)
404+
expected_outputs = model(*example_inputs)
405+
406+
compile_config = EdgeCompileConfig(
407+
_check_ir_validity=False,
408+
_use_edge_ops=True,
409+
)
410+
m = to_edge(
411+
export(model, example_inputs, strict=True), compile_config=compile_config
412+
)
413+
414+
# Before pass, we see torchao dequantize and embedding ops
415+
FileCheck().check_count(
416+
"executorch_exir_dialects_edge__ops_torchao_dequantize_affine_default", 1, exactly=True
417+
).check_count(
418+
"executorch_exir_dialects_edge__ops_aten_embedding_default", 1, exactly=True,
419+
).run(
420+
m.exported_program().graph_module.code
421+
)
422+
423+
m = m.transform([QuantFusionPass(_fix_node_meta_val=True)])
424+
425+
# After pass, we see packing op and quantized embedding op, but no torchao dequantize op
426+
FileCheck().check_count(
427+
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default", 1 if bit_width < 8 else 0, exactly=True
428+
).check_count(
429+
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
430+
).check_not(
431+
"executorch_exir_dialects_edge__ops_torchao_dequantize_affine_default"
432+
).run(
433+
m.exported_program().graph_module.code
434+
)
435+
436+
constant_prop_pass(m.exported_program())
437+
438+
# After constant prop, we see quantized embedding op, but no packing op
439+
FileCheck().check_count(
440+
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
441+
).check_not(
442+
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
443+
).run(
444+
m.exported_program().graph_module.code
445+
)
446+
447+
# Compare numerics
448+
actual_outputs = m.exported_program().module()(*example_inputs)
449+
self.assertTrue(torch.allclose(expected_outputs, actual_outputs))
450+
451+
# Can lower to executorch
452+
exec_prog = m.to_executorch()

0 commit comments

Comments
 (0)