From ae3264e6aaa3bdbcd5e283cb925e87ab7e2895bb Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 5 Aug 2025 14:22:23 -0700 Subject: [PATCH] [pt2e] Avoid getting model device once per node **Summary:** Previously, we call `assert_and_get_unqiue_device` once per node in both prepare and convert. This is expensive and unnecessary since the model device is the same across all nodes, so we should just call this once in the beginning and reuse the same model device across all the nodes. torchao version of https://github.com/pytorch/pytorch/pull/159901 Note: The prepare path is not completely done yet, since we are blocked on the pytorch PR on being merged. It's different from convert since it still calls utility functions from `torch.ao.quantization.fx`. **Test Plan:** ``` python test/quantization/pt2e/test_quantize_pt2e.py ``` --- torchao/quantization/pt2e/convert.py | 32 ++++++++++++++++++++++----- torchao/quantization/pt2e/observer.py | 12 ++++++++-- torchao/quantization/pt2e/prepare.py | 26 +++++++++++++++++++--- torchao/quantization/pt2e/utils.py | 9 ++++++-- 4 files changed, 66 insertions(+), 13 deletions(-) diff --git a/torchao/quantization/pt2e/convert.py b/torchao/quantization/pt2e/convert.py index 99516ac4c3..0b74341131 100644 --- a/torchao/quantization/pt2e/convert.py +++ b/torchao/quantization/pt2e/convert.py @@ -49,9 +49,7 @@ ) from torch.ao.quantization.fx.utils import ( _get_module, - assert_and_get_unique_device, collect_producer_nodes, - create_getattr_from_value, graph_module_from_producer_nodes, node_arg_is_weight, ) @@ -73,7 +71,11 @@ from torchao.quantization.pt2e import FROM_NODE_KEY from torchao.quantization.pt2e.observer import _is_activation_post_process -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 +from torchao.quantization.pt2e.utils import create_getattr_from_value +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, + _assert_and_get_unique_device, +) if TORCH_VERSION_AT_LEAST_2_6: from torch.fx.traceback import NodeSource, NodeSourceAction @@ -132,6 +134,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed( modules: dict[str, torch.nn.Module], node_name_to_scope: dict[str, tuple[str, type]], node_name_to_qconfig: dict[str, QConfigAny], + model_device: Optional[torch.device] = None, ) -> None: """Replace activation_post_process module call node with quantize and dequantize node working with decomposed Tensor @@ -260,7 +263,11 @@ def add_quantize_dequantize_node_info(qdq_node, original_node): # sure that the default overload can be used. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( - model, graph, module_path + prefix + key, value_or_node + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, ) quantize_op_inputs.append(qparam_node) else: @@ -407,6 +414,7 @@ def _replace_observer_with_quantize_dequantize_node( modules: dict[str, torch.nn.Module], node_name_to_scope: dict[str, tuple[str, type]], node_name_to_qconfig: dict[str, QConfigAny], + model_device: Optional[torch.device] = None, ) -> None: """Replace activation_post_process module call node with quantize and dequantize node @@ -487,7 +495,11 @@ def _replace_observer_with_quantize_dequantize_node( # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( - model, graph, module_path + prefix + key, value_or_node + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, ) quantize_op_inputs.append(qparam_node) else: @@ -785,6 +797,7 @@ def convert_weighted_module( backend_config: BackendConfig, is_decomposed: bool = False, is_reference: bool = False, + model_device: Optional[torch.device] = None, ) -> None: """Convert a weighted module to reference quantized module in the model If the QConfig of a QAT module is not set, the module will still be converted to @@ -873,7 +886,10 @@ def convert_weighted_module( is_ptq = weight_post_process is None if is_ptq: weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] - device = assert_and_get_unique_device(float_module) + if model_device is not None: + device = model_device + else: + device = _assert_and_get_unique_device(float_module) if device: weight_post_process.to(device) @@ -1076,6 +1092,7 @@ def convert( root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) qat_module_classes = get_qat_module_classes(backend_config) fused_module_classes = get_fused_module_classes(backend_config) + model_device = _assert_and_get_unique_device(model) for node in list(model.graph.nodes): if node.op == "placeholder": @@ -1123,6 +1140,7 @@ def convert( modules, node_name_to_scope, node_name_to_qconfig, + model_device, ) else: _replace_observer_with_quantize_dequantize_node( @@ -1131,6 +1149,7 @@ def convert( modules, node_name_to_scope, node_name_to_qconfig, + model_device, ) elif isinstance(mod, DeQuantStub): _replace_observer_or_dequant_stub_with_dequantize_node( @@ -1160,6 +1179,7 @@ def convert( backend_config, is_decomposed, is_reference, + model_device, ) # remove deadcode after converting observers to quant/dequant ops diff --git a/torchao/quantization/pt2e/observer.py b/torchao/quantization/pt2e/observer.py index 4115040669..f89e3c5b1a 100644 --- a/torchao/quantization/pt2e/observer.py +++ b/torchao/quantization/pt2e/observer.py @@ -1915,10 +1915,18 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): else: scale, zero_point = self.calculate_qparams() scale_node = create_getattr_from_value( - model, model.graph, "_scale", scale + model, + model.graph, + "_scale", + scale, + scale.device, ) zero_point_node = create_getattr_from_value( - model, model.graph, "_zero_point", zero_point + model, + model.graph, + "_zero_point", + zero_point, + zero_point.device, ) q_node = model.graph.call_function( diff --git a/torchao/quantization/pt2e/prepare.py b/torchao/quantization/pt2e/prepare.py index d8f5b99fc5..7e3bf06443 100644 --- a/torchao/quantization/pt2e/prepare.py +++ b/torchao/quantization/pt2e/prepare.py @@ -38,7 +38,7 @@ SharedQuantizationSpec, ) from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, _assert_and_get_unique_device # TODO: make pt2e folder private? __all__ = [ @@ -409,6 +409,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( named_modules: dict[str, torch.nn.Module], obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> Argument: """ Given a `node` and an `arg`, inserts an input observer between @@ -427,6 +428,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( named_modules, obs_or_fq_map, is_qat, + model_device, ) new_arg_to_return.append(new_inner_arg) return type(arg)(new_arg_to_return) @@ -479,6 +481,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( return maybe_obs_node assert isinstance(model.graph, Graph) + # TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901 new_arg = _insert_obs_or_fq( arg, input_edge_obs_or_fq, model, named_modules, model.graph ) @@ -492,6 +495,7 @@ def _maybe_insert_input_observers_for_node( named_modules: dict[str, torch.nn.Module], obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> None: """ If needed, inserts observers to the input args and kwargs of `node`. @@ -518,6 +522,7 @@ def _maybe_insert_input_observers_for_node( named_modules, obs_or_fq_map, is_qat, + model_device, ) new_args.append(new_arg) @@ -542,9 +547,11 @@ def _maybe_insert_output_observer_for_node( graph: Graph, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> Optional[Node]: if node in obs_or_fq_map: output_act_obs_or_fq = obs_or_fq_map[node] + # TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901 new_output = _insert_obs_or_fq( node, output_act_obs_or_fq, model, named_modules, graph ) @@ -565,6 +572,7 @@ def _maybe_insert_input_and_output_observers_for_node( model: torch.fx.GraphModule, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ): this_node_quantization_annotation = ( node.meta[Q_ANNOTATION_KEY] if Q_ANNOTATION_KEY in node.meta else None @@ -580,6 +588,7 @@ def _maybe_insert_input_and_output_observers_for_node( named_modules, obs_or_fq_map, is_qat, + model_device, ) output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor) @@ -588,7 +597,13 @@ def _maybe_insert_input_and_output_observers_for_node( # this returns the new observer node if it was needed maybe_output_obs_node = _maybe_insert_output_observer_for_node( - node, model, named_modules, model.graph, obs_or_fq_map, is_qat + node, + model, + named_modules, + model.graph, + obs_or_fq_map, + is_qat, + model_device, ) if maybe_output_obs_node is None: @@ -636,11 +651,16 @@ def prepare( ) if obs_or_fq_callback: obs_or_fq_callback(model, obs_or_fq_map) + model_device = _assert_and_get_unique_device(model) for node in nodes_before_observation: # TODO: simplify logic for inserting observers _maybe_insert_input_and_output_observers_for_node( - node, model, obs_or_fq_map, is_qat + node, + model, + obs_or_fq_map, + is_qat, + model_device, ) model = GraphModule(model, model.graph) diff --git a/torchao/quantization/pt2e/utils.py b/torchao/quantization/pt2e/utils.py index dc5f802fb8..3e1e77e506 100644 --- a/torchao/quantization/pt2e/utils.py +++ b/torchao/quantization/pt2e/utils.py @@ -525,7 +525,11 @@ def get_attr_name(i: int): def create_getattr_from_value( - module: torch.nn.Module, graph: Graph, prefix: str, value: Any + module: torch.nn.Module, + graph: Graph, + prefix: str, + value: Any, + device: Optional[torch.device] = None, ) -> Node: """ Given a value of any type, creates a getattr node corresponding to the value and @@ -533,7 +537,8 @@ def create_getattr_from_value( """ get_new_attr_name = get_new_attr_name_with_prefix(prefix) attr_name = get_new_attr_name(module) - device = _assert_and_get_unique_device(module) + if device is None: + device = _assert_and_get_unique_device(module) new_value = ( value.detach().clone() if isinstance(value, torch.Tensor)