Skip to content

Commit 4a08e16

Browse files
committed
[pt2e] Avoid getting model device once per node
**Summary:** Previously, we call `assert_and_get_unqiue_device` once per node in convert. This is expensive and unnecessary since the model device is the same across all nodes, so we should just call this once in the beginning and reuse the same model device across all the nodes. torchao version of pytorch/pytorch#159901 **Test Plan:** ``` python test/quantization/pt2e/test_quantize_pt2e.py ```
1 parent 418593c commit 4a08e16

File tree

3 files changed

+45
-10
lines changed

3 files changed

+45
-10
lines changed

torchao/quantization/pt2e/convert.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@
4949
)
5050
from torch.ao.quantization.fx.utils import (
5151
_get_module,
52-
assert_and_get_unique_device,
5352
collect_producer_nodes,
54-
create_getattr_from_value,
5553
graph_module_from_producer_nodes,
5654
node_arg_is_weight,
5755
)
@@ -73,7 +71,11 @@
7371

7472
from torchao.quantization.pt2e import FROM_NODE_KEY
7573
from torchao.quantization.pt2e.observer import _is_activation_post_process
76-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
74+
from torchao.quantization.pt2e.utils import create_getattr_from_value
75+
from torchao.utils import (
76+
TORCH_VERSION_AT_LEAST_2_6,
77+
_assert_and_get_unique_device,
78+
)
7779

7880
if TORCH_VERSION_AT_LEAST_2_6:
7981
from torch.fx.traceback import NodeSource, NodeSourceAction
@@ -132,6 +134,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
132134
modules: dict[str, torch.nn.Module],
133135
node_name_to_scope: dict[str, tuple[str, type]],
134136
node_name_to_qconfig: dict[str, QConfigAny],
137+
model_device: Optional[torch.device] = None,
135138
) -> None:
136139
"""Replace activation_post_process module call node with quantize and
137140
dequantize node working with decomposed Tensor
@@ -260,7 +263,11 @@ def add_quantize_dequantize_node_info(qdq_node, original_node):
260263
# sure that the default overload can be used.
261264
# TODO: maybe need more complex attr name here
262265
qparam_node = create_getattr_from_value(
263-
model, graph, module_path + prefix + key, value_or_node
266+
model,
267+
graph,
268+
module_path + prefix + key,
269+
value_or_node,
270+
model_device,
264271
)
265272
quantize_op_inputs.append(qparam_node)
266273
else:
@@ -407,6 +414,7 @@ def _replace_observer_with_quantize_dequantize_node(
407414
modules: dict[str, torch.nn.Module],
408415
node_name_to_scope: dict[str, tuple[str, type]],
409416
node_name_to_qconfig: dict[str, QConfigAny],
417+
model_device: Optional[torch.device] = None,
410418
) -> None:
411419
"""Replace activation_post_process module call node with quantize and
412420
dequantize node
@@ -487,7 +495,11 @@ def _replace_observer_with_quantize_dequantize_node(
487495
# For scale and zero_point values we register them as buffers in the root module.
488496
# TODO: maybe need more complex attr name here
489497
qparam_node = create_getattr_from_value(
490-
model, graph, module_path + prefix + key, value_or_node
498+
model,
499+
graph,
500+
module_path + prefix + key,
501+
value_or_node,
502+
model_device,
491503
)
492504
quantize_op_inputs.append(qparam_node)
493505
else:
@@ -785,6 +797,7 @@ def convert_weighted_module(
785797
backend_config: BackendConfig,
786798
is_decomposed: bool = False,
787799
is_reference: bool = False,
800+
model_device: Optional[torch.device] = None,
788801
) -> None:
789802
"""Convert a weighted module to reference quantized module in the model
790803
If the QConfig of a QAT module is not set, the module will still be converted to
@@ -873,7 +886,10 @@ def convert_weighted_module(
873886
is_ptq = weight_post_process is None
874887
if is_ptq:
875888
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
876-
device = assert_and_get_unique_device(float_module)
889+
if model_device is not None:
890+
device = model_device
891+
else:
892+
device = _assert_and_get_unique_device(float_module)
877893
if device:
878894
weight_post_process.to(device)
879895

@@ -1076,6 +1092,7 @@ def convert(
10761092
root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
10771093
qat_module_classes = get_qat_module_classes(backend_config)
10781094
fused_module_classes = get_fused_module_classes(backend_config)
1095+
model_device = _assert_and_get_unique_device(model)
10791096

10801097
for node in list(model.graph.nodes):
10811098
if node.op == "placeholder":
@@ -1123,6 +1140,7 @@ def convert(
11231140
modules,
11241141
node_name_to_scope,
11251142
node_name_to_qconfig,
1143+
model_device,
11261144
)
11271145
else:
11281146
_replace_observer_with_quantize_dequantize_node(
@@ -1131,6 +1149,7 @@ def convert(
11311149
modules,
11321150
node_name_to_scope,
11331151
node_name_to_qconfig,
1152+
model_device,
11341153
)
11351154
elif isinstance(mod, DeQuantStub):
11361155
_replace_observer_or_dequant_stub_with_dequantize_node(
@@ -1160,6 +1179,7 @@ def convert(
11601179
backend_config,
11611180
is_decomposed,
11621181
is_reference,
1182+
model_device,
11631183
)
11641184

11651185
# remove deadcode after converting observers to quant/dequant ops

torchao/quantization/pt2e/observer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,7 +1885,9 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node):
18851885
)
18861886

18871887
from torchao.quantization.pt2e.utils import create_getattr_from_value
1888+
from torchao.utils import _assert_and_get_unique_device
18881889

1890+
model_device = _assert_and_get_unique_device(model)
18891891
with model.graph.inserting_before(observer_node):
18901892
assert self.block_size is not None, "Expecting block_size to be populated"
18911893
assert self.original_dtype is not None, (
@@ -1915,10 +1917,18 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node):
19151917
else:
19161918
scale, zero_point = self.calculate_qparams()
19171919
scale_node = create_getattr_from_value(
1918-
model, model.graph, "_scale", scale
1920+
model,
1921+
model.graph,
1922+
"_scale",
1923+
scale,
1924+
model_device,
19191925
)
19201926
zero_point_node = create_getattr_from_value(
1921-
model, model.graph, "_zero_point", zero_point
1927+
model,
1928+
model.graph,
1929+
"_zero_point",
1930+
zero_point,
1931+
model_device,
19221932
)
19231933

19241934
q_node = model.graph.call_function(

torchao/quantization/pt2e/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,15 +525,20 @@ def get_attr_name(i: int):
525525

526526

527527
def create_getattr_from_value(
528-
module: torch.nn.Module, graph: Graph, prefix: str, value: Any
528+
module: torch.nn.Module,
529+
graph: Graph,
530+
prefix: str,
531+
value: Any,
532+
device: Optional[torch.device] = None,
529533
) -> Node:
530534
"""
531535
Given a value of any type, creates a getattr node corresponding to the value and
532536
registers the value as a buffer to the module.
533537
"""
534538
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
535539
attr_name = get_new_attr_name(module)
536-
device = _assert_and_get_unique_device(module)
540+
if device is None:
541+
device = _assert_and_get_unique_device(module)
537542
new_value = (
538543
value.detach().clone()
539544
if isinstance(value, torch.Tensor)

0 commit comments

Comments
 (0)