Skip to content

Commit 3df1600

Browse files
eonofreymeta-codesync[bot]
authored andcommitted
TransferLearningAnalysis (#4918)
Summary: Pull Request resolved: #4918 Analysis card to show transferrable experiments with a default of 25% parameter overlap. Reviewed By: mpolson64 Differential Revision: D92926519 fbshipit-source-id: dec8f1ef24ddec87848f965e0740313c5200c7c5
1 parent 3ceab73 commit 3df1600

File tree

6 files changed

+358
-21
lines changed

6 files changed

+358
-21
lines changed

ax/analysis/healthcheck/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ax.analysis.healthcheck.regression_analysis import RegressionAnalysis
2525
from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis
2626
from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates
27+
from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis
2728

2829
__all__ = [
2930
"create_healthcheck_analysis_card",
@@ -39,4 +40,5 @@
3940
"ComplexityRatingAnalysis",
4041
"PredictableMetricsAnalysis",
4142
"BaselineImprovementAnalysis",
43+
"TransferLearningAnalysis",
4244
]
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
from unittest.mock import patch
9+
10+
from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckStatus
11+
from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis
12+
from ax.core.auxiliary import TransferLearningMetadata
13+
from ax.core.experiment import Experiment
14+
from ax.core.parameter import ParameterType, RangeParameter
15+
from ax.core.search_space import SearchSpace
16+
from ax.exceptions.core import UserInputError
17+
from ax.utils.common.testutils import TestCase
18+
19+
20+
def _make_experiment(
21+
param_names: list[str],
22+
experiment_type: str | None = None,
23+
) -> Experiment:
24+
"""Create a simple experiment with the given parameter names."""
25+
return Experiment(
26+
search_space=SearchSpace(
27+
parameters=[
28+
RangeParameter(
29+
name=name,
30+
parameter_type=ParameterType.FLOAT,
31+
lower=0.0,
32+
upper=1.0,
33+
)
34+
for name in param_names
35+
]
36+
),
37+
name="test_experiment",
38+
experiment_type=experiment_type,
39+
)
40+
41+
42+
_MOCK_TARGET = "ax.storage.sqa_store.load.identify_transferable_experiments"
43+
44+
45+
class TestTransferLearningAnalysis(TestCase):
46+
def test_no_experiment_type_returns_pass(self) -> None:
47+
"""When no experiment_type is set and no experiment_types provided,
48+
return PASS."""
49+
experiment = _make_experiment(["x1", "x2"], experiment_type=None)
50+
analysis = TransferLearningAnalysis()
51+
card = analysis.compute(experiment=experiment)
52+
self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
53+
self.assertTrue(card.is_passing())
54+
self.assertIn("No experiment type set", card.subtitle)
55+
56+
@patch(_MOCK_TARGET, return_value={})
57+
def test_no_candidates_returns_pass(self, mock_identify: object) -> None:
58+
experiment = _make_experiment(["x1", "x2"], experiment_type="my_type")
59+
analysis = TransferLearningAnalysis()
60+
card = analysis.compute(experiment=experiment)
61+
self.assertEqual(card.get_status(), HealthcheckStatus.PASS)
62+
self.assertTrue(card.is_passing())
63+
self.assertTrue(card.df.empty)
64+
65+
@patch(_MOCK_TARGET)
66+
def test_single_candidate_returns_warning(self, mock_identify: object) -> None:
67+
experiment = _make_experiment(
68+
["x1", "x2", "x3", "x4", "x5"], experiment_type="my_type"
69+
)
70+
mock_identify.return_value = { # pyre-ignore[16]
71+
"source_exp": TransferLearningMetadata(
72+
overlap_parameters=["x1", "x2", "x3", "x4"],
73+
),
74+
}
75+
analysis = TransferLearningAnalysis()
76+
card = analysis.compute(experiment=experiment)
77+
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
78+
self.assertFalse(card.is_passing())
79+
self.assertIn("source_exp", card.subtitle)
80+
self.assertIn("80.0%", card.subtitle)
81+
self.assertEqual(len(card.df), 1)
82+
self.assertEqual(card.df.iloc[0]["Experiment"], "source_exp")
83+
self.assertEqual(card.df.iloc[0]["Overlapping Parameters"], 4)
84+
self.assertEqual(card.df.iloc[0]["Overlap (%)"], 80.0)
85+
86+
@patch(_MOCK_TARGET)
87+
def test_multiple_candidates_preserves_order(self, mock_identify: object) -> None:
88+
"""Results should preserve the order from identify_transferable_experiments
89+
(sorted by overlap then recency)."""
90+
experiment = _make_experiment(
91+
["x1", "x2", "x3", "x4", "x5"], experiment_type="my_type"
92+
)
93+
# Mock returns already-sorted results (as identify_transferable_experiments
94+
# now handles sorting by overlap desc, then recency desc).
95+
mock_identify.return_value = { # pyre-ignore[16]
96+
"exp_high": TransferLearningMetadata(
97+
overlap_parameters=["x1", "x2", "x3", "x4"],
98+
),
99+
"exp_mid": TransferLearningMetadata(
100+
overlap_parameters=["x1", "x2", "x3"],
101+
),
102+
"exp_low": TransferLearningMetadata(
103+
overlap_parameters=["x1"],
104+
),
105+
}
106+
analysis = TransferLearningAnalysis()
107+
card = analysis.compute(experiment=experiment)
108+
self.assertEqual(card.get_status(), HealthcheckStatus.WARNING)
109+
110+
# Verify order is preserved from identify_transferable_experiments
111+
self.assertEqual(card.df.iloc[0]["Experiment"], "exp_high")
112+
self.assertEqual(card.df.iloc[0]["Overlapping Parameters"], 4)
113+
self.assertEqual(card.df.iloc[1]["Experiment"], "exp_mid")
114+
self.assertEqual(card.df.iloc[1]["Overlapping Parameters"], 3)
115+
self.assertEqual(card.df.iloc[2]["Experiment"], "exp_low")
116+
self.assertEqual(card.df.iloc[2]["Overlapping Parameters"], 1)
117+
118+
# All experiments listed in subtitle
119+
self.assertIn("exp_high", card.subtitle)
120+
self.assertIn("exp_mid", card.subtitle)
121+
self.assertIn("exp_low", card.subtitle)
122+
self.assertIn("We found **3 eligible source experiment(s)**", card.subtitle)
123+
124+
@patch(_MOCK_TARGET)
125+
def test_percentage_calculation(self, mock_identify: object) -> None:
126+
experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type")
127+
mock_identify.return_value = { # pyre-ignore[16]
128+
"exp_a": TransferLearningMetadata(
129+
overlap_parameters=["x1"],
130+
),
131+
}
132+
analysis = TransferLearningAnalysis()
133+
card = analysis.compute(experiment=experiment)
134+
self.assertEqual(card.df.iloc[0]["Overlap (%)"], 33.3)
135+
136+
@patch(_MOCK_TARGET)
137+
def test_parameters_listed_alphabetically(self, mock_identify: object) -> None:
138+
experiment = _make_experiment(
139+
["alpha", "beta", "gamma", "delta"], experiment_type="my_type"
140+
)
141+
mock_identify.return_value = { # pyre-ignore[16]
142+
"exp_a": TransferLearningMetadata(
143+
overlap_parameters=["gamma", "alpha", "delta"],
144+
),
145+
}
146+
analysis = TransferLearningAnalysis()
147+
card = analysis.compute(experiment=experiment)
148+
self.assertEqual(card.df.iloc[0]["Parameters"], "alpha, delta, gamma")
149+
150+
def test_requires_experiment(self) -> None:
151+
analysis = TransferLearningAnalysis()
152+
with self.assertRaises(UserInputError):
153+
analysis.compute(experiment=None)
154+
155+
@patch(_MOCK_TARGET, return_value={})
156+
def test_experiment_name_passed_to_identify(self, mock_identify: object) -> None:
157+
"""Verify that experiment.name is forwarded to
158+
identify_transferable_experiments so it can filter the target out."""
159+
experiment = _make_experiment(["x1", "x2", "x3"], experiment_type="my_type")
160+
analysis = TransferLearningAnalysis()
161+
analysis.compute(experiment=experiment)
162+
mock_identify.assert_called_once() # pyre-ignore[16]
163+
call_kwargs = mock_identify.call_args.kwargs # pyre-ignore[16]
164+
self.assertEqual(call_kwargs["experiment_name"], "test_experiment")
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
from __future__ import annotations
9+
10+
import json
11+
from typing import final, TYPE_CHECKING
12+
13+
import markdown as md
14+
import pandas as pd
15+
from ax.adapter.base import Adapter
16+
from ax.analysis.analysis import Analysis
17+
from ax.analysis.healthcheck.healthcheck_analysis import (
18+
create_healthcheck_analysis_card,
19+
HealthcheckAnalysisCard,
20+
HealthcheckStatus,
21+
)
22+
from ax.core.experiment import Experiment
23+
from ax.exceptions.core import UserInputError
24+
from ax.generation_strategy.generation_strategy import GenerationStrategy
25+
from pyre_extensions import override
26+
27+
if TYPE_CHECKING:
28+
from ax.storage.sqa_store.sqa_config import SQAConfig
29+
30+
31+
class TransferLearningAnalysisCard(HealthcheckAnalysisCard):
32+
"""HealthcheckAnalysisCard with markdown-aware rendering for notebooks."""
33+
34+
def _body_html(self, depth: int) -> str:
35+
parts = [md.markdown(self.subtitle)]
36+
if not self.df.empty:
37+
parts.append(self.df.to_html(index=False))
38+
return f"<div class='content'>{''.join(parts)}</div>"
39+
40+
41+
@final
42+
class TransferLearningAnalysis(Analysis):
43+
def __init__(
44+
self,
45+
experiment_types: list[str] | None = None,
46+
overlap_threshold: float = 0.50,
47+
max_num_exps: int = 10,
48+
config: SQAConfig | None = None,
49+
) -> None:
50+
self.experiment_types = experiment_types
51+
self.overlap_threshold = overlap_threshold
52+
self.max_num_exps = max_num_exps
53+
self.config = config
54+
55+
@override
56+
def compute(
57+
self,
58+
experiment: Experiment | None = None,
59+
generation_strategy: GenerationStrategy | None = None,
60+
adapter: Adapter | None = None,
61+
) -> HealthcheckAnalysisCard:
62+
if experiment is None:
63+
raise UserInputError(
64+
"TransferLearningAnalysis requires a non-null experiment to compute "
65+
"overlap percentages. Please provide an experiment."
66+
)
67+
68+
# Determine experiment types to query for.
69+
experiment_types = self.experiment_types
70+
if experiment_types is None:
71+
if experiment.experiment_type is None:
72+
return create_healthcheck_analysis_card(
73+
name=self.__class__.__name__,
74+
title="Transfer Learning Eligibility",
75+
subtitle=(
76+
"No experiment type set on this experiment. "
77+
"Cannot search for transferable experiments."
78+
),
79+
df=pd.DataFrame(),
80+
status=HealthcheckStatus.PASS,
81+
)
82+
experiment_types = [experiment.experiment_type]
83+
84+
# Lazy import to avoid circular dependency (sqa_store depends on
85+
# healthcheck_analysis).
86+
from ax.storage.sqa_store.load import identify_transferable_experiments
87+
88+
transferable_experiments = identify_transferable_experiments(
89+
search_space=experiment.search_space,
90+
experiment_types=experiment_types,
91+
overlap_threshold=self.overlap_threshold,
92+
max_num_exps=self.max_num_exps,
93+
config=self.config,
94+
experiment_name=experiment.name,
95+
)
96+
97+
if not transferable_experiments:
98+
return create_healthcheck_analysis_card(
99+
name=self.__class__.__name__,
100+
title="Transfer Learning Eligibility",
101+
subtitle="No eligible source experiments found for transfer learning.",
102+
df=pd.DataFrame(),
103+
status=HealthcheckStatus.PASS,
104+
)
105+
106+
total_parameters = len(experiment.search_space.parameters)
107+
108+
rows = []
109+
for exp_name, metadata in transferable_experiments.items():
110+
overlap_count = len(metadata.overlap_parameters)
111+
overlap_pct = (
112+
(overlap_count / total_parameters * 100)
113+
if total_parameters > 0
114+
else 0.0
115+
)
116+
rows.append(
117+
{
118+
"Experiment": exp_name,
119+
"Overlapping Parameters": overlap_count,
120+
"Overlap (%)": round(overlap_pct, 1),
121+
"Parameters": ", ".join(sorted(metadata.overlap_parameters)),
122+
}
123+
)
124+
125+
df = pd.DataFrame(rows)
126+
127+
n = len(rows)
128+
exp_lines = "\n".join(
129+
f"- **{r['Experiment']}** ({r['Overlap (%)']:.1f}% parameter overlap)"
130+
for r in rows
131+
)
132+
subtitle = (
133+
"Transfer learning can improve optimization by leveraging data "
134+
"from similar past experiments. We found "
135+
f"**{n} eligible source experiment(s)** "
136+
"for transfer learning:\n\n"
137+
f"{exp_lines}\n\n"
138+
"Caution: Only use source experiments that are closely related "
139+
"to your current experiment. "
140+
"Using data from unrelated experiments can lead to negative "
141+
"transfer, which may hurt "
142+
"optimization performance. Review the overlapping parameters "
143+
"before enabling transfer learning."
144+
)
145+
146+
return TransferLearningAnalysisCard(
147+
name=self.__class__.__name__,
148+
title="Transfer Learning Eligibility",
149+
subtitle=subtitle,
150+
df=df,
151+
blob=json.dumps({"status": HealthcheckStatus.WARNING}),
152+
)

ax/analysis/overview.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ax.analysis.healthcheck.predictable_metrics import PredictableMetricsAnalysis
2525
from ax.analysis.healthcheck.search_space_analysis import SearchSpaceAnalysis
2626
from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates
27+
from ax.analysis.healthcheck.transfer_learning_analysis import TransferLearningAnalysis
2728
from ax.analysis.insights import InsightsAnalysis
2829
from ax.analysis.results import ResultsAnalysis
2930
from ax.analysis.trials import AllTrialsAnalysis
@@ -114,6 +115,7 @@ def __init__(
114115
options: OrchestratorOptions | None = None,
115116
tier_metadata: dict[str, Any] | None = None,
116117
model_fit_threshold: float | None = None,
118+
sqa_config: Any = None,
117119
) -> None:
118120
super().__init__()
119121
self.can_generate = can_generate
@@ -124,6 +126,7 @@ def __init__(
124126
self.options = options
125127
self.tier_metadata = tier_metadata
126128
self.model_fit_threshold = model_fit_threshold
129+
self.sqa_config = sqa_config
127130

128131
@override
129132
def validate_applicable_state(
@@ -229,6 +232,7 @@ def compute(
229232
if not has_batch_trials
230233
else None,
231234
BaselineImprovementAnalysis() if not has_batch_trials else None,
235+
TransferLearningAnalysis(config=self.sqa_config),
232236
*[
233237
SearchSpaceAnalysis(trial_index=trial.index)
234238
for trial in candidate_trials

0 commit comments

Comments
 (0)