|
7 | 7 |
|
8 | 8 |
|
9 | 9 | from logging import Logger |
| 10 | +from unittest.mock import MagicMock |
10 | 11 |
|
11 | 12 | import pandas as pd |
12 | 13 | from ax.adapter.registry import Generators |
| 14 | +from ax.core.arm import Arm |
13 | 15 | from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose |
14 | 16 | from ax.core.data import Data |
| 17 | +from ax.core.derived_metric import DerivedMetric |
| 18 | +from ax.core.experiment import Experiment |
15 | 19 | from ax.core.trial_status import TrialStatus |
16 | 20 | from ax.exceptions.core import DataRequiredError, UserInputError |
17 | 21 | from ax.exceptions.generation_strategy import MaxParallelismReachedException |
|
24 | 28 | from ax.generation_strategy.transition_criterion import ( |
25 | 29 | AutoTransitionAfterGen, |
26 | 30 | AuxiliaryExperimentCheck, |
| 31 | + FreshLILOLabelCheck, |
27 | 32 | IsSingleObjective, |
28 | 33 | MaxGenerationParallelism, |
29 | 34 | MaxTrialsAwaitingData, |
30 | 35 | MinTrials, |
31 | 36 | ) |
| 37 | +from ax.utils.common.constants import Keys |
| 38 | +from ax.utils.common.hash_utils import compute_lilo_input_hash |
32 | 39 | from ax.utils.common.logger import get_logger |
33 | 40 | from ax.utils.common.testutils import TestCase |
34 | 41 | from ax.utils.testing.core_stubs import ( |
|
41 | 48 | logger: Logger = get_logger(__name__) |
42 | 49 |
|
43 | 50 |
|
| 51 | +def _mock_node(trials_from_node: set[int]) -> MagicMock: |
| 52 | + """Create a mock GenerationNode with a specified trials_from_node set.""" |
| 53 | + node = MagicMock() |
| 54 | + node.trials_from_node = trials_from_node |
| 55 | + return node |
| 56 | + |
| 57 | + |
44 | 58 | class TestTransitionCriterion(TestCase): |
45 | 59 | def setUp(self) -> None: |
46 | 60 | super().setUp() |
@@ -614,3 +628,189 @@ def test_max_generation_parallelism_block_error(self) -> None: |
614 | 628 | experiment=self.experiment, |
615 | 629 | trials_from_node={0, 1, 2}, |
616 | 630 | ) |
| 631 | + |
| 632 | + def test_fresh_lilo_label_check(self) -> None: |
| 633 | + """Verify FreshLILOLabelCheck counts only hash-fresh trials.""" |
| 634 | + exp = get_branin_experiment() |
| 635 | + |
| 636 | + # Register a DerivedMetric with pairwise name. |
| 637 | + pairwise_metric = DerivedMetric( |
| 638 | + name=Keys.PAIRWISE_PREFERENCE_QUERY.value, |
| 639 | + input_metric_names=["branin"], |
| 640 | + ) |
| 641 | + exp.add_tracking_metric(pairwise_metric) |
| 642 | + |
| 643 | + criterion = FreshLILOLabelCheck( |
| 644 | + threshold=2, |
| 645 | + transition_to="next_node", |
| 646 | + only_in_statuses=[TrialStatus.COMPLETED], |
| 647 | + ) |
| 648 | + |
| 649 | + # Helper to create and complete a trial with data. |
| 650 | + def _add_trial(idx: int, exp: Experiment = exp) -> None: |
| 651 | + trial = exp.new_batch_trial() |
| 652 | + trial.add_arm( |
| 653 | + Arm(name=f"{idx}_0", parameters={"x1": float(idx), "x2": 0.0}) |
| 654 | + ) |
| 655 | + trial.mark_running(no_runner_required=True) |
| 656 | + trial.mark_completed() |
| 657 | + exp.attach_data( |
| 658 | + Data( |
| 659 | + df=pd.DataFrame( |
| 660 | + [ |
| 661 | + { |
| 662 | + "trial_index": idx, |
| 663 | + "arm_name": f"{idx}_0", |
| 664 | + "metric_name": "branin", |
| 665 | + "metric_signature": "branin", |
| 666 | + "mean": float(idx), |
| 667 | + "sem": 0.1, |
| 668 | + } |
| 669 | + ] |
| 670 | + ) |
| 671 | + ) |
| 672 | + ) |
| 673 | + |
| 674 | + # Create 3 trials, stamp first 2 with current hash. |
| 675 | + for i in range(3): |
| 676 | + _add_trial(i) |
| 677 | + |
| 678 | + current_hash = compute_lilo_input_hash(exp, ["branin"]) |
| 679 | + trials_from_node = {0, 1, 2} |
| 680 | + |
| 681 | + with self.subTest("no_hashes_none_count"): |
| 682 | + # No hash stamps → no trials counted (only LILO trials with |
| 683 | + # a matching hash contribute). |
| 684 | + count = criterion.num_contributing_to_threshold(exp, trials_from_node) |
| 685 | + self.assertEqual(count, 0) |
| 686 | + |
| 687 | + # Stamp trials 0 and 1 with the current hash. |
| 688 | + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = current_hash |
| 689 | + exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = current_hash |
| 690 | + |
| 691 | + with self.subTest("fresh_hashes_count"): |
| 692 | + count = criterion.num_contributing_to_threshold(exp, trials_from_node) |
| 693 | + # Trials 0, 1 (fresh hash). Trial 2 (no hash → excluded). |
| 694 | + self.assertEqual(count, 2) |
| 695 | + |
| 696 | + # Make trial 1 stale. |
| 697 | + exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = "stale_hash" |
| 698 | + |
| 699 | + with self.subTest("stale_hash_excluded"): |
| 700 | + count = criterion.num_contributing_to_threshold(exp, trials_from_node) |
| 701 | + # Trial 0 (fresh). Trial 1 (stale) and trial 2 (no hash) excluded. |
| 702 | + self.assertEqual(count, 1) |
| 703 | + self.assertFalse(criterion.is_met(exp, _mock_node(trials_from_node))) |
| 704 | + |
| 705 | + # Make trial 0 stale too. |
| 706 | + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = "another_stale" |
| 707 | + |
| 708 | + with self.subTest("not_enough_fresh"): |
| 709 | + count = criterion.num_contributing_to_threshold(exp, trials_from_node) |
| 710 | + # All stamped trials are stale, trial 2 has no hash → 0. |
| 711 | + self.assertEqual(count, 0) |
| 712 | + self.assertFalse(criterion.is_met(exp, _mock_node(trials_from_node))) |
| 713 | + |
| 714 | + with self.subTest("data_change_invalidates"): |
| 715 | + # Add new data — changes the current hash, making ALL stamped |
| 716 | + # trials stale. |
| 717 | + _add_trial(3) |
| 718 | + trials_from_node.add(3) |
| 719 | + count = criterion.num_contributing_to_threshold(exp, trials_from_node) |
| 720 | + # Trials 0, 1 stale. Trials 2, 3 have no hash → excluded. |
| 721 | + self.assertEqual(count, 0) |
| 722 | + |
| 723 | + def test_fresh_lilo_label_check_require_sufficient(self) -> None: |
| 724 | + """Verify require_sufficient flag controls is_met direction.""" |
| 725 | + exp = get_branin_experiment() |
| 726 | + |
| 727 | + pairwise_metric = DerivedMetric( |
| 728 | + name=Keys.PAIRWISE_PREFERENCE_QUERY.value, |
| 729 | + input_metric_names=["branin"], |
| 730 | + ) |
| 731 | + exp.add_tracking_metric(pairwise_metric) |
| 732 | + |
| 733 | + # Create 2 completed trials with data. |
| 734 | + for i in range(2): |
| 735 | + trial = exp.new_batch_trial() |
| 736 | + trial.add_arm(Arm(name=f"{i}_0", parameters={"x1": float(i), "x2": 0.0})) |
| 737 | + trial.mark_running(no_runner_required=True) |
| 738 | + trial.mark_completed() |
| 739 | + exp.attach_data( |
| 740 | + Data( |
| 741 | + df=pd.DataFrame( |
| 742 | + [ |
| 743 | + { |
| 744 | + "trial_index": i, |
| 745 | + "arm_name": f"{i}_0", |
| 746 | + "metric_name": "branin", |
| 747 | + "metric_signature": "branin", |
| 748 | + "mean": float(i), |
| 749 | + "sem": 0.1, |
| 750 | + } |
| 751 | + ] |
| 752 | + ) |
| 753 | + ) |
| 754 | + ) |
| 755 | + |
| 756 | + current_hash = compute_lilo_input_hash(exp, ["branin"]) |
| 757 | + # Stamp both trials as fresh. |
| 758 | + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = current_hash |
| 759 | + exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = current_hash |
| 760 | + trials_from_node = {0, 1} |
| 761 | + |
| 762 | + sufficient = FreshLILOLabelCheck( |
| 763 | + threshold=2, |
| 764 | + transition_to="MBG", |
| 765 | + require_sufficient=True, |
| 766 | + only_in_statuses=[TrialStatus.COMPLETED], |
| 767 | + ) |
| 768 | + insufficient = FreshLILOLabelCheck( |
| 769 | + threshold=2, |
| 770 | + transition_to="LILO", |
| 771 | + require_sufficient=False, |
| 772 | + only_in_statuses=[TrialStatus.COMPLETED], |
| 773 | + ) |
| 774 | + |
| 775 | + with self.subTest("sufficient_met_when_enough_fresh"): |
| 776 | + # 2 fresh >= threshold 2 → require_sufficient=True is met. |
| 777 | + self.assertTrue(sufficient.is_met(exp, _mock_node(trials_from_node))) |
| 778 | + |
| 779 | + with self.subTest("insufficient_not_met_when_enough_fresh"): |
| 780 | + # 2 fresh >= threshold 2 → require_sufficient=False is NOT met. |
| 781 | + self.assertFalse(insufficient.is_met(exp, _mock_node(trials_from_node))) |
| 782 | + |
| 783 | + # Make trial 0 stale → only 1 fresh trial. |
| 784 | + exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = "stale" |
| 785 | + |
| 786 | + with self.subTest("sufficient_not_met_when_stale"): |
| 787 | + # 1 fresh < threshold 2 → require_sufficient=True is NOT met. |
| 788 | + self.assertFalse(sufficient.is_met(exp, _mock_node(trials_from_node))) |
| 789 | + |
| 790 | + with self.subTest("insufficient_met_when_stale"): |
| 791 | + # 1 fresh < threshold 2 → require_sufficient=False IS met. |
| 792 | + self.assertTrue(insufficient.is_met(exp, _mock_node(trials_from_node))) |
| 793 | + |
| 794 | + def test_fresh_lilo_label_check_non_lilo_fallback(self) -> None: |
| 795 | + """Non-LILO experiment: require_sufficient=True always met, |
| 796 | + require_sufficient=False never met.""" |
| 797 | + exp = get_branin_experiment() |
| 798 | + # No pairwise DerivedMetric registered — non-LILO experiment. |
| 799 | + trials_from_node: set[int] = set() |
| 800 | + |
| 801 | + sufficient = FreshLILOLabelCheck( |
| 802 | + threshold=32, |
| 803 | + transition_to="MBG", |
| 804 | + require_sufficient=True, |
| 805 | + ) |
| 806 | + insufficient = FreshLILOLabelCheck( |
| 807 | + threshold=32, |
| 808 | + transition_to="LILO", |
| 809 | + require_sufficient=False, |
| 810 | + ) |
| 811 | + |
| 812 | + with self.subTest("non_lilo_sufficient_always_met"): |
| 813 | + self.assertTrue(sufficient.is_met(exp, _mock_node(trials_from_node))) |
| 814 | + |
| 815 | + with self.subTest("non_lilo_insufficient_never_met"): |
| 816 | + self.assertFalse(insufficient.is_met(exp, _mock_node(trials_from_node))) |
0 commit comments