Skip to content

Commit 8a43fe7

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
Replace debug handle with from_node to trace operator transformation (pytorch#11532)
Summary: Pull Request resolved: pytorch#11532 X-link: pytorch/ao#2339 This diff replace the debug handle with `from_node` infrastructure, which is a first class citizen in exported program and used to trace the node-level transformation by recording every ancestor of given node. N6213836 is a demonstration of how `from_node` infra records the node transformation after unlifting and re-exporting exported graph. For simplify the progress, we are trying to reuse the debug handle infrastructure by generating debug handle with hashing their greatest ancestor's node. After this change user no longer need to invoke `generate_numeric_debug_handle` for debugging. Also the original pipeline will still work under current scenario. Reviewed By: jerryzh168 Differential Revision: D76168997
1 parent db96aba commit 8a43fe7

File tree

1 file changed

+64
-48
lines changed

1 file changed

+64
-48
lines changed

backends/xnnpack/test/quantizer/test_pt2e_quantization.py

Lines changed: 64 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-unsafe
88

99
from collections import Counter
10-
from typing import Dict, Tuple
10+
from typing import Tuple
1111

1212
import torch
1313
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
@@ -36,13 +36,13 @@
3636
from torchao.quantization.pt2e import (
3737
allow_exported_model_train_eval,
3838
compare_results,
39-
CUSTOM_KEY,
4039
extract_results_from_loggers,
41-
generate_numeric_debug_handle,
42-
NUMERIC_DEBUG_HANDLE_KEY,
40+
FROM_NODE_KEY,
4341
prepare_for_propagation_comparison,
4442
)
4543

44+
from torchao.quantization.pt2e._numeric_debugger import _generate_debug_handle_from_node
45+
4646
from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
4747
from torchao.quantization.pt2e.quantize_pt2e import (
4848
convert_pt2e,
@@ -723,47 +723,72 @@ def test_save_load(self) -> None:
723723
instantiate_parametrized_tests(TestQuantizePT2E)
724724

725725

726+
# TODO: deduplicate with TestNumericDebugger under torchao
726727
class TestNumericDebugger(TestCase):
727-
def _extract_debug_handles(self, model) -> Dict[str, int]:
728-
debug_handle_map: Dict[str, int] = {}
728+
def _assert_each_node_has_debug_handle(self, model) -> None:
729+
def _assert_node_has_debug_handle(node):
730+
self.assertIn(
731+
FROM_NODE_KEY,
732+
node.meta,
733+
f"Node {node} doesn't have from_node info",
734+
)
735+
736+
bfs_trace_with_node_process(model, _assert_node_has_debug_handle)
737+
738+
def _extract_debug_handles(self, model) -> dict[str, int]:
739+
debug_handle_map: dict[str, int] = {}
729740

730-
def _extract_debug_handles_from_node(node: torch.fx.Node) -> None:
741+
def _extract_debug_handles_from_node(node):
731742
nonlocal debug_handle_map
732-
if (
733-
CUSTOM_KEY in node.meta
734-
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
735-
):
736-
debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
737-
NUMERIC_DEBUG_HANDLE_KEY
738-
]
743+
if (dh := _generate_debug_handle_from_node(node)) is not None:
744+
debug_handle_map[str(node)] = dh
739745

740746
bfs_trace_with_node_process(model, _extract_debug_handles_from_node)
747+
741748
return debug_handle_map
742749

743-
def _assert_each_node_has_debug_handle(self, model) -> None:
744-
def _assert_node_has_debug_handle(node: torch.fx.Node) -> None:
745-
self.assertTrue(
746-
CUSTOM_KEY in node.meta
747-
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY],
748-
f"Node {node} doesn't have debug handle",
749-
)
750+
def _extract_debug_handles_with_prev_decomp_op(self, model) -> dict[str, int]:
751+
prev_decomp_op_to_debug_handle_map: dict[str, int] = {}
750752

751-
bfs_trace_with_node_process(model, _assert_node_has_debug_handle)
753+
def _extract_debug_handles_with_prev_decomp_op_from_node(node):
754+
nonlocal prev_decomp_op_to_debug_handle_map
755+
if FROM_NODE_KEY in node.meta:
756+
prev_decomp_op = str(node.meta.get("nn_module_stack"))
757+
debug_handle = _generate_debug_handle_from_node(node)
758+
if prev_decomp_op not in prev_decomp_op_to_debug_handle_map:
759+
prev_decomp_op_to_debug_handle_map[prev_decomp_op] = debug_handle
760+
else:
761+
assert (
762+
prev_decomp_op_to_debug_handle_map[prev_decomp_op]
763+
== debug_handle
764+
), f"Node {node} has different debug handle {debug_handle}"
765+
"than previous node sharing the same decomp op {prev_decomp_op}"
766+
767+
bfs_trace_with_node_process(
768+
model, _extract_debug_handles_with_prev_decomp_op_from_node
769+
)
770+
return prev_decomp_op_to_debug_handle_map
752771

753-
def test_quantize_pt2e_preserve_handle(self) -> None:
772+
def test_quantize_pt2e_preserve_handle(self):
754773
m = TestHelperModules.Conv2dThenConv1d()
755774
example_inputs = m.example_inputs()
756775
ep = export_for_training(m, example_inputs, strict=True)
757-
generate_numeric_debug_handle(ep)
758776
m = ep.module()
759777

760778
quantizer = XNNPACKQuantizer().set_global(
761779
get_symmetric_quantization_config(is_per_channel=False)
762780
)
763-
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
781+
m = prepare_pt2e(m, quantizer)
764782
debug_handle_map = self._extract_debug_handles(m)
783+
node_name_equip_with_output_observer = [
784+
"conv2d",
785+
"conv1d",
786+
"squeeze",
787+
]
765788
res_counter = Counter(debug_handle_map.values())
766-
repeated_debug_handle_ids = [1, 2, 3]
789+
repeated_debug_handle_ids = [
790+
debug_handle_map[n_name] for n_name in node_name_equip_with_output_observer
791+
]
767792
# 3 ids were repeated because we copy over the id from node to its output observer
768793
# torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
769794
for dh_id in repeated_debug_handle_ids:
@@ -776,22 +801,23 @@ def test_quantize_pt2e_preserve_handle(self) -> None:
776801
res_counter = Counter(debug_handle_map.values())
777802
# same set of ids where repeated, because we copy over the id from observer/fake_quant to
778803
# dequantize node
779-
repeated_debug_handle_ids = [1, 2, 3]
804+
repeated_debug_handle_ids = [
805+
debug_handle_map[n_name] for n_name in node_name_equip_with_output_observer
806+
]
780807
for dh_id in repeated_debug_handle_ids:
781808
self.assertEqual(res_counter[dh_id], 2)
782809

783-
def test_extract_results_from_loggers(self) -> None:
810+
def test_extract_results_from_loggers(self):
784811
m = TestHelperModules.Conv2dThenConv1d()
785812
example_inputs = m.example_inputs()
786813
ep = export_for_training(m, example_inputs, strict=True)
787-
generate_numeric_debug_handle(ep)
788814
m = ep.module()
789-
m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6]
815+
m_ref_logger = prepare_for_propagation_comparison(m)
790816

791817
quantizer = XNNPACKQuantizer().set_global(
792818
get_symmetric_quantization_config(is_per_channel=False)
793819
)
794-
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
820+
m = prepare_pt2e(m, quantizer)
795821
m(*example_inputs)
796822
m = convert_pt2e(m)
797823
m_quant_logger = prepare_for_propagation_comparison(m)
@@ -800,29 +826,22 @@ def test_extract_results_from_loggers(self) -> None:
800826
m_quant_logger(*example_inputs)
801827
ref_results = extract_results_from_loggers(m_ref_logger)
802828
quant_results = extract_results_from_loggers(m_quant_logger)
803-
comparison_results = compare_results(
804-
ref_results,
805-
quant_results, # pyre-ignore[6]
806-
)
829+
comparison_results = compare_results(ref_results, quant_results)
807830
for node_summary in comparison_results.values():
808831
if len(node_summary.results) > 0:
809-
self.assertGreaterEqual(
810-
node_summary.results[0].sqnr,
811-
35, # pyre-ignore[6]
812-
)
832+
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
813833

814-
def test_extract_results_from_loggers_list_output(self) -> None:
834+
def test_extract_results_from_loggers_list_output(self):
815835
m = TestHelperModules.Conv2dWithSplit()
816836
example_inputs = m.example_inputs()
817837
ep = export_for_training(m, example_inputs, strict=True)
818-
generate_numeric_debug_handle(ep)
819838
m = ep.module()
820-
m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6]
839+
m_ref_logger = prepare_for_propagation_comparison(m)
821840

822841
quantizer = XNNPACKQuantizer().set_global(
823842
get_symmetric_quantization_config(is_per_channel=False)
824843
)
825-
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
844+
m = prepare_pt2e(m, quantizer)
826845
m(*example_inputs)
827846
m = convert_pt2e(m)
828847
m_quant_logger = prepare_for_propagation_comparison(m)
@@ -831,15 +850,12 @@ def test_extract_results_from_loggers_list_output(self) -> None:
831850
m_quant_logger(*example_inputs)
832851
ref_results = extract_results_from_loggers(m_ref_logger)
833852
quant_results = extract_results_from_loggers(m_quant_logger)
834-
comparison_results = compare_results(
835-
ref_results,
836-
quant_results, # pyre-ignore[6]
837-
)
853+
comparison_results = compare_results(ref_results, quant_results)
838854
for node_summary in comparison_results.values():
839855
if len(node_summary.results) > 0:
840856
sqnr = node_summary.results[0].sqnr
841857
if isinstance(sqnr, list):
842858
for sqnr_i in sqnr:
843859
self.assertGreaterEqual(sqnr_i, 35)
844860
else:
845-
self.assertGreaterEqual(sqnr, 35) # pyre-ignore[6]
861+
self.assertGreaterEqual(sqnr, 35)

0 commit comments

Comments
 (0)