Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 47 additions & 76 deletions backends/xnnpack/test/quantizer/test_pt2e_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@

# pyre-unsafe

import unittest

from collections import Counter
from typing import Dict, Tuple
from typing import Tuple

import torch
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
Expand All @@ -33,27 +31,26 @@
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
TemporaryFileName,
TestCase,
)
from torchao.quantization.pt2e import (
allow_exported_model_train_eval,
compare_results,
CUSTOM_KEY,
extract_results_from_loggers,
generate_numeric_debug_handle,
NUMERIC_DEBUG_HANDLE_KEY,
FROM_NODE_KEY,
prepare_for_propagation_comparison,
)

from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
from torchao.quantization.pt2e.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)
from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer
from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer
from torchao.testing.pt2e.utils import PT2EQuantizationTestCase
from torchao.testing.pt2e.utils import (
PT2ENumericDebuggerTestCase,
PT2EQuantizationTestCase,
)


class TestQuantizePT2E(PT2EQuantizationTestCase):
Expand Down Expand Up @@ -495,7 +492,8 @@ def forward(self, x):
for n in m.graph.nodes:
if n.op == "get_attr" and "frozen_param" in n.target:
for key in n.meta:
self.assertEqual(n.meta[key], weight_meta[key])
if key != FROM_NODE_KEY:
self.assertEqual(n.meta[key], weight_meta[key])

def test_reentrant(self) -> None:
"""Test we can safely call quantization apis multiple times"""
Expand Down Expand Up @@ -725,76 +723,59 @@ def test_save_load(self) -> None:
instantiate_parametrized_tests(TestQuantizePT2E)


@unittest.skip("TODO: Reenable it after debug infrature finish update")
class TestNumericDebugger(TestCase):
def _extract_debug_handles(self, model) -> Dict[str, int]:
debug_handle_map: Dict[str, int] = {}

def _extract_debug_handles_from_node(node: torch.fx.Node) -> None:
nonlocal debug_handle_map
if (
CUSTOM_KEY in node.meta
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
):
debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
NUMERIC_DEBUG_HANDLE_KEY
]

bfs_trace_with_node_process(model, _extract_debug_handles_from_node)
return debug_handle_map

def _assert_each_node_has_debug_handle(self, model) -> None:
def _assert_node_has_debug_handle(node: torch.fx.Node) -> None:
self.assertTrue(
CUSTOM_KEY in node.meta
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY],
f"Node {node} doesn't have debug handle",
)

bfs_trace_with_node_process(model, _assert_node_has_debug_handle)
class TestXNNPACKQuantizerNumericDebugger(PT2ENumericDebuggerTestCase):

def test_quantize_pt2e_preserve_handle(self) -> None:
def test_quantize_pt2e_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()

quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=False)
)
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
debug_handle_map = self._extract_debug_handles(m)
res_counter = Counter(debug_handle_map.values())
repeated_debug_handle_ids = [1, 2, 3]
# 3 ids were repeated because we copy over the id from node to its output observer
m = prepare_pt2e(m, quantizer)
from_node_source_map = self._extract_from_node_source(m)
node_name_equip_with_output_observer = [
"conv2d",
"conv1d",
"squeeze",
]
res_counter = Counter(from_node_source_map.values())
repeated_from_node_source = [
from_node_source_map[n_name]
for n_name in node_name_equip_with_output_observer
]
# 3 infos were repeated because we copy over the info from node to its output observer
# torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
for dh_id in repeated_debug_handle_ids:
self.assertEqual(res_counter[dh_id], 2)
for from_node_source in repeated_from_node_source:
self.assertEqual(res_counter[from_node_source], 2)

m(*example_inputs)
m = convert_pt2e(m)
self._assert_each_node_has_debug_handle(ep)
debug_handle_map = self._extract_debug_handles(m)
res_counter = Counter(debug_handle_map.values())
# same set of ids where repeated, because we copy over the id from observer/fake_quant to
# dequantize node
repeated_debug_handle_ids = [1, 2, 3]
for dh_id in repeated_debug_handle_ids:
self.assertEqual(res_counter[dh_id], 2)

def test_extract_results_from_loggers(self) -> None:
self._assert_each_node_has_from_node_source(m)
from_node_source_map = self._extract_from_node_source(m)
res_counter = Counter(from_node_source_map.values())
# same set of infos where repeated, because we copy over the info from observer/fake_quant to
# quantize/dequantize node
repeated_from_node_source = [
from_node_source_map[n_name]
for n_name in node_name_equip_with_output_observer
]
for from_node_source in repeated_from_node_source:
self.assertEqual(res_counter[from_node_source], 3)

def test_extract_results_from_loggers(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6]
m_ref_logger = prepare_for_propagation_comparison(m)

quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=False)
)
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m_quant_logger = prepare_for_propagation_comparison(m)
Expand All @@ -803,29 +784,22 @@ def test_extract_results_from_loggers(self) -> None:
m_quant_logger(*example_inputs)
ref_results = extract_results_from_loggers(m_ref_logger)
quant_results = extract_results_from_loggers(m_quant_logger)
comparison_results = compare_results(
ref_results,
quant_results, # pyre-ignore[6]
)
comparison_results = compare_results(ref_results, quant_results)
for node_summary in comparison_results.values():
if len(node_summary.results) > 0:
self.assertGreaterEqual(
node_summary.results[0].sqnr,
35, # pyre-ignore[6]
)
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)

def test_extract_results_from_loggers_list_output(self) -> None:
def test_extract_results_from_loggers_list_output(self):
m = TestHelperModules.Conv2dWithSplit()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m) # pyre-ignore[6]
m_ref_logger = prepare_for_propagation_comparison(m)

quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=False)
)
m = prepare_pt2e(m, quantizer) # pyre-ignore[6]
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m_quant_logger = prepare_for_propagation_comparison(m)
Expand All @@ -834,15 +808,12 @@ def test_extract_results_from_loggers_list_output(self) -> None:
m_quant_logger(*example_inputs)
ref_results = extract_results_from_loggers(m_ref_logger)
quant_results = extract_results_from_loggers(m_quant_logger)
comparison_results = compare_results(
ref_results,
quant_results, # pyre-ignore[6]
)
comparison_results = compare_results(ref_results, quant_results)
for node_summary in comparison_results.values():
if len(node_summary.results) > 0:
sqnr = node_summary.results[0].sqnr
if isinstance(sqnr, list):
for sqnr_i in sqnr:
self.assertGreaterEqual(sqnr_i, 35)
else:
self.assertGreaterEqual(sqnr, 35) # pyre-ignore[6]
self.assertGreaterEqual(sqnr, 35)
Loading