Skip to content

Commit c685371

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
Replace debug handle with from_node to trace operator transformation
Summary: X-link: pytorch/ao#2339 This diff replace the debug handle with `from_node` infrastructure, which is a first class citizen in exported program and used to trace the node-level transformation. For simplify the progress, we are trying to reuse the debug handle infrastructure by generating debug handle from from_node info via hasing. After this change user no longer need to invoke `generate_numeric_debug_handle` for debugging. Also the original pipeline will still work under current scenario. Reviewed By: jerryzh168 Differential Revision: D76168997
1 parent 120eb85 commit c685371

File tree

1 file changed

+62
-37
lines changed

1 file changed

+62
-37
lines changed

backends/xnnpack/test/quantizer/test_pt2e_quantization.py

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -723,11 +723,22 @@ def test_save_load(self) -> None:
723723
instantiate_parametrized_tests(TestQuantizePT2E)
724724

725725

726+
#TODO: deduplicate with TestNumericDebugger under torchao
726727
class TestNumericDebugger(TestCase):
727-
def _extract_debug_handles(self, model) -> Dict[str, int]:
728-
debug_handle_map: Dict[str, int] = {}
728+
def _assert_each_node_has_debug_handle(self, model) -> None:
729+
def _assert_node_has_debug_handle(node):
730+
self.assertTrue(
731+
CUSTOM_KEY in node.meta
732+
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY],
733+
f"Node {node} doesn't have debug handle",
734+
)
735+
736+
bfs_trace_with_node_process(model, _assert_node_has_debug_handle)
729737

730-
def _extract_debug_handles_from_node(node: torch.fx.Node) -> None:
738+
def _extract_debug_handles(self, model) -> dict[str, int]:
739+
debug_handle_map: dict[str, int] = {}
740+
741+
def _extract_debug_handles_from_node(node):
731742
nonlocal debug_handle_map
732743
if (
733744
CUSTOM_KEY in node.meta
@@ -738,32 +749,55 @@ def _extract_debug_handles_from_node(node: torch.fx.Node) -> None:
738749
]
739750

740751
bfs_trace_with_node_process(model, _extract_debug_handles_from_node)
752+
741753
return debug_handle_map
742754

743-
def _assert_each_node_has_debug_handle(self, model) -> None:
744-
def _assert_node_has_debug_handle(node: torch.fx.Node) -> None:
745-
self.assertTrue(
746-
CUSTOM_KEY in node.meta
747-
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY],
748-
f"Node {node} doesn't have debug handle",
749-
)
755+
def _extract_debug_handles_with_prev_decomp_op(self, model) -> dict[str, int]:
756+
prev_decomp_op_to_debug_handle_map: dict[str, int] = {}
750757

751-
bfs_trace_with_node_process(model, _assert_node_has_debug_handle)
758+
def _extract_debug_handles_with_prev_decomp_op_from_node(node):
759+
nonlocal prev_decomp_op_to_debug_handle_map
760+
if (
761+
CUSTOM_KEY in node.meta
762+
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
763+
):
764+
prev_decomp_op = str(node.meta.get("nn_module_stack"))
765+
debug_handle = node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
766+
if prev_decomp_op not in prev_decomp_op_to_debug_handle_map:
767+
prev_decomp_op_to_debug_handle_map[prev_decomp_op] = debug_handle
768+
else:
769+
assert (
770+
prev_decomp_op_to_debug_handle_map[prev_decomp_op]
771+
== debug_handle
772+
), f"Node {node} has different debug handle {debug_handle}"
773+
"than previous node sharing the same decomp op {prev_decomp_op}"
774+
775+
bfs_trace_with_node_process(
776+
model, _extract_debug_handles_with_prev_decomp_op_from_node
777+
)
778+
return prev_decomp_op_to_debug_handle_map
752779

753-
def test_quantize_pt2e_preserve_handle(self) -> None:
780+
def test_quantize_pt2e_preserve_handle(self):
754781
m = TestHelperModules.Conv2dThenConv1d()
755782
example_inputs = m.example_inputs()
756783
ep = export_for_training(m, example_inputs, strict=True)
757-
generate_numeric_debug_handle(ep)
784+
# generate_numeric_debug_handle(ep)
758785
m = ep.module()
759786

760787
quantizer = XNNPACKQuantizer().set_global(
761788
get_symmetric_quantization_config(is_per_channel=False)
762789
)
763-
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
790+
m = prepare_pt2e(m, quantizer)
764791
debug_handle_map = self._extract_debug_handles(m)
792+
node_name_equip_with_output_observer = [
793+
"conv2d",
794+
"conv1d",
795+
"squeeze",
796+
]
765797
res_counter = Counter(debug_handle_map.values())
766-
repeated_debug_handle_ids = [1, 2, 3]
798+
repeated_debug_handle_ids = [
799+
debug_handle_map[n_name] for n_name in node_name_equip_with_output_observer
800+
]
767801
# 3 ids were repeated because we copy over the id from node to its output observer
768802
# torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
769803
for dh_id in repeated_debug_handle_ids:
@@ -776,22 +810,23 @@ def test_quantize_pt2e_preserve_handle(self) -> None:
776810
res_counter = Counter(debug_handle_map.values())
777811
# same set of ids where repeated, because we copy over the id from observer/fake_quant to
778812
# dequantize node
779-
repeated_debug_handle_ids = [1, 2, 3]
813+
repeated_debug_handle_ids = [
814+
debug_handle_map[n_name] for n_name in node_name_equip_with_output_observer
815+
]
780816
for dh_id in repeated_debug_handle_ids:
781817
self.assertEqual(res_counter[dh_id], 2)
782818

783-
def test_extract_results_from_loggers(self) -> None:
819+
def test_extract_results_from_loggers(self):
784820
m = TestHelperModules.Conv2dThenConv1d()
785821
example_inputs = m.example_inputs()
786822
ep = export_for_training(m, example_inputs, strict=True)
787-
generate_numeric_debug_handle(ep)
788823
m = ep.module()
789-
m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6]
824+
m_ref_logger = prepare_for_propagation_comparison(m)
790825

791826
quantizer = XNNPACKQuantizer().set_global(
792827
get_symmetric_quantization_config(is_per_channel=False)
793828
)
794-
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
829+
m = prepare_pt2e(m, quantizer)
795830
m(*example_inputs)
796831
m = convert_pt2e(m)
797832
m_quant_logger = prepare_for_propagation_comparison(m)
@@ -800,29 +835,22 @@ def test_extract_results_from_loggers(self) -> None:
800835
m_quant_logger(*example_inputs)
801836
ref_results = extract_results_from_loggers(m_ref_logger)
802837
quant_results = extract_results_from_loggers(m_quant_logger)
803-
comparison_results = compare_results(
804-
ref_results,
805-
quant_results, # pyre-ignore[6]
806-
)
838+
comparison_results = compare_results(ref_results, quant_results)
807839
for node_summary in comparison_results.values():
808840
if len(node_summary.results) > 0:
809-
self.assertGreaterEqual(
810-
node_summary.results[0].sqnr,
811-
35, # pyre-ignore[6]
812-
)
841+
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
813842

814-
def test_extract_results_from_loggers_list_output(self) -> None:
843+
def test_extract_results_from_loggers_list_output(self):
815844
m = TestHelperModules.Conv2dWithSplit()
816845
example_inputs = m.example_inputs()
817846
ep = export_for_training(m, example_inputs, strict=True)
818-
generate_numeric_debug_handle(ep)
819847
m = ep.module()
820-
m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6]
848+
m_ref_logger = prepare_for_propagation_comparison(m)
821849

822850
quantizer = XNNPACKQuantizer().set_global(
823851
get_symmetric_quantization_config(is_per_channel=False)
824852
)
825-
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
853+
m = prepare_pt2e(m, quantizer)
826854
m(*example_inputs)
827855
m = convert_pt2e(m)
828856
m_quant_logger = prepare_for_propagation_comparison(m)
@@ -831,15 +859,12 @@ def test_extract_results_from_loggers_list_output(self) -> None:
831859
m_quant_logger(*example_inputs)
832860
ref_results = extract_results_from_loggers(m_ref_logger)
833861
quant_results = extract_results_from_loggers(m_quant_logger)
834-
comparison_results = compare_results(
835-
ref_results,
836-
quant_results, # pyre-ignore[6]
837-
)
862+
comparison_results = compare_results(ref_results, quant_results)
838863
for node_summary in comparison_results.values():
839864
if len(node_summary.results) > 0:
840865
sqnr = node_summary.results[0].sqnr
841866
if isinstance(sqnr, list):
842867
for sqnr_i in sqnr:
843868
self.assertGreaterEqual(sqnr_i, 35)
844869
else:
845-
self.assertGreaterEqual(sqnr, 35) # pyre-ignore[6]
870+
self.assertGreaterEqual(sqnr, 35)

0 commit comments

Comments
 (0)