Skip to content

Commit 4040192

Browse files
committed
[pt2e] Avoid getting model device once per node
**Summary:** Previously, we call `assert_and_get_unqiue_device` once per node in both prepare and convert. This is expensive and unnecessary since the model device is the same across all nodes, so we should just call this once in the beginning and reuse the same model device across all the nodes. torchao version of pytorch/pytorch#159901 Note: The prepare path is not completely done yet, since we are blocked on the pytorch PR on being merged. It's different from convert since it still calls utility functions from `torch.ao.quantization.fx`. **Test Plan:** ``` python test/quantization/pt2e/test_quantize_pt2e.py ```
1 parent 418593c commit 4040192

File tree

4 files changed

+68
-13
lines changed

4 files changed

+68
-13
lines changed

torchao/quantization/pt2e/convert.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@
4949
)
5050
from torch.ao.quantization.fx.utils import (
5151
_get_module,
52-
assert_and_get_unique_device,
5352
collect_producer_nodes,
54-
create_getattr_from_value,
5553
graph_module_from_producer_nodes,
5654
node_arg_is_weight,
5755
)
@@ -73,7 +71,11 @@
7371

7472
from torchao.quantization.pt2e import FROM_NODE_KEY
7573
from torchao.quantization.pt2e.observer import _is_activation_post_process
76-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
74+
from torchao.quantization.pt2e.utils import create_getattr_from_value
75+
from torchao.utils import (
76+
TORCH_VERSION_AT_LEAST_2_6,
77+
_assert_and_get_unique_device,
78+
)
7779

7880
if TORCH_VERSION_AT_LEAST_2_6:
7981
from torch.fx.traceback import NodeSource, NodeSourceAction
@@ -132,6 +134,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
132134
modules: dict[str, torch.nn.Module],
133135
node_name_to_scope: dict[str, tuple[str, type]],
134136
node_name_to_qconfig: dict[str, QConfigAny],
137+
model_device: Optional[torch.device] = None,
135138
) -> None:
136139
"""Replace activation_post_process module call node with quantize and
137140
dequantize node working with decomposed Tensor
@@ -260,7 +263,11 @@ def add_quantize_dequantize_node_info(qdq_node, original_node):
260263
# sure that the default overload can be used.
261264
# TODO: maybe need more complex attr name here
262265
qparam_node = create_getattr_from_value(
263-
model, graph, module_path + prefix + key, value_or_node
266+
model,
267+
graph,
268+
module_path + prefix + key,
269+
value_or_node,
270+
model_device,
264271
)
265272
quantize_op_inputs.append(qparam_node)
266273
else:
@@ -407,6 +414,7 @@ def _replace_observer_with_quantize_dequantize_node(
407414
modules: dict[str, torch.nn.Module],
408415
node_name_to_scope: dict[str, tuple[str, type]],
409416
node_name_to_qconfig: dict[str, QConfigAny],
417+
model_device: Optional[torch.device] = None,
410418
) -> None:
411419
"""Replace activation_post_process module call node with quantize and
412420
dequantize node
@@ -487,7 +495,11 @@ def _replace_observer_with_quantize_dequantize_node(
487495
# For scale and zero_point values we register them as buffers in the root module.
488496
# TODO: maybe need more complex attr name here
489497
qparam_node = create_getattr_from_value(
490-
model, graph, module_path + prefix + key, value_or_node
498+
model,
499+
graph,
500+
module_path + prefix + key,
501+
value_or_node,
502+
model_device,
491503
)
492504
quantize_op_inputs.append(qparam_node)
493505
else:
@@ -785,6 +797,7 @@ def convert_weighted_module(
785797
backend_config: BackendConfig,
786798
is_decomposed: bool = False,
787799
is_reference: bool = False,
800+
model_device: Optional[torch.device] = None,
788801
) -> None:
789802
"""Convert a weighted module to reference quantized module in the model
790803
If the QConfig of a QAT module is not set, the module will still be converted to
@@ -873,7 +886,10 @@ def convert_weighted_module(
873886
is_ptq = weight_post_process is None
874887
if is_ptq:
875888
weight_post_process = qconfig.weight() # type: ignore[union-attr, operator]
876-
device = assert_and_get_unique_device(float_module)
889+
if model_device is not None:
890+
device = model_device
891+
else:
892+
device = _assert_and_get_unique_device(float_module)
877893
if device:
878894
weight_post_process.to(device)
879895

@@ -1076,6 +1092,7 @@ def convert(
10761092
root_module_classes = tuple(root_module_to_quantized_reference_module.keys())
10771093
qat_module_classes = get_qat_module_classes(backend_config)
10781094
fused_module_classes = get_fused_module_classes(backend_config)
1095+
model_device = _assert_and_get_unique_device(model)
10791096

10801097
for node in list(model.graph.nodes):
10811098
if node.op == "placeholder":
@@ -1123,6 +1140,7 @@ def convert(
11231140
modules,
11241141
node_name_to_scope,
11251142
node_name_to_qconfig,
1143+
model_device,
11261144
)
11271145
else:
11281146
_replace_observer_with_quantize_dequantize_node(
@@ -1131,6 +1149,7 @@ def convert(
11311149
modules,
11321150
node_name_to_scope,
11331151
node_name_to_qconfig,
1152+
model_device,
11341153
)
11351154
elif isinstance(mod, DeQuantStub):
11361155
_replace_observer_or_dequant_stub_with_dequantize_node(
@@ -1160,6 +1179,7 @@ def convert(
11601179
backend_config,
11611180
is_decomposed,
11621181
is_reference,
1182+
model_device,
11631183
)
11641184

11651185
# remove deadcode after converting observers to quant/dequant ops

torchao/quantization/pt2e/observer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,7 +1885,9 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node):
18851885
)
18861886

18871887
from torchao.quantization.pt2e.utils import create_getattr_from_value
1888+
from torchao.utils import _assert_and_get_unique_device
18881889

1890+
model_device = _assert_and_get_unique_device(model)
18891891
with model.graph.inserting_before(observer_node):
18901892
assert self.block_size is not None, "Expecting block_size to be populated"
18911893
assert self.original_dtype is not None, (
@@ -1915,10 +1917,18 @@ def convert(self, model: torch.fx.GraphModule, observer_node: Node):
19151917
else:
19161918
scale, zero_point = self.calculate_qparams()
19171919
scale_node = create_getattr_from_value(
1918-
model, model.graph, "_scale", scale
1920+
model,
1921+
model.graph,
1922+
"_scale",
1923+
scale,
1924+
model_device,
19191925
)
19201926
zero_point_node = create_getattr_from_value(
1921-
model, model.graph, "_zero_point", zero_point
1927+
model,
1928+
model.graph,
1929+
"_zero_point",
1930+
zero_point,
1931+
model_device,
19221932
)
19231933

19241934
q_node = model.graph.call_function(

torchao/quantization/pt2e/prepare.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
SharedQuantizationSpec,
3939
)
4040
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
41-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
41+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, _assert_and_get_unique_device
4242

4343
# TODO: make pt2e folder private?
4444
__all__ = [
@@ -409,6 +409,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
409409
named_modules: dict[str, torch.nn.Module],
410410
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
411411
is_qat: bool,
412+
model_device: Optional[torch.device] = None,
412413
) -> Argument:
413414
"""
414415
Given a `node` and an `arg`, inserts an input observer between
@@ -427,6 +428,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
427428
named_modules,
428429
obs_or_fq_map,
429430
is_qat,
431+
model_device,
430432
)
431433
new_arg_to_return.append(new_inner_arg)
432434
return type(arg)(new_arg_to_return)
@@ -479,6 +481,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
479481
return maybe_obs_node
480482

481483
assert isinstance(model.graph, Graph)
484+
# TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901
482485
new_arg = _insert_obs_or_fq(
483486
arg, input_edge_obs_or_fq, model, named_modules, model.graph
484487
)
@@ -492,6 +495,7 @@ def _maybe_insert_input_observers_for_node(
492495
named_modules: dict[str, torch.nn.Module],
493496
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
494497
is_qat: bool,
498+
model_device: Optional[torch.device] = None,
495499
) -> None:
496500
"""
497501
If needed, inserts observers to the input args and kwargs of `node`.
@@ -518,6 +522,7 @@ def _maybe_insert_input_observers_for_node(
518522
named_modules,
519523
obs_or_fq_map,
520524
is_qat,
525+
model_device,
521526
)
522527
new_args.append(new_arg)
523528

@@ -542,9 +547,11 @@ def _maybe_insert_output_observer_for_node(
542547
graph: Graph,
543548
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
544549
is_qat: bool,
550+
model_device: Optional[torch.device] = None,
545551
) -> Optional[Node]:
546552
if node in obs_or_fq_map:
547553
output_act_obs_or_fq = obs_or_fq_map[node]
554+
# TODO: pass in model_device here after https://github.com/pytorch/pytorch/pull/159901
548555
new_output = _insert_obs_or_fq(
549556
node, output_act_obs_or_fq, model, named_modules, graph
550557
)
@@ -565,6 +572,7 @@ def _maybe_insert_input_and_output_observers_for_node(
565572
model: torch.fx.GraphModule,
566573
obs_or_fq_map: dict[EdgeOrNode, ObserverOrFakeQuantize],
567574
is_qat: bool,
575+
model_device: Optional[torch.device] = None,
568576
):
569577
this_node_quantization_annotation = (
570578
node.meta[Q_ANNOTATION_KEY] if Q_ANNOTATION_KEY in node.meta else None
@@ -580,6 +588,7 @@ def _maybe_insert_input_and_output_observers_for_node(
580588
named_modules,
581589
obs_or_fq_map,
582590
is_qat,
591+
model_device,
583592
)
584593

585594
output_is_a_tensor = "val" in node.meta and isinstance(node.meta["val"], FakeTensor)
@@ -588,7 +597,13 @@ def _maybe_insert_input_and_output_observers_for_node(
588597

589598
# this returns the new observer node if it was needed
590599
maybe_output_obs_node = _maybe_insert_output_observer_for_node(
591-
node, model, named_modules, model.graph, obs_or_fq_map, is_qat
600+
node,
601+
model,
602+
named_modules,
603+
model.graph,
604+
obs_or_fq_map,
605+
is_qat,
606+
model_device,
592607
)
593608

594609
if maybe_output_obs_node is None:
@@ -636,11 +651,16 @@ def prepare(
636651
)
637652
if obs_or_fq_callback:
638653
obs_or_fq_callback(model, obs_or_fq_map)
654+
model_device = _assert_and_get_unique_device(model)
639655

640656
for node in nodes_before_observation:
641657
# TODO: simplify logic for inserting observers
642658
_maybe_insert_input_and_output_observers_for_node(
643-
node, model, obs_or_fq_map, is_qat
659+
node,
660+
model,
661+
obs_or_fq_map,
662+
is_qat,
663+
model_device,
644664
)
645665

646666
model = GraphModule(model, model.graph)

torchao/quantization/pt2e/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,15 +525,20 @@ def get_attr_name(i: int):
525525

526526

527527
def create_getattr_from_value(
528-
module: torch.nn.Module, graph: Graph, prefix: str, value: Any
528+
module: torch.nn.Module,
529+
graph: Graph,
530+
prefix: str,
531+
value: Any,
532+
device: Optional[torch.device] = None,
529533
) -> Node:
530534
"""
531535
Given a value of any type, creates a getattr node corresponding to the value and
532536
registers the value as a buffer to the module.
533537
"""
534538
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
535539
attr_name = get_new_attr_name(module)
536-
device = _assert_and_get_unique_device(module)
540+
if device is None:
541+
device = _assert_and_get_unique_device(module)
537542
new_value = (
538543
value.detach().clone()
539544
if isinstance(value, torch.Tensor)

0 commit comments

Comments
 (0)