diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 9c454f4339f..3dd612e650e 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -38,7 +38,7 @@ ) from executorch.exir.passes import ToOutVarPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from executorch.exir.program._program import to_edge +from executorch.exir.program._program import _transform, to_edge from torch.export.exported_program import ExportedProgram from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e @@ -145,22 +145,22 @@ def convert_pt2( # fused model, to be able to get reference numerics. # If this does not apply, please use quantize_pt2 instead. def fuse_pt2( - converted_graph_module: torch.fx.GraphModule, + converted_program: ExportedProgram, quantizer: CadenceQuantizer, -) -> torch.fx.GraphModule: +) -> ExportedProgram: """ - Fuse a converted graph module using the given quantizer. + Fuse a converted exported program using the given quantizer. The quantizer must be the same as the one used to convert the model. If you do not expect that behavior, please use quantize_pt2 instead, which will instantiate a default quantizer for you if needed. - Returns a GraphModule with the fused model. + Returns an ExportedProgram with the fused model. """ # Get patterns and apply fusion of dq -> op -> q to qop # pyre-ignore[16]: no attribute patterns = [q.pattern for q in quantizer.quantizers] - QuantFusion(patterns)(converted_graph_module) + fused_program = _transform(converted_program, QuantFusion(patterns)) - return converted_graph_module + return fused_program # Note: quantizer is not optional here to force the user to supply a quantizer @@ -210,7 +210,7 @@ def quantize_pt2( If calibration data is provided, it will be used to calibrate the model. If not, the inputs will be used for calibration instead, which is useful for unit tests but should not be used for end-to-end use cases. - Returns a GraphModule with the quantized model. + Returns an ExportedProgram with the quantized model. Note: this function should not be called directly in general. Please use quantize_and_export_to_executorch for most needs. """ @@ -227,16 +227,15 @@ def quantize_pt2( dump_graphs=dump_graphs, ) - # Get fused model - fused_gm = fuse_pt2(converted_gm, quantizer) + # Apply quant fusion to the exported program + program = torch.export.export(converted_gm, inputs, strict=True) + fused_program = fuse_pt2(program, quantizer) if dump_graphs: logging.info("Graph after quantization and fusion:") - logging.info(fused_gm.graph.print_tabular()) + logging.info(fused_program.graph_module.graph.print_tabular()) - program = torch.export.export(fused_gm, inputs, strict=True) - - return program + return fused_program TO_EDGE_OP_EXCEPTION_LIST: list[torch._ops.OpOverload] = [ diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index 6af7a88fdc2..20719322e82 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -63,11 +63,10 @@ def export_model( # Get reference outputs from converted model ref_outputs = converted_model(*example_inputs) - # Quantize the model (note: quantizer needs to be the same as - # the one used in prepare_and_convert_pt2) - quantized_model = fuse_pt2(converted_model, quantizer) + ep = torch.export.export(converted_model, example_inputs, strict=True) - ep = torch.export.export(quantized_model, example_inputs, strict=True) + # Fuse the quantized patterns on the exported program (note: quantizer needs to be the same as the one used in prepare_and_convert_pt2) + ep = fuse_pt2(ep, quantizer) # Get edge program after Cadence specific passes exec_prog: ExecutorchProgramManager = _lower_ep_to_cadence_gen_etrecord( diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index e2818f725ef..7093ef19c3d 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -33,6 +33,7 @@ ) from executorch.backends.cadence.aot.quantizer.utils import ( check_out_zero_point_is_min_range, + copy_node_metadata, create_zero_bias_int32, find_sequential_partitions_aten, get_conv_args, @@ -159,6 +160,20 @@ def get_args_and_kwargs_layer_norm( ), {"dtype": torch.float32}, ) + if len(inputs_inputs) > 0: + if "val" in inputs_inputs[0].meta: + fake_mode = inputs_inputs[0].meta["val"].fake_mode + if fake_mode is not None: + with fake_mode: + fake_weight = torch.full( + other_inputs[0], 1, dtype=torch.float32 + ) + weight.meta["val"] = fake_weight + else: + weight.meta["val"] = torch.full( + other_inputs[0], 1, dtype=torch.float32 + ) + copy_node_metadata(weight, inputs_inputs[0]) bias = other_inputs[2] if len(other_inputs) > 2 else None @@ -171,6 +186,18 @@ def get_args_and_kwargs_layer_norm( ), {"dtype": torch.float32}, ) + if len(inputs_inputs) > 0: + if "val" in inputs_inputs[0].meta: + fake_mode = inputs_inputs[0].meta["val"].fake_mode + if fake_mode is not None: + with fake_mode: + fake_bias = torch.full(other_inputs[0], 0, dtype=torch.float32) + bias.meta["val"] = fake_bias + else: + bias.meta["val"] = torch.full( + other_inputs[0], 0, dtype=torch.float32 + ) + copy_node_metadata(bias, inputs_inputs[0]) # Make the args and kwargs for the replacement op args = tuple(inputs_inputs + [scale, zero_point]) @@ -346,6 +373,16 @@ def get_args_and_kwargs_softmax( ), {"dtype": torch.int32}, ) + if len(inputs_inputs) > 0: + if "val" in inputs_inputs[0].meta: + fake_mode = inputs_inputs[0].meta["val"].fake_mode + if fake_mode is not None: + with fake_mode: + fake_mask = torch.full(mask_shape, 0.0, dtype=torch.int32) + mask_tensor.meta["val"] = fake_mask + else: + mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32) + copy_node_metadata(mask_tensor, inputs_inputs[0]) # Make the scale and zero_point tensors in_scale = dequants_inputs[0].args[1] in_zero_point = dequants_inputs[0].args[2] @@ -395,10 +432,13 @@ def get_args_and_kwargs_mixed_w8a32_conv( torch.ops.aten.permute.default, (other_inputs[0], [0, 2, 1]), # NCL -> NLC ) + copy_node_metadata(transposed_inputs, other_inputs[0]) + transposed_weights = graph_module.graph.call_function( torch.ops.aten.permute.default, (weights_inputs[0], [2, 0, 1]), # NCL -> LNC ) + copy_node_metadata(transposed_weights, weights_inputs[0]) args = ( transposed_inputs, @@ -582,6 +622,26 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 torch.ops.aten.transpose.int, (weights_inputs[0], 0, 1), ) + if "val" in weights_inputs[0].meta: + original_val = weights_inputs[0].meta["val"] + fake_mode = original_val.fake_mode + if fake_mode is not None: + with fake_mode: + transposed_val = torch.ops.aten.transpose.int( + original_val, 0, 1 + ) + transposed_weights.meta["val"] = transposed_val + else: + transposed_shape = list(original_val.shape) + transposed_shape[0], transposed_shape[1] = ( + transposed_shape[1], + transposed_shape[0], + ) + transposed_weights.meta["val"] = torch.zeros( + transposed_shape, dtype=original_val.dtype + ) + copy_node_metadata(transposed_weights, weights_inputs[0]) + # Call linear with transposed weight args, kwargs = get_args_and_kwargs_linear( graph_module, @@ -654,6 +714,19 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 legalize_graph(graph_module) graph_module.graph.eliminate_dead_code() + nodes_list = list(graph_module.graph.nodes) + + if len(nodes_list) > 0 and nodes_list[-1].op != "output": + output_nodes = [n for n in nodes_list if n.op == "output"] + output_arg = output_nodes[0].args[0] + original_meta = output_nodes[0].meta.copy() + + for out_node in output_nodes: + graph_module.graph.erase_node(out_node) + + new_output_node = graph_module.graph.output(output_arg) + new_output_node.meta.update(original_meta) + graph_module.recompile() return PassResult(graph_module, True) diff --git a/backends/cadence/aot/quantizer/utils.py b/backends/cadence/aot/quantizer/utils.py index 68fc6740cb4..dfc31bfac8c 100644 --- a/backends/cadence/aot/quantizer/utils.py +++ b/backends/cadence/aot/quantizer/utils.py @@ -24,6 +24,12 @@ from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY +def copy_node_metadata(dest_node: fx.Node, src_node: fx.Node) -> None: + for key in ["nn_module_stack", "stack_trace", "source_fn_stack"]: + if key in src_node.meta and src_node.meta[key]: + dest_node.meta[key] = src_node.meta[key] + + def quantize_tensor_multiplier( requantize_scale_tensor: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -114,15 +120,45 @@ def create_zero_bias_int32( """ Creates a zero bias tensor with the shape of weight[0] """ - attr_node = getattr(graph_module, weight_node.target) + try: + attr_node = getattr(graph_module, weight_node.target) + except AttributeError: + if "val" in weight_node.meta: + attr_node = weight_node.meta["val"] + else: + param_dict = dict(graph_module.named_parameters()) + if weight_node.target in param_dict: + attr_node = param_dict[weight_node.target] + else: + buffer_dict = dict(graph_module.named_buffers()) + if weight_node.target in buffer_dict: + attr_node = buffer_dict[weight_node.target] + else: + raise AttributeError( + f"Could not find weight tensor for node {weight_node.target}. " + f"Node metadata keys: {list(weight_node.meta.keys())}" + ) + weight_shape = list(attr_node.shape) bias_shape = weight_shape[0] - return graph_module.graph.call_function( + new_node = graph_module.graph.call_function( torch.ops.aten.full.default, ([bias_shape], 0.0), {"dtype": torch.int32}, ) + if "val" in weight_node.meta: + fake_mode = weight_node.meta["val"].fake_mode + if fake_mode is not None: + with fake_mode: + fake_bias = torch.zeros([bias_shape], dtype=torch.int32) + new_node.meta["val"] = fake_bias + else: + new_node.meta["val"] = torch.zeros([bias_shape], dtype=torch.int32) + copy_node_metadata(new_node, weight_node) + + return new_node + def get_bias_qparams( obs_or_fqs: List[ObserverOrFakeQuantize], diff --git a/util/activation_memory_profiler.py b/util/activation_memory_profiler.py index 80e4fac56e2..caf4dc1380b 100644 --- a/util/activation_memory_profiler.py +++ b/util/activation_memory_profiler.py @@ -41,9 +41,10 @@ def _get_module_hierarchy(node: torch.fx.Node) -> str: Get the module hierarchy of the given node. """ module_stack = node.meta.get("nn_module_stack") - if module_stack is not None: + if module_stack is not None and module_stack: module_values_list = list(module_stack.values()) - return module_values_list[-1][0] + if module_values_list: + return module_values_list[-1][0] return ""