Skip to content

Commit e2056d2

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Add FreshLILOLabelCheck transition criterion (#4994)
Summary: Pull Request resolved: #4994 Add a hash-aware transition criterion for LILO GS loops. `FreshLILOLabelCheck` counts only trials whose LILO input hash matches the current experiment state, ensuring transitions are gated on *fresh* labels (produced under current data + LLM messages). The `require_sufficient` flag controls the transition direction: - `require_sufficient=True` (LILO_LABELING -> MBG): is_met when fresh count >= threshold. "Enough fresh labels -- proceed to BO generation." - `require_sufficient=False` (MBG -> LILO_LABELING): is_met when fresh count < threshold. "Labels are stale -- relabel before generating." Non-LILO experiments (no pairwise DerivedMetric) short-circuit: `require_sufficient=True` -> always met, `require_sufficient=False` -> never met. This prevents false relabeling triggers on non-LILO experiments. Reviewed By: saitcakmak Differential Revision: D95284285 fbshipit-source-id: 457fdfa99d8a5f9f99345d3d9dc6a46d1debf8d1
1 parent 675464c commit e2056d2

File tree

3 files changed

+334
-0
lines changed

3 files changed

+334
-0
lines changed

ax/generation_strategy/tests/test_transition_criterion.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77

88

99
from logging import Logger
10+
from unittest.mock import MagicMock
1011

1112
import pandas as pd
1213
from ax.adapter.registry import Generators
14+
from ax.core.arm import Arm
1315
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
1416
from ax.core.data import Data
17+
from ax.core.derived_metric import DerivedMetric
18+
from ax.core.experiment import Experiment
1519
from ax.core.trial_status import TrialStatus
1620
from ax.exceptions.core import DataRequiredError, UserInputError
1721
from ax.exceptions.generation_strategy import MaxParallelismReachedException
@@ -24,11 +28,14 @@
2428
from ax.generation_strategy.transition_criterion import (
2529
AutoTransitionAfterGen,
2630
AuxiliaryExperimentCheck,
31+
FreshLILOLabelCheck,
2732
IsSingleObjective,
2833
MaxGenerationParallelism,
2934
MaxTrialsAwaitingData,
3035
MinTrials,
3136
)
37+
from ax.utils.common.constants import Keys
38+
from ax.utils.common.hash_utils import compute_lilo_input_hash
3239
from ax.utils.common.logger import get_logger
3340
from ax.utils.common.testutils import TestCase
3441
from ax.utils.testing.core_stubs import (
@@ -41,6 +48,13 @@
4148
logger: Logger = get_logger(__name__)
4249

4350

51+
def _mock_node(trials_from_node: set[int]) -> MagicMock:
52+
"""Create a mock GenerationNode with a specified trials_from_node set."""
53+
node = MagicMock()
54+
node.trials_from_node = trials_from_node
55+
return node
56+
57+
4458
class TestTransitionCriterion(TestCase):
4559
def setUp(self) -> None:
4660
super().setUp()
@@ -614,3 +628,189 @@ def test_max_generation_parallelism_block_error(self) -> None:
614628
experiment=self.experiment,
615629
trials_from_node={0, 1, 2},
616630
)
631+
632+
def test_fresh_lilo_label_check(self) -> None:
633+
"""Verify FreshLILOLabelCheck counts only hash-fresh trials."""
634+
exp = get_branin_experiment()
635+
636+
# Register a DerivedMetric with pairwise name.
637+
pairwise_metric = DerivedMetric(
638+
name=Keys.PAIRWISE_PREFERENCE_QUERY.value,
639+
input_metric_names=["branin"],
640+
)
641+
exp.add_tracking_metric(pairwise_metric)
642+
643+
criterion = FreshLILOLabelCheck(
644+
threshold=2,
645+
transition_to="next_node",
646+
only_in_statuses=[TrialStatus.COMPLETED],
647+
)
648+
649+
# Helper to create and complete a trial with data.
650+
def _add_trial(idx: int, exp: Experiment = exp) -> None:
651+
trial = exp.new_batch_trial()
652+
trial.add_arm(
653+
Arm(name=f"{idx}_0", parameters={"x1": float(idx), "x2": 0.0})
654+
)
655+
trial.mark_running(no_runner_required=True)
656+
trial.mark_completed()
657+
exp.attach_data(
658+
Data(
659+
df=pd.DataFrame(
660+
[
661+
{
662+
"trial_index": idx,
663+
"arm_name": f"{idx}_0",
664+
"metric_name": "branin",
665+
"metric_signature": "branin",
666+
"mean": float(idx),
667+
"sem": 0.1,
668+
}
669+
]
670+
)
671+
)
672+
)
673+
674+
# Create 3 trials, stamp first 2 with current hash.
675+
for i in range(3):
676+
_add_trial(i)
677+
678+
current_hash = compute_lilo_input_hash(exp, ["branin"])
679+
trials_from_node = {0, 1, 2}
680+
681+
with self.subTest("no_hashes_none_count"):
682+
# No hash stamps → no trials counted (only LILO trials with
683+
# a matching hash contribute).
684+
count = criterion.num_contributing_to_threshold(exp, trials_from_node)
685+
self.assertEqual(count, 0)
686+
687+
# Stamp trials 0 and 1 with the current hash.
688+
exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = current_hash
689+
exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = current_hash
690+
691+
with self.subTest("fresh_hashes_count"):
692+
count = criterion.num_contributing_to_threshold(exp, trials_from_node)
693+
# Trials 0, 1 (fresh hash). Trial 2 (no hash → excluded).
694+
self.assertEqual(count, 2)
695+
696+
# Make trial 1 stale.
697+
exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = "stale_hash"
698+
699+
with self.subTest("stale_hash_excluded"):
700+
count = criterion.num_contributing_to_threshold(exp, trials_from_node)
701+
# Trial 0 (fresh). Trial 1 (stale) and trial 2 (no hash) excluded.
702+
self.assertEqual(count, 1)
703+
self.assertFalse(criterion.is_met(exp, _mock_node(trials_from_node)))
704+
705+
# Make trial 0 stale too.
706+
exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = "another_stale"
707+
708+
with self.subTest("not_enough_fresh"):
709+
count = criterion.num_contributing_to_threshold(exp, trials_from_node)
710+
# All stamped trials are stale, trial 2 has no hash → 0.
711+
self.assertEqual(count, 0)
712+
self.assertFalse(criterion.is_met(exp, _mock_node(trials_from_node)))
713+
714+
with self.subTest("data_change_invalidates"):
715+
# Add new data — changes the current hash, making ALL stamped
716+
# trials stale.
717+
_add_trial(3)
718+
trials_from_node.add(3)
719+
count = criterion.num_contributing_to_threshold(exp, trials_from_node)
720+
# Trials 0, 1 stale. Trials 2, 3 have no hash → excluded.
721+
self.assertEqual(count, 0)
722+
723+
def test_fresh_lilo_label_check_require_sufficient(self) -> None:
724+
"""Verify require_sufficient flag controls is_met direction."""
725+
exp = get_branin_experiment()
726+
727+
pairwise_metric = DerivedMetric(
728+
name=Keys.PAIRWISE_PREFERENCE_QUERY.value,
729+
input_metric_names=["branin"],
730+
)
731+
exp.add_tracking_metric(pairwise_metric)
732+
733+
# Create 2 completed trials with data.
734+
for i in range(2):
735+
trial = exp.new_batch_trial()
736+
trial.add_arm(Arm(name=f"{i}_0", parameters={"x1": float(i), "x2": 0.0}))
737+
trial.mark_running(no_runner_required=True)
738+
trial.mark_completed()
739+
exp.attach_data(
740+
Data(
741+
df=pd.DataFrame(
742+
[
743+
{
744+
"trial_index": i,
745+
"arm_name": f"{i}_0",
746+
"metric_name": "branin",
747+
"metric_signature": "branin",
748+
"mean": float(i),
749+
"sem": 0.1,
750+
}
751+
]
752+
)
753+
)
754+
)
755+
756+
current_hash = compute_lilo_input_hash(exp, ["branin"])
757+
# Stamp both trials as fresh.
758+
exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = current_hash
759+
exp.trials[1]._properties[Keys.LILO_INPUT_HASH] = current_hash
760+
trials_from_node = {0, 1}
761+
762+
sufficient = FreshLILOLabelCheck(
763+
threshold=2,
764+
transition_to="MBG",
765+
require_sufficient=True,
766+
only_in_statuses=[TrialStatus.COMPLETED],
767+
)
768+
insufficient = FreshLILOLabelCheck(
769+
threshold=2,
770+
transition_to="LILO",
771+
require_sufficient=False,
772+
only_in_statuses=[TrialStatus.COMPLETED],
773+
)
774+
775+
with self.subTest("sufficient_met_when_enough_fresh"):
776+
# 2 fresh >= threshold 2 → require_sufficient=True is met.
777+
self.assertTrue(sufficient.is_met(exp, _mock_node(trials_from_node)))
778+
779+
with self.subTest("insufficient_not_met_when_enough_fresh"):
780+
# 2 fresh >= threshold 2 → require_sufficient=False is NOT met.
781+
self.assertFalse(insufficient.is_met(exp, _mock_node(trials_from_node)))
782+
783+
# Make trial 0 stale → only 1 fresh trial.
784+
exp.trials[0]._properties[Keys.LILO_INPUT_HASH] = "stale"
785+
786+
with self.subTest("sufficient_not_met_when_stale"):
787+
# 1 fresh < threshold 2 → require_sufficient=True is NOT met.
788+
self.assertFalse(sufficient.is_met(exp, _mock_node(trials_from_node)))
789+
790+
with self.subTest("insufficient_met_when_stale"):
791+
# 1 fresh < threshold 2 → require_sufficient=False IS met.
792+
self.assertTrue(insufficient.is_met(exp, _mock_node(trials_from_node)))
793+
794+
def test_fresh_lilo_label_check_non_lilo_fallback(self) -> None:
795+
"""Non-LILO experiment: require_sufficient=True always met,
796+
require_sufficient=False never met."""
797+
exp = get_branin_experiment()
798+
# No pairwise DerivedMetric registered — non-LILO experiment.
799+
trials_from_node: set[int] = set()
800+
801+
sufficient = FreshLILOLabelCheck(
802+
threshold=32,
803+
transition_to="MBG",
804+
require_sufficient=True,
805+
)
806+
insufficient = FreshLILOLabelCheck(
807+
threshold=32,
808+
transition_to="LILO",
809+
require_sufficient=False,
810+
)
811+
812+
with self.subTest("non_lilo_sufficient_always_met"):
813+
self.assertTrue(sufficient.is_met(exp, _mock_node(trials_from_node)))
814+
815+
with self.subTest("non_lilo_insufficient_never_met"):
816+
self.assertFalse(insufficient.is_met(exp, _mock_node(trials_from_node)))

ax/generation_strategy/transition_criterion.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from ax.core.utils import get_trial_indices_with_required_metrics
1818
from ax.exceptions.core import DataRequiredError, UserInputError
1919
from ax.exceptions.generation_strategy import MaxParallelismReachedException
20+
from ax.utils.common.constants import Keys
21+
from ax.utils.common.hash_utils import get_current_lilo_hash
2022

2123
if TYPE_CHECKING:
2224
from ax.generation_strategy.generation_node import GenerationNode
@@ -644,6 +646,135 @@ def __init__(
644646
)
645647

646648

649+
class FreshLILOLabelCheck(TrialBasedCriterion):
650+
"""Transition criterion based on the freshness of LILO preference labels.
651+
652+
LILO (Language-in-the-Loop) trials are stamped with a hash of the
653+
experiment state (metric data + LLM messages) at labeling time.
654+
When the experiment state changes (new data arrives, or the user updates
655+
LLM messages), old labels become stale. This criterion gates transitions
656+
based on how many *fresh* labels exist.
657+
658+
The ``require_sufficient`` flag controls the direction:
659+
660+
- **``require_sufficient=True``** (LILO_LABELING -> MBG): ``is_met``
661+
when the number of fresh labels >= ``threshold``. "We have enough
662+
fresh labels -- proceed to BO generation."
663+
- **``require_sufficient=False``** (MBG -> LILO_LABELING): ``is_met``
664+
when the number of fresh labels < ``threshold``. "Labels are stale
665+
-- relabel before generating."
666+
667+
**Non-LILO fallback** (no pairwise ``DerivedMetric`` on the experiment):
668+
``require_sufficient=True`` -> always met (proceed normally).
669+
``require_sufficient=False`` -> never met (never trigger relabeling).
670+
The fallback short-circuits *before* the count comparison so that a
671+
non-LILO experiment with fewer than ``threshold`` trials does not
672+
falsely trigger relabeling.
673+
674+
Args:
675+
threshold: Number of fresh trials for the sufficiency check.
676+
transition_to: The GenerationNode to transition to when met.
677+
require_sufficient: If ``True``, ``is_met`` when fresh count >=
678+
threshold. If ``False``, ``is_met`` when fresh count <
679+
threshold. Defaults to ``True``.
680+
only_in_statuses: Only count trials with these statuses.
681+
not_in_statuses: Exclude trials with these statuses.
682+
use_all_trials_in_exp: Count all experiment trials, not just
683+
those from the current node.
684+
continue_trial_generation: Continue generating arms for the
685+
same trial after transition.
686+
count_only_trials_with_data: Only count trials that have data.
687+
"""
688+
689+
def __init__(
690+
self,
691+
threshold: int,
692+
transition_to: str,
693+
require_sufficient: bool = True,
694+
only_in_statuses: list[TrialStatus] | None = None,
695+
not_in_statuses: list[TrialStatus] | None = None,
696+
use_all_trials_in_exp: bool | None = False,
697+
continue_trial_generation: bool | None = False,
698+
count_only_trials_with_data: bool = False,
699+
) -> None:
700+
self.require_sufficient = require_sufficient
701+
super().__init__(
702+
threshold=threshold,
703+
transition_to=transition_to,
704+
only_in_statuses=only_in_statuses,
705+
not_in_statuses=not_in_statuses,
706+
use_all_trials_in_exp=use_all_trials_in_exp,
707+
continue_trial_generation=continue_trial_generation,
708+
count_only_trials_with_data=count_only_trials_with_data,
709+
)
710+
711+
def num_contributing_to_threshold(
712+
self,
713+
experiment: Experiment,
714+
trials_from_node: set[int],
715+
) -> int:
716+
"""Count trials toward threshold, excluding those with stale hashes.
717+
718+
First applies the standard status-based filtering from the base class,
719+
then further filters to only trials whose LILO input hash matches
720+
the current experiment state.
721+
"""
722+
# Get the base count of candidate trial indices (status-filtered).
723+
all_trials = self.all_trials_to_check(experiment)
724+
if self.count_only_trials_with_data:
725+
data_trial_indices = get_trial_indices_with_required_metrics(
726+
experiment=experiment,
727+
df=experiment.lookup_data().df,
728+
require_data_for_all_metrics=False,
729+
)
730+
all_trials = all_trials.intersection(data_trial_indices)
731+
732+
if not bool(self.use_all_trials_in_exp):
733+
all_trials = trials_from_node.intersection(all_trials)
734+
735+
# Further filter by LILO input hash freshness.
736+
current_hash = get_current_lilo_hash(experiment)
737+
if current_hash is None:
738+
# No pairwise DerivedMetric found — fall back to plain count.
739+
return len(all_trials)
740+
741+
fresh_count = 0
742+
for idx in all_trials:
743+
trial = experiment.trials[idx]
744+
trial_hash = trial._properties.get(Keys.LILO_INPUT_HASH)
745+
# Only count trials that have a LILO_INPUT_HASH (i.e., actual
746+
# LILO labeling trials) and whose hash matches the current state.
747+
# Trials without a hash (regular Sobol/MBG trials) are excluded
748+
# so they don't inflate the fresh-label count.
749+
if trial_hash is not None and trial_hash == current_hash:
750+
fresh_count += 1
751+
752+
return fresh_count
753+
754+
def is_met(
755+
self,
756+
experiment: Experiment,
757+
curr_node: GenerationNode,
758+
) -> bool:
759+
"""Check whether the freshness condition is satisfied.
760+
761+
For non-LILO experiments (no pairwise ``DerivedMetric``), this
762+
short-circuits: ``require_sufficient=True`` → always met,
763+
``require_sufficient=False`` → never met.
764+
"""
765+
# Short-circuit for non-LILO experiments.
766+
if get_current_lilo_hash(experiment) is None:
767+
return self.require_sufficient
768+
769+
count = self.num_contributing_to_threshold(
770+
experiment=experiment, trials_from_node=curr_node.trials_from_node
771+
)
772+
if self.require_sufficient:
773+
return count >= self.threshold
774+
else:
775+
return count < self.threshold
776+
777+
647778
class AuxiliaryExperimentCheck(TransitionCriterion):
648779
"""A class to transition from one GenerationNode to another by checking if certain
649780
types of Auxiliary Experiment purposes exists.

0 commit comments

Comments
 (0)