|  | 
| 17 | 17 | 
 | 
| 18 | 18 | from unittest.mock import patch | 
| 19 | 19 | 
 | 
|  | 20 | +import pandas as pd | 
|  | 21 | + | 
| 20 | 22 | import torch | 
| 21 | 23 | import torch.fx | 
| 22 | 24 | 
 | 
| @@ -578,6 +580,75 @@ def test_get_runtime_intermediate_outputs(self): | 
| 578 | 580 |                 self.assertIn((key,), runtime_outputs) | 
| 579 | 581 |                 self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE) | 
| 580 | 582 | 
 | 
|  | 583 | +    def test_calculate_numeric_gap(self): | 
|  | 584 | +        # Create a context manager to patch functions called by Inspector.__init__ | 
|  | 585 | +        with patch.object( | 
|  | 586 | +            _inspector, "parse_etrecord", return_value=None | 
|  | 587 | +        ), patch.object( | 
|  | 588 | +            _inspector, "gen_etdump_object", return_value=None | 
|  | 589 | +        ), patch.object( | 
|  | 590 | +            EventBlock, "_gen_from_etdump" | 
|  | 591 | +        ), patch.object( | 
|  | 592 | +            _inspector, "gen_graphs_from_etrecord" | 
|  | 593 | +        ): | 
|  | 594 | +            # Call the constructor of Inspector | 
|  | 595 | +            inspector_instance = Inspector( | 
|  | 596 | +                etdump_path=ETDUMP_PATH, | 
|  | 597 | +                etrecord=ETRECORD_PATH, | 
|  | 598 | +            ) | 
|  | 599 | + | 
|  | 600 | +            aot_intermediate_outputs = { | 
|  | 601 | +                (0,): torch.tensor([1.0, 2.0, 3.0]), | 
|  | 602 | +                (1,): torch.tensor([4.0, 5.0, 6.0]), | 
|  | 603 | +            } | 
|  | 604 | + | 
|  | 605 | +            runtime_intermediate_outputs = { | 
|  | 606 | +                (0,): torch.tensor([2.0, 1.0, 4.0]), | 
|  | 607 | +                (1,): torch.tensor([3.0, 6.0, 5.0]), | 
|  | 608 | +            } | 
|  | 609 | + | 
|  | 610 | +            inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs | 
|  | 611 | +            inspector_instance._get_runtime_intermediate_outputs = ( | 
|  | 612 | +                lambda: runtime_intermediate_outputs | 
|  | 613 | +            ) | 
|  | 614 | + | 
|  | 615 | +            df = inspector_instance.calculate_numeric_gap(distance="L1") | 
|  | 616 | +            self.assertIsInstance(df, pd.DataFrame) | 
|  | 617 | +            self.assertEqual(len(df), 2) | 
|  | 618 | +            cols = set(df.columns) | 
|  | 619 | +            expected_cols = { | 
|  | 620 | +                "aot_debug_handle", | 
|  | 621 | +                "aot_intermediate_output", | 
|  | 622 | +                "runtime_debug_handle", | 
|  | 623 | +                "runtime_intermediate_output", | 
|  | 624 | +                "gap", | 
|  | 625 | +            } | 
|  | 626 | +            self.assertEqual(cols, expected_cols) | 
|  | 627 | +            founded_aot_debug_handle = set(df["aot_debug_handle"]) | 
|  | 628 | +            self.assertEqual( | 
|  | 629 | +                founded_aot_debug_handle, set(aot_intermediate_outputs.keys()) | 
|  | 630 | +            ) | 
|  | 631 | +            for _, row in df.iterrows(): | 
|  | 632 | +                aot_debuh_handle = row["aot_debug_handle"] | 
|  | 633 | +                # aot_intermediate_output should equal aot_intermediate_outputs[h] | 
|  | 634 | +                self.assertTrue( | 
|  | 635 | +                    torch.allclose( | 
|  | 636 | +                        row["aot_intermediate_output"], | 
|  | 637 | +                        aot_intermediate_outputs[aot_debuh_handle], | 
|  | 638 | +                    ) | 
|  | 639 | +                ) | 
|  | 640 | +                # runtime_debug_hanlde equals aot_debug_handle at this case | 
|  | 641 | +                self.assertEqual(row["runtime_debug_handle"], aot_debuh_handle) | 
|  | 642 | +                # runtime_intermediate_output should equal runtime_intermediate_outputs[h] | 
|  | 643 | +                self.assertTrue( | 
|  | 644 | +                    torch.allclose( | 
|  | 645 | +                        row["runtime_intermediate_output"], | 
|  | 646 | +                        runtime_intermediate_outputs[aot_debuh_handle], | 
|  | 647 | +                    ) | 
|  | 648 | +                ) | 
|  | 649 | +                # gap should equal 3.0 | 
|  | 650 | +                self.assertEqual(row["gap"], 3.0) | 
|  | 651 | + | 
| 581 | 652 |     def _gen_random_float_list(self) -> List[float]: | 
| 582 | 653 |         return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)] | 
| 583 | 654 | 
 | 
|  | 
0 commit comments