Skip to content

Commit 3e6a0de

Browse files
mpolson64facebook-github-bot
authored andcommitted
New parallel coordinates plot (facebook#2590)
Summary: Pull Request resolved: facebook#2590 An improved version of the parallel coordinates analysis I implemented earlier this year, now refactored with AnalysisCards. Main improvements include: * Ability to infer what metric to use based on the OptimizationConfig if one is not provided * Compatibility with ChoiceParameters and FixedParameters * Truncation of long parameter and metric names where appropriate NOTE: This analysis introduces a number of helper functions in parallel_coordinates.py -- as we add more analyses these should be refactored out into analysis/plotly/utils.py or analysis/utils.py as appropriate. Reviewed By: Cesar-Cardoso Differential Revision: D59927703 fbshipit-source-id: 57e0c01438c43f8464912e83d2b68b2521321108
1 parent c0fa168 commit 3e6a0de

File tree

3 files changed

+269
-1
lines changed

3 files changed

+269
-1
lines changed

ax/analysis/plotly/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
# pyre-strict
77

8+
from ax.analysis.plotly.parallel_coordinates.parallel_coordinates import (
9+
ParallelCoordinatesPlot,
10+
)
811
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
912

10-
__all__ = ["PlotlyAnalysis", "PlotlyAnalysisCard"]
13+
__all__ = ["PlotlyAnalysis", "PlotlyAnalysisCard", "ParallelCoordinatesPlot"]
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, Dict, Optional
7+
8+
import numpy as np
9+
import pandas as pd
10+
from ax.analysis.analysis import AnalysisCardLevel
11+
12+
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
13+
from ax.core.experiment import Experiment
14+
from ax.core.objective import MultiObjective, ScalarizedObjective
15+
from ax.exceptions.core import UnsupportedError, UserInputError
16+
from ax.modelbridge.generation_strategy import GenerationStrategy
17+
from plotly import graph_objects as go
18+
19+
20+
class ParallelCoordinatesPlot(PlotlyAnalysis):
21+
"""
22+
Plotly Parcoords plot for a single metric, with one line per arm and dimensions for
23+
each parameter in the search space. This plot is useful for understanding how
24+
thoroughly the search space is explored as well as for identifying if there is any
25+
clusertering for either good or bad parameterizations.
26+
27+
The DataFrame computed will contain one row per arm and the following columns:
28+
- arm_name: The name of the arm
29+
- METRIC_NAME: The observed mean of the metric specified
30+
- **PARAMETER_NAME: The value of said parameter for the arm, for each parameter
31+
"""
32+
33+
def __init__(self, metric_name: Optional[str] = None) -> None:
34+
"""
35+
Args:
36+
metric_name: The name of the metric to plot. If not specified the objective
37+
will be used. Note that the metric cannot be inferred for
38+
multi-objective or scalarized-objective experiments.
39+
"""
40+
41+
self.metric_name = metric_name
42+
43+
def compute(
44+
self,
45+
experiment: Optional[Experiment] = None,
46+
generation_strategy: Optional[GenerationStrategy] = None,
47+
) -> PlotlyAnalysisCard:
48+
if experiment is None:
49+
raise UserInputError("ParallelCoordinatesPlot requires an Experiment")
50+
51+
metric_name = self.metric_name or _select_metric(experiment=experiment)
52+
53+
df = _prepare_data(experiment=experiment, metric=metric_name)
54+
fig = _prepare_plot(df=df, metric_name=metric_name)
55+
56+
return PlotlyAnalysisCard(
57+
name=self.__class__.__name__,
58+
title=f"Parallel Coordinates for {metric_name}",
59+
subtitle="View arm parameterizations with their respective metric values",
60+
level=AnalysisCardLevel.HIGH,
61+
df=df,
62+
blob=fig,
63+
)
64+
65+
66+
def _prepare_data(experiment: Experiment, metric: str) -> pd.DataFrame:
67+
data_df = experiment.lookup_data().df
68+
filtered_df = data_df.loc[data_df["metric_name"] == metric]
69+
70+
if filtered_df.empty:
71+
raise ValueError(f"No data found for metric {metric}")
72+
73+
records = [
74+
{
75+
"arm_name": arm.name,
76+
**arm.parameters,
77+
metric: _find_mean_by_arm_name(df=filtered_df, arm_name=arm.name),
78+
}
79+
for trial in experiment.trials.values()
80+
for arm in trial.arms
81+
]
82+
83+
return pd.DataFrame.from_records(records)
84+
85+
86+
def _prepare_plot(df: pd.DataFrame, metric_name: str) -> go.Figure:
87+
88+
# ParCoords requires that the dimensions are specified on continuous scales, so
89+
# ChoiceParameters and FixedParameters must be preprocessed to allow for
90+
# appropriate plotting.
91+
parameter_dimensions = [
92+
_get_parameter_dimension(series=df[col])
93+
for col in df.columns
94+
if col != "arm_name" and col != metric_name
95+
]
96+
97+
return go.Figure(
98+
go.Parcoords(
99+
line={
100+
"color": df[metric_name],
101+
"showscale": True,
102+
},
103+
dimensions=[
104+
*parameter_dimensions,
105+
{
106+
"label": _truncate_label(label=metric_name),
107+
"values": df[metric_name].tolist(),
108+
},
109+
],
110+
# Rotate the labels to allow them to be longer withoutoverlapping
111+
labelangle=-45,
112+
)
113+
)
114+
115+
116+
def _select_metric(experiment: Experiment) -> str:
117+
if experiment.optimization_config is None:
118+
raise ValueError(
119+
"Cannot infer metric to plot from Experiment without OptimizationConfig"
120+
)
121+
objective = experiment.optimization_config.objective
122+
if isinstance(objective, MultiObjective):
123+
raise UnsupportedError(
124+
"Cannot infer metric to plot from MultiObjective, please "
125+
"specify a metric"
126+
)
127+
if isinstance(objective, ScalarizedObjective):
128+
raise UnsupportedError(
129+
"Cannot infer metric to plot from ScalarizedObjective, please "
130+
"specify a metric"
131+
)
132+
return experiment.optimization_config.objective.metric.name
133+
134+
135+
def _find_mean_by_arm_name(
136+
df: pd.DataFrame,
137+
arm_name: str,
138+
) -> float:
139+
# Given a dataframe with arm_name and mean columns, find the mean for a given
140+
# arm_name. If an arm_name is not found (as can happen if the arm is still running
141+
# or has failed) return NaN.
142+
series = df.loc[df["arm_name"] == arm_name]["mean"]
143+
144+
if series.empty:
145+
return np.nan
146+
147+
return series.item()
148+
149+
150+
def _get_parameter_dimension(series: pd.Series) -> Dict[str, Any]:
151+
# For numeric parameters allow Plotly to infer tick attributes. Note: booleans are
152+
# considered numeric, but in this case we want to treat them as categorical.
153+
if pd.api.types.is_numeric_dtype(series) and not pd.api.types.is_bool_dtype(series):
154+
return {
155+
"tickvals": None,
156+
"ticktext": None,
157+
"label": _truncate_label(label=str(series.name)),
158+
"values": series.tolist(),
159+
}
160+
161+
# For non-numeric parameters, sort, map onto an integer scale, and provide
162+
# corresponding tick attributes
163+
mapping = {v: k for k, v in enumerate(sorted(series.unique()))}
164+
165+
return {
166+
"tickvals": [_truncate_label(label=str(val)) for val in mapping.values()],
167+
"ticktext": [_truncate_label(label=str(key)) for key in mapping.keys()],
168+
"label": _truncate_label(label=str(series.name)),
169+
"values": series.map(mapping).tolist(),
170+
}
171+
172+
173+
def _truncate_label(label: str, n: int = 18) -> str:
174+
if len(label) > n:
175+
return label[:n] + "..."
176+
return label
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import pandas as pd
7+
from ax.analysis.analysis import AnalysisCardLevel
8+
from ax.analysis.plotly.parallel_coordinates.parallel_coordinates import (
9+
_get_parameter_dimension,
10+
_select_metric,
11+
ParallelCoordinatesPlot,
12+
)
13+
from ax.exceptions.core import UnsupportedError, UserInputError
14+
from ax.utils.common.testutils import TestCase
15+
from ax.utils.testing.core_stubs import (
16+
get_branin_experiment,
17+
get_experiment_with_multi_objective,
18+
get_experiment_with_scalarized_objective_and_outcome_constraint,
19+
)
20+
21+
22+
class TestParallelCoordinatesPlot(TestCase):
23+
def test_compute(self) -> None:
24+
analysis = ParallelCoordinatesPlot("branin")
25+
experiment = get_branin_experiment(with_completed_trial=True)
26+
27+
with self.assertRaisesRegex(UserInputError, "requires an Experiment"):
28+
analysis.compute()
29+
30+
card = analysis.compute(experiment=experiment)
31+
self.assertEqual(card.name, "ParallelCoordinatesPlot")
32+
self.assertEqual(card.title, "Parallel Coordinates for branin")
33+
self.assertEqual(
34+
card.subtitle,
35+
"View arm parameterizations with their respective metric values",
36+
)
37+
self.assertEqual(card.level, AnalysisCardLevel.HIGH)
38+
self.assertEqual({*card.df.columns}, {"arm_name", "branin", "x1", "x2"})
39+
self.assertIsNotNone(card.blob)
40+
self.assertEqual(card.blob_annotation, "plotly")
41+
42+
analysis_no_metric = ParallelCoordinatesPlot()
43+
_ = analysis_no_metric.compute(experiment=experiment)
44+
45+
def test_select_metric(self) -> None:
46+
experiment = get_branin_experiment()
47+
experiment_no_optimization_config = get_branin_experiment(
48+
has_optimization_config=False
49+
)
50+
experiment_multi_objective = get_experiment_with_multi_objective()
51+
experiment_scalarized_objective = (
52+
get_experiment_with_scalarized_objective_and_outcome_constraint()
53+
)
54+
55+
self.assertEqual(_select_metric(experiment=experiment), "branin")
56+
57+
with self.assertRaisesRegex(ValueError, "OptimizationConfig"):
58+
_select_metric(experiment=experiment_no_optimization_config)
59+
60+
with self.assertRaisesRegex(UnsupportedError, "MultiObjective"):
61+
_select_metric(experiment=experiment_multi_objective)
62+
63+
with self.assertRaisesRegex(UnsupportedError, "ScalarizedObjective"):
64+
_select_metric(experiment=experiment_scalarized_objective)
65+
66+
def test_get_parameter_dimension(self) -> None:
67+
range_series = pd.Series([0, 1, 2, 3], name="range")
68+
range_dimension = _get_parameter_dimension(series=range_series)
69+
self.assertEqual(
70+
range_dimension,
71+
{
72+
"tickvals": None,
73+
"ticktext": None,
74+
"label": "range",
75+
"values": range_series.tolist(),
76+
},
77+
)
78+
79+
choice_series = pd.Series(["foo", "bar", "baz"], name="choice")
80+
choice_dimension = _get_parameter_dimension(series=choice_series)
81+
self.assertEqual(
82+
choice_dimension,
83+
{
84+
"tickvals": ["0", "1", "2"],
85+
"ticktext": ["bar", "baz", "foo"],
86+
"label": "choice",
87+
"values": [2, 0, 1],
88+
},
89+
)

0 commit comments

Comments
 (0)