|
7 | 7 | # pyre-unsafe |
8 | 8 |
|
9 | 9 | from collections import Counter |
10 | | -from typing import Dict, Tuple |
| 10 | +from typing import Tuple |
11 | 11 |
|
12 | 12 | import torch |
13 | 13 | from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( |
|
31 | 31 | from torch.testing._internal.common_utils import ( |
32 | 32 | instantiate_parametrized_tests, |
33 | 33 | TemporaryFileName, |
34 | | - TestCase, |
35 | 34 | ) |
36 | 35 | from torchao.quantization.pt2e import ( |
37 | 36 | allow_exported_model_train_eval, |
38 | 37 | compare_results, |
39 | | - CUSTOM_KEY, |
40 | 38 | extract_results_from_loggers, |
41 | 39 | generate_numeric_debug_handle, |
42 | | - NUMERIC_DEBUG_HANDLE_KEY, |
43 | 40 | prepare_for_propagation_comparison, |
44 | 41 | ) |
45 | 42 |
|
46 | | -from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process |
47 | 43 | from torchao.quantization.pt2e.quantize_pt2e import ( |
48 | 44 | convert_pt2e, |
49 | 45 | prepare_pt2e, |
50 | 46 | prepare_qat_pt2e, |
51 | 47 | ) |
52 | 48 | from torchao.quantization.pt2e.quantizer import ComposableQuantizer, Quantizer |
53 | 49 | from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer |
54 | | -from torchao.testing.pt2e.utils import PT2EQuantizationTestCase |
| 50 | +from torchao.testing.pt2e.utils import PT2EQuantizationTestCase, PT2ENumericDebuggerTestCase |
55 | 51 |
|
56 | 52 |
|
57 | 53 | class TestQuantizePT2E(PT2EQuantizationTestCase): |
@@ -723,33 +719,7 @@ def test_save_load(self) -> None: |
723 | 719 | instantiate_parametrized_tests(TestQuantizePT2E) |
724 | 720 |
|
725 | 721 |
|
726 | | -class TestNumericDebugger(TestCase): |
727 | | - def _extract_debug_handles(self, model) -> Dict[str, int]: |
728 | | - debug_handle_map: Dict[str, int] = {} |
729 | | - |
730 | | - def _extract_debug_handles_from_node(node: torch.fx.Node) -> None: |
731 | | - nonlocal debug_handle_map |
732 | | - if ( |
733 | | - CUSTOM_KEY in node.meta |
734 | | - and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] |
735 | | - ): |
736 | | - debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][ |
737 | | - NUMERIC_DEBUG_HANDLE_KEY |
738 | | - ] |
739 | | - |
740 | | - bfs_trace_with_node_process(model, _extract_debug_handles_from_node) |
741 | | - return debug_handle_map |
742 | | - |
743 | | - def _assert_each_node_has_debug_handle(self, model) -> None: |
744 | | - def _assert_node_has_debug_handle(node: torch.fx.Node) -> None: |
745 | | - self.assertTrue( |
746 | | - CUSTOM_KEY in node.meta |
747 | | - and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY], |
748 | | - f"Node {node} doesn't have debug handle", |
749 | | - ) |
750 | | - |
751 | | - bfs_trace_with_node_process(model, _assert_node_has_debug_handle) |
752 | | - |
| 722 | +class TestXNNPACKQuantizerNumericDebugger(PT2ENumericDebuggerTestCase): |
753 | 723 | def test_quantize_pt2e_preserve_handle(self) -> None: |
754 | 724 | m = TestHelperModules.Conv2dThenConv1d() |
755 | 725 | example_inputs = m.example_inputs() |
|
0 commit comments