diff --git a/torchao/quantization/pt2e/convert.py b/torchao/quantization/pt2e/convert.py index 99516ac4c3..0b74341131 100644 --- a/torchao/quantization/pt2e/convert.py +++ b/torchao/quantization/pt2e/convert.py @@ -49,9 +49,7 @@ ) from torch.ao.quantization.fx.utils import ( _get_module, - assert_and_get_unique_device, collect_producer_nodes, - create_getattr_from_value, graph_module_from_producer_nodes, node_arg_is_weight, ) @@ -73,7 +71,11 @@ from torchao.quantization.pt2e import FROM_NODE_KEY from torchao.quantization.pt2e.observer import _is_activation_post_process -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 +from torchao.quantization.pt2e.utils import create_getattr_from_value +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, + _assert_and_get_unique_device, +) if TORCH_VERSION_AT_LEAST_2_6: from torch.fx.traceback import NodeSource, NodeSourceAction @@ -132,6 +134,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed( modules: dict[str, torch.nn.Module], node_name_to_scope: dict[str, tuple[str, type]], node_name_to_qconfig: dict[str, QConfigAny], + model_device: Optional[torch.device] = None, ) -> None: """Replace activation_post_process module call node with quantize and dequantize node working with decomposed Tensor @@ -260,7 +263,11 @@ def add_quantize_dequantize_node_info(qdq_node, original_node): # sure that the default overload can be used. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( - model, graph, module_path + prefix + key, value_or_node + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, ) quantize_op_inputs.append(qparam_node) else: @@ -407,6 +414,7 @@ def _replace_observer_with_quantize_dequantize_node( modules: dict[str, torch.nn.Module], node_name_to_scope: dict[str, tuple[str, type]], node_name_to_qconfig: dict[str, QConfigAny], + model_device: Optional[torch.device] = None, ) -> None: """Replace activation_post_process module call node with quantize and dequantize node @@ -487,7 +495,11 @@ def _replace_observer_with_quantize_dequantize_node( # For scale and zero_point values we register them as buffers in the root module. # TODO: maybe need more complex attr name here qparam_node = create_getattr_from_value( - model, graph, module_path + prefix + key, value_or_node + model, + graph, + module_path + prefix + key, + value_or_node, + model_device, ) quantize_op_inputs.append(qparam_node) else: @@ -785,6 +797,7 @@ def convert_weighted_module( backend_config: BackendConfig, is_decomposed: bool = False, is_reference: bool = False, + model_device: Optional[torch.device] = None, ) -> None: """Convert a weighted module to reference quantized module in the model If the QConfig of a QAT module is not set, the module will still be converted to @@ -873,7 +886,10 @@ def convert_weighted_module( is_ptq = weight_post_process is None if is_ptq: weight_post_process = qconfig.weight() # type: ignore[union-attr, operator] - device = assert_and_get_unique_device(float_module) + if model_device is not None: + device = model_device + else: + device = _assert_and_get_unique_device(float_module) if device: weight_post_process.to(device) @@ -1076,6 +1092,7 @@ def convert( root_module_classes = tuple(root_module_to_quantized_reference_module.keys()) qat_module_classes = get_qat_module_classes(backend_config) fused_module_classes = get_fused_module_classes(backend_config) + model_device = _assert_and_get_unique_device(model) for node in list(model.graph.nodes): if node.op == "placeholder": @@ -1123,6 +1140,7 @@ def convert( modules, node_name_to_scope, node_name_to_qconfig, + model_device, ) else: _replace_observer_with_quantize_dequantize_node( @@ -1131,6 +1149,7 @@ def convert( modules, node_name_to_scope, node_name_to_qconfig, + model_device, ) elif isinstance(mod, DeQuantStub): _replace_observer_or_dequant_stub_with_dequantize_node( @@ -1160,6 +1179,7 @@ def convert( backend_config, is_decomposed, is_reference, + model_device, ) # remove deadcode after converting observers to quant/dequant ops diff --git a/torchao/quantization/pt2e/observer.py b/torchao/quantization/pt2e/observer.py index 4115040669..f89e3c5b1a 100644 --- a/torchao/quantization/pt2e/observer.py +++ b/torchao/quantization/pt2e/observer.py @@ -1915,10 +1915,18 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node): else: scale, zero_point = self.calculate_qparams() scale_node = create_getattr_from_value( - model, model.graph, "_scale", scale + model, + model.graph, + "_scale", + scale, + scale.device, ) zero_point_node = create_getattr_from_value( - model, model.graph, "_zero_point", zero_point + model, + model.graph, + "_zero_point", + zero_point, + zero_point.device, ) q_node = model.graph.call_function( diff --git a/torchao/quantization/pt2e/prepare.py b/torchao/quantization/pt2e/prepare.py index d8f5b99fc5..7e3bf06443 100644 --- a/torchao/quantization/pt2e/prepare.py +++ b/torchao/quantization/pt2e/prepare.py @@ -38,7 +38,7 @@ SharedQuantizationSpec, ) from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, _assert_and_get_unique_device # TODO: make pt2e folder private? __all__ = [ @@ -409,6 +409,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( named_modules: dict[str, torch.nn.Module], obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> Argument: """ Given a `node` and an `arg`, inserts an input observer between @@ -427,6 +428,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( named_modules, obs_or_fq_map, is_qat, + model_device, ) new_arg_to_return.append(new_inner_arg) return type(arg)(new_arg_to_return) @@ -479,6 +481,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( return maybe_obs_node assert isinstance(model.graph, Graph) + # TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901 new_arg = _insert_obs_or_fq( arg, input_edge_obs_or_fq, model, named_modules, model.graph ) @@ -492,6 +495,7 @@ def _maybe_insert_input_observers_for_node( named_modules: dict[str, torch.nn.Module], obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> None: """ If needed, inserts observers to the input args and kwargs of `node`. @@ -518,6 +522,7 @@ def _maybe_insert_input_observers_for_node( named_modules, obs_or_fq_map, is_qat, + model_device, ) new_args.append(new_arg) @@ -542,9 +547,11 @@ def _maybe_insert_output_observer_for_node( graph: Graph, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ) -> Optional[Node]: if node in obs_or_fq_map: output_act_obs_or_fq = obs_or_fq_map[node] + # TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901 new_output = _insert_obs_or_fq( node, output_act_obs_or_fq, model, named_modules, graph ) @@ -565,6 +572,7 @@ def _maybe_insert_input_and_output_observers_for_node( model: torch.fx.GraphModule, obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize], is_qat: bool, + model_device: Optional[torch.device] = None, ): this_node_quantization_annotation = ( node.meta[Q_ANNOTATION_KEY] if Q_ANNOTATION_KEY in node.meta else None @@ -580,6 +588,7 @@ def _maybe_insert_input_and_output_observers_for_node( named_modules, obs_or_fq_map, is_qat, + model_device, ) output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor) @@ -588,7 +597,13 @@ def _maybe_insert_input_and_output_observers_for_node( # this returns the new observer node if it was needed maybe_output_obs_node = _maybe_insert_output_observer_for_node( - node, model, named_modules, model.graph, obs_or_fq_map, is_qat + node, + model, + named_modules, + model.graph, + obs_or_fq_map, + is_qat, + model_device, ) if maybe_output_obs_node is None: @@ -636,11 +651,16 @@ def prepare( ) if obs_or_fq_callback: obs_or_fq_callback(model, obs_or_fq_map) + model_device = _assert_and_get_unique_device(model) for node in nodes_before_observation: # TODO: simplify logic for inserting observers _maybe_insert_input_and_output_observers_for_node( - node, model, obs_or_fq_map, is_qat + node, + model, + obs_or_fq_map, + is_qat, + model_device, ) model = GraphModule(model, model.graph) diff --git a/torchao/quantization/pt2e/utils.py b/torchao/quantization/pt2e/utils.py index dc5f802fb8..3e1e77e506 100644 --- a/torchao/quantization/pt2e/utils.py +++ b/torchao/quantization/pt2e/utils.py @@ -525,7 +525,11 @@ def get_attr_name(i: int): def create_getattr_from_value( - module: torch.nn.Module, graph: Graph, prefix: str, value: Any + module: torch.nn.Module, + graph: Graph, + prefix: str, + value: Any, + device: Optional[torch.device] = None, ) -> Node: """ Given a value of any type, creates a getattr node corresponding to the value and @@ -533,7 +537,8 @@ def create_getattr_from_value( """ get_new_attr_name = get_new_attr_name_with_prefix(prefix) attr_name = get_new_attr_name(module) - device = _assert_and_get_unique_device(module) + if device is None: + device = _assert_and_get_unique_device(module) new_value = ( value.detach().clone() if isinstance(value, torch.Tensor)