diff --git a/exir/TARGETS b/exir/TARGETS index 8539311e27d..9b121cfbb77 100644 --- a/exir/TARGETS +++ b/exir/TARGETS @@ -16,6 +16,7 @@ python_library( "//caffe2:torch", "//executorch/exir/operator:convert", "//executorch/extension/pytree:pylib", + "//pytorch/ao:torchao", ], ) diff --git a/exir/passes/_quant_patterns_and_replacements.py b/exir/passes/_quant_patterns_and_replacements.py index 5c2a0541538..2d8880cbde6 100644 --- a/exir/passes/_quant_patterns_and_replacements.py +++ b/exir/passes/_quant_patterns_and_replacements.py @@ -22,6 +22,56 @@ "get_quant_patterns_and_replacements", ] + +from torch import Tensor +from torch.library import custom_op + + +@custom_op("quant_fusion::_pack_embedding_weight", mutates_args=()) +def _pack_embedding_weight(weight: Tensor, bitwidth: int) -> Tensor: + num_embeddings, embedding_dim = weight.shape + + if bitwidth == 2: + assert embedding_dim % 4 == 0, "embedding_dim must be divisible by 4" + weight_range_shifted = weight.add(2).view(torch.uint8) + weight_view = weight_range_shifted.view(num_embeddings, embedding_dim // 4, 4) + weight_0 = weight_view[:, :, 0] + weight_1 = weight_view[:, :, 1] << 2 + weight_2 = weight_view[:, :, 2] << 4 + weight_3 = weight_view[:, :, 3] << 6 + packed_weight = weight_0 | weight_1 | weight_2 | weight_3 + return packed_weight + elif bitwidth == 4: + assert embedding_dim % 2 == 0, "embedding_dim must be divisible by 2" + weight_range_shifted = weight.add(8).view(torch.uint8) + weight_view = weight_range_shifted.view( + weight.shape[0], weight.shape[1] // 2, 2 + ) + weight_even = weight_view[:, :, 0] << 4 + weight_odd = weight_view[:, :, 1] + packed_weight = weight_even | weight_odd + return packed_weight + elif bitwidth == 8: + return weight + + raise RuntimeError(f"Unsupported bitwidth {bitwidth}") + + +# Use register_fake to add a ``FakeTensor`` kernel for the operator +@_pack_embedding_weight.register_fake +def _(weight, bit_width): + assert bit_width in [2, 4, 8] + num_embeddings, embedding_dim = weight.shape + values_per_byte = 8 // bit_width + assert embedding_dim % values_per_byte == 0 + return torch.empty( + num_embeddings, + embedding_dim // values_per_byte, + dtype=torch.uint8, + device=weight.device, + ) + + # TODO: extending an existing library that is defined in OSS might be a bit # confusing, we can investigate if it is possible to define a new library @@ -69,9 +119,10 @@ def embedding_weight_checks(weight, weight_scales, weight_zero_points): assert ( weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype ), "Expecting weight_zero_points to be None or have same dtype as weight_scales" - assert ( - weight_zero_points is None or weight_zero_points.dim() == 1 - ), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}" + assert weight_zero_points is None or weight_zero_points.dim() in [ + 1, + 2, + ], f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}" assert weight_zero_points is None or weight_zero_points.size(0) == weight.size( 0 ), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}" @@ -234,6 +285,21 @@ def embedding_2bit( return torch.ops.aten.embedding.default(weight, indices) +@register_fake("quantized_decomposed::embedding_2bit") +def _( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, +): + num_embeddings, packed_embedding_dim = weight.shape + embedding_dim = packed_embedding_dim * 4 + embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device) + return embedding(indices) + + @register_fake("quantized_decomposed::embedding_2bit.out") def embedding_2bit_out_meta( weight: torch.Tensor, @@ -296,6 +362,22 @@ def embedding_2bit_dtype( return torch.ops.aten.embedding.default(weight, indices) +@register_fake("quantized_decomposed::embedding_2bit.dtype") +def _( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, + dtype: Optional[torch.dtype], +) -> torch.Tensor: + num_embeddings, packed_embedding_dim = weight.shape + embedding_dim = packed_embedding_dim * 4 + embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device) + return embedding(indices).to(dtype) + + @register_fake("quantized_decomposed::embedding_2bit.dtype_out") def embedding_2bit_dtype_out_meta( weight: torch.Tensor, @@ -378,6 +460,21 @@ def embedding_4bit( return torch.ops.aten.embedding.default(weight, indices) +@register_fake("quantized_decomposed::embedding_4bit") +def _( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, +): + num_embeddings, packed_embedding_dim = weight.shape + embedding_dim = packed_embedding_dim * 2 + embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device) + return embedding(indices) + + @register_fake("quantized_decomposed::embedding_4bit.out") def embedding_4bit_out_meta( weight: torch.Tensor, @@ -438,6 +535,22 @@ def embedding_4bit_dtype( return torch.ops.aten.embedding.default(weight, indices) +@register_fake("quantized_decomposed::embedding_4bit.dtype") +def _( + weight: torch.Tensor, + weight_scales: torch.Tensor, + weight_zero_points: Optional[torch.Tensor], + weight_quant_min: int, + weight_quant_max: int, + indices: torch.Tensor, + dtype: Optional[torch.dtype], +) -> torch.Tensor: + num_embeddings, packed_embedding_dim = weight.shape + embedding_dim = packed_embedding_dim * 2 + embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device) + return embedding(indices).to(dtype) + + @register_fake("quantized_decomposed::embedding_4bit.dtype_out") def embedding_4bit_dtype_out_meta( weight: torch.Tensor, @@ -873,6 +986,186 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax): ] +def _get_embedding_ops_patterns_and_replacements_torchao() -> ( # noqa C901 + List[Tuple[Callable, Callable, List[Callable]]] +): + def embedding_byte_pattern(indices, int_data, group_size, scale, zero_point): + dq = torch.ops.torchao.dequantize_affine.default( + int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127 + ) + return torch.ops.aten.embedding.default(dq, indices) + + def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point): + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) + return torch.ops.quantized_decomposed.embedding_byte.default( + int_data, + scale, + zero_point_dtype_cast, + -128, + 127, + indices, + ) + + def embedding_byte_dtype_pattern( + indices, int_data, group_size, scale, zero_point, output_dtype + ): + dq = torch.ops.torchao.dequantize_affine.default( + int_data, + [1, group_size], + scale, + zero_point, + torch.int8, + -128, + 127, + "INT", + output_dtype, + ) + return torch.ops.aten.embedding.default(dq, indices) + + def embedding_byte_dtype_replacement( + indices, int_data, group_size, scale, zero_point, output_dtype + ): + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) + return torch.ops.quantized_decomposed.embedding_byte.dtype( + int_data, + scale, + zero_point_dtype_cast, + -128, + 127, + indices, + dtype=output_dtype, + ) + + def embedding_2bit_pattern(indices, int_data, group_size, scale, zero_point): + dq = torch.ops.torchao.dequantize_affine.default( + int_data, [1, group_size], scale, zero_point, torch.int8, -2, 1 + ) + return torch.ops.aten.embedding.default(dq, indices) + + def embedding_2bit_replacement(indices, int_data, group_size, scale, zero_point): + packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( + int_data, 2 + ) + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) + return torch.ops.quantized_decomposed.embedding_2bit.default( + packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices + ) + + def embedding_2bit_dtype_pattern( + indices, int_data, group_size, scale, zero_point, output_dtype + ): + dq = torch.ops.torchao.dequantize_affine.default( + int_data, + [1, group_size], + scale, + zero_point, + torch.int8, + -2, + 1, + "INT", + output_dtype, + ) + return torch.ops.aten.embedding.default(dq, indices) + + def embedding_2bit_dtype_replacement( + indices, int_data, group_size, scale, zero_point, output_dtype + ): + packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( + int_data, 2 + ) + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) + return torch.ops.quantized_decomposed.embedding_2bit.dtype( + packed_int_data, + scale, + zero_point_dtype_cast, + -2, + 1, + indices, + dtype=output_dtype, + ) + + def embedding_4bit_pattern(indices, int_data, group_size, scale, zero_point): + dq = torch.ops.torchao.dequantize_affine.default( + int_data, [1, group_size], scale, zero_point, torch.int8, -8, 7 + ) + return torch.ops.aten.embedding.default(dq, indices) + + def embedding_4bit_replacement(indices, int_data, group_size, scale, zero_point): + packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( + int_data, 4 + ) + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) + return torch.ops.quantized_decomposed.embedding_4bit.default( + packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices + ) + + def embedding_4bit_dtype_pattern( + indices, int_data, group_size, scale, zero_point, output_dtype + ): + dq = torch.ops.torchao.dequantize_affine.default( + int_data, + [1, group_size], + scale, + zero_point, + torch.int8, + -8, + 7, + "INT", + output_dtype, + ) + return torch.ops.aten.embedding.default(dq, indices) + + def embedding_4bit_dtype_replacement( + indices, int_data, group_size, scale, zero_point, output_dtype + ): + packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default( + int_data, 4 + ) + zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype) + return torch.ops.quantized_decomposed.embedding_4bit.dtype( + packed_int_data, + scale, + zero_point_dtype_cast, + -8, + 7, + indices, + dtype=output_dtype, + ) + + return [ + ( + _trace_and_lower_to_edge_ops(embedding_byte_pattern), + _trace_and_lower_to_edge_ops(embedding_byte_replacement), + [], + ), + ( + _trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern), + _trace_and_lower_to_edge_ops(embedding_byte_dtype_replacement), + [], + ), + ( + _trace_and_lower_to_edge_ops(embedding_2bit_pattern), + _trace_and_lower_to_edge_ops(embedding_2bit_replacement), + [], + ), + ( + _trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern), + _trace_and_lower_to_edge_ops(embedding_2bit_dtype_replacement), + [], + ), + ( + _trace_and_lower_to_edge_ops(embedding_4bit_pattern), + _trace_and_lower_to_edge_ops(embedding_4bit_replacement), + [], + ), + ( + _trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern), + _trace_and_lower_to_edge_ops(embedding_4bit_dtype_replacement), + [], + ), + ] + + def _get_embedding_ops_patterns_and_replacements() -> ( List[Tuple[Callable, Callable, List[Callable]]] ): @@ -1167,5 +1460,6 @@ def get_quant_patterns_and_replacements() -> ( *_get_slice_patterns_and_replacements(), # *_get_fixed_qparams_ops_patterns_and_replacements(), *_get_embedding_ops_patterns_and_replacements(), + *_get_embedding_ops_patterns_and_replacements_torchao(), ] ) diff --git a/exir/passes/quant_fusion_pass.py b/exir/passes/quant_fusion_pass.py index 2b4dd22ef40..0509c5e9820 100644 --- a/exir/passes/quant_fusion_pass.py +++ b/exir/passes/quant_fusion_pass.py @@ -90,6 +90,18 @@ def _get_qparams(node): model.graph.erase_node(qnode) +def _remove_dtype_getattr_nodes(model: GraphModule) -> None: + for n in model.graph.nodes: + if n.op == "call_function" and n.target == getattr: + if isinstance(n.args[0], torch.fx.Node) and n.args[1] == "dtype": + dtype = n.args[0].meta["val"].dtype + n.replace_all_uses_with(dtype) + model.graph.erase_node(n) + model.graph.eliminate_dead_code() + model.graph.lint() + model.recompile() + + class QuantFusionPass(ExportPass): def __init__(self, _fix_node_meta_val=False): super().__init__() @@ -123,6 +135,7 @@ def call(self, graph_module: GraphModule) -> PassResult: torch.fx.Node, lambda x: x.meta["val"], (n.args, n.kwargs) ) n.meta["val"] = n.target(*args, **kwargs) + _remove_dtype_getattr_nodes(graph_module) graph_module.graph.lint() graph_module.graph.eliminate_dead_code() return PassResult(graph_module, True) diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 63f76656a03..a8cdc31555a 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -298,6 +298,8 @@ python_unittest( "//caffe2:torch", "//executorch/exir:lib", "//executorch/exir/passes:quant_fusion_pass", + "//pytorch/ao:torchao", + "//executorch/exir/passes:constant_prop_pass", ], ) diff --git a/exir/tests/test_quant_fusion_pass.py b/exir/tests/test_quant_fusion_pass.py index a339ad97811..4697c5f2dc0 100644 --- a/exir/tests/test_quant_fusion_pass.py +++ b/exir/tests/test_quant_fusion_pass.py @@ -11,6 +11,7 @@ import torch from executorch import exir from executorch.exir import EdgeCompileConfig, to_edge +from executorch.exir.passes.constant_prop_pass import constant_prop_pass from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.tests.common import register_additional_test_aten_ops from torch.ao.quantization import ( # @manual @@ -30,6 +31,8 @@ from torch.nn import functional as F from torch.testing import FileCheck +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ class TestQuantFusionPass(unittest.TestCase): @@ -373,3 +376,96 @@ def forward(self, indices): # ).run( # m.dump_graph_module().code # ) + + def test_embedding_torchao(self) -> None: + for bit_width, use_dtype_variant, test_per_group in zip( + [2, 4, 8], [True, False], [True, False] + ): + self._test_embedding_torchao(bit_width, use_dtype_variant, test_per_group) + + def _test_embedding_torchao( + self, bit_width: int, use_dtype_variant: bool, test_per_group: bool + ) -> None: + assert bit_width in [2, 4, 8] + embedding_suffix = f"{bit_width}bit" if bit_width < 8 else "byte" + if use_dtype_variant: + embedding_suffix = f"{embedding_suffix}_dtype" + + indices = torch.tensor([1, 2, 3], dtype=torch.int64) + model = torch.nn.Sequential( + *[torch.nn.Embedding(10, 64), torch.nn.Linear(64, 8)] + ) + example_inputs = (indices,) + + # torchao adds a dtype cast to match embeddings original weight type + # this does not happen for float32 because it is the default dtype + model = model.to(torch.float16) if use_dtype_variant else model + + # quantize the model + granularity = PerGroup(32) if test_per_group else PerAxis(0) + quantize_( + model, + IntxWeightOnlyConfig( + weight_dtype=getattr(torch, f"int{bit_width}"), granularity=granularity + ), + lambda m, fqn: isinstance(m, torch.nn.Embedding), + ) + expected_outputs = model(*example_inputs) + + compile_config = EdgeCompileConfig( + _check_ir_validity=False, + _use_edge_ops=True, + ) + m = to_edge( + export(model, example_inputs, strict=True), compile_config=compile_config + ) + + # Before pass, we see torchao dequantize and embedding ops + FileCheck().check_count( + "executorch_exir_dialects_edge__ops_torchao_dequantize_affine_default", + 1, + exactly=True, + ).check_count( + "executorch_exir_dialects_edge__ops_aten_embedding_default", + 1, + exactly=True, + ).run( + m.exported_program().graph_module.code + ) + + m = m.transform([QuantFusionPass(_fix_node_meta_val=True)]) + + # After pass, we see packing op and quantized embedding op, but no torchao dequantize op + FileCheck().check_count( + "executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default", + 1 if bit_width < 8 else 0, + exactly=True, + ).check_count( + f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", + 1, + exactly=True, + ).check_not( + "executorch_exir_dialects_edge__ops_torchao_dequantize_affine_default" + ).run( + m.exported_program().graph_module.code + ) + + constant_prop_pass(m.exported_program()) + + # After constant prop, we see quantized embedding op, but no packing op + FileCheck().check_count( + f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", + 1, + exactly=True, + ).check_not( + "executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default", + ).run( + m.exported_program().graph_module.code + ) + + # Compare numerics + actual_outputs = m.exported_program().module()(*example_inputs) + self.assertTrue(torch.allclose(expected_outputs, actual_outputs)) + + # Can lower to executorch + exec_prog = m.to_executorch() # noqa: F841 diff --git a/exir/tracer.py b/exir/tracer.py index c749df510ad..c80d9368ccd 100644 --- a/exir/tracer.py +++ b/exir/tracer.py @@ -629,8 +629,7 @@ def _default_decomposition_table( torch.ops.aten.arange.start, torch.ops.aten.transpose, ] - # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e... - return get_decompositions(decomp_opset) + return get_decompositions(decomp_opset) # pyre-fixme[7] decomps = default_decompositions() # Add edge specific decompositions @@ -642,7 +641,27 @@ def _default_decomposition_table( additional_decomps = get_decompositions(additional_decomp_ops) decomps.update(additional_decomps) # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir.... - return decomps + + never_decompose = [] + try: + # Do not decompose torchao quant primitives + # They have decompositions registered for inductor/CUDA, but in ExecuTorch we + # just pattern match them and lower to delegates + import torchao # noqa: F401 + + never_decompose.extend( + [ + torch.ops.torchao.quantize_affine.default, + torch.ops.torchao.dequantize_affine.default, + torch.ops.torchao.choose_qparams_affine.default, + ] + ) + except: + pass + + for op in never_decompose: + decomps.pop(op, None) + return decomps # pyre-fixme[7] def dynamo_trace(