Skip to content

Commit 00960e9

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Remove pyre-fixme/pyre-ignore from ax/ small directories (#4981)
Summary: Pull Request resolved: #4981 Remove pyre-fixme/pyre-ignore suppression comments from 17 files across several directories: plot/, analysis/, generation_strategy/, benchmark/, metrics/, global_stopping/ (both source and test files). Key fixes: - Widen `AxPlotConfig.__new__` data param to `dict[str, Any] | Figure`, eliminating 14 pyre-fixme[6] suppressions across ax/plot/ - Use `none_throws()` for Optional unwrapping - Use `assert_is_instance()` for type narrowing - Add proper type annotations (`AnalysisCard`, `SklearnDataset`, `TParamValue`) - Use `float()` wrapping for numpy scalar arithmetic - Use `npt.NDArray` instead of bare `np.ndarray` - Use `int()` for cardinality() return values - Fix `tuple[int]` -> `tuple[int, int, int]` for RGB color tuples - Annotate `DISCRETE_COLOR_SCALE` global Reviewed By: dme65 Differential Revision: D95264987 fbshipit-source-id: 3eb128e5552913ac0bb7127f110079ccbc8fb08c
1 parent 3df1600 commit 00960e9

File tree

17 files changed

+54
-78
lines changed

17 files changed

+54
-78
lines changed

ax/analysis/healthcheck/tests/test_early_stopping_healthcheck.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pandas as pd
1212
from ax.analysis.healthcheck.early_stopping_healthcheck import EarlyStoppingAnalysis
1313
from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckStatus
14+
from ax.core.analysis_card import AnalysisCard
1415
from ax.core.data import Data
1516
from ax.core.experiment import Experiment
1617
from ax.core.objective import MultiObjective, Objective
@@ -39,9 +40,8 @@ def setUp(self) -> None:
3940
early_stopping_strategy=self.early_stopping_strategy
4041
)
4142

42-
def _get_df_dict(self, card: object) -> dict[str, str]:
43+
def _get_df_dict(self, card: AnalysisCard) -> dict[str, str]:
4344
"""Extract Property -> Value dict from card dataframe."""
44-
# pyre-ignore[16]: card has df attribute
4545
return {row["Property"]: row["Value"] for _, row in card.df.iterrows()}
4646

4747
def _create_map_data(

ax/benchmark/benchmark_result.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,7 @@ def from_benchmark_results(
171171
"""
172172
# Extract average wall times and standard errors thereof
173173
fit_time, gen_time = (
174-
# pyre-fixme [16]: Item `float` of `typing.Union[numpy.ndarray[typing.Any,
175-
# typing.Any], float]` has no attribute `item`.
176-
[nanmean(Ts).item(), sem(Ts, ddof=1, nan_policy="propagate").item()]
174+
[float(nanmean(Ts)), float(sem(Ts, ddof=1, nan_policy="propagate"))]
177175
for Ts in zip(*((res.fit_time, res.gen_time) for res in results))
178176
)
179177

ax/generation_strategy/best_model_selector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from ax.utils.common.func_enum import FuncEnum
2121
from pyre_extensions import none_throws
2222

23-
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
24-
ARRAYLIKE = np.ndarray | list[float] | list[np.ndarray]
23+
ARRAYLIKE = npt.NDArray | list[float] | list[npt.NDArray]
2524

2625

2726
class BestModelSelector(ABC, Base):

ax/generation_strategy/dispatch_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,7 @@ def _suggest_gp_model(
186186
all_range_parameters_are_discrete = False
187187
else:
188188
num_param_discrete_values = parameter.cardinality()
189-
# pyre-fixme[58]: `*` is not supported for operand types `int` and
190-
# `Union[float, int]`.
191-
num_possible_points *= num_param_discrete_values
189+
num_possible_points *= int(num_param_discrete_values)
192190

193191
if should_enumerate_param:
194192
num_enumerated_combinations *= none_throws(num_param_discrete_values)

ax/global_stopping/tests/test_strategies.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,9 @@ def test_multi_objective(self) -> None:
297297
gss = ImprovementGlobalStoppingStrategy(
298298
min_trials=3, window_size=3, improvement_bar=0.1
299299
)
300-
objectives = exp.optimization_config.objective.objectives # pyre-ignore
300+
objectives = assert_is_instance(
301+
none_throws(exp.optimization_config).objective, MultiObjective
302+
).objectives
301303
custom_objective_thresholds = [
302304
ObjectiveThreshold(
303305
metric=objectives[0].metric,

ax/metrics/sklearn.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ class SklearnDataset(StrEnum):
4040

4141

4242
@lru_cache(maxsize=8)
43-
# pyre-fixme[2]: Parameter must be annotated.
44-
def _get_data(dataset) -> dict[str, npt.NDArray]:
43+
def _get_data(dataset: SklearnDataset) -> dict[str, npt.NDArray]:
4544
"""Return sklearn dataset, loading and caching if necessary."""
4645
if dataset is SklearnDataset.DIGITS:
4746
return datasets.load_digits()
@@ -105,8 +104,7 @@ def __init__(
105104
)
106105
if model_type is SklearnModelType.NN:
107106
if regression:
108-
# pyre-fixme[4]: Attribute must be annotated.
109-
self._model_cls = MLPRegressor
107+
self._model_cls: Any = MLPRegressor
110108
else:
111109
self._model_cls = MLPClassifier
112110
elif model_type is SklearnModelType.RF:

ax/metrics/tests/test_chemistry.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919

2020

2121
class DummyEnum(Enum):
22-
# pyre-fixme[35]: Target cannot be annotated.
23-
DUMMY: str = "dummy"
22+
DUMMY = "dummy"
2423

2524

2625
class ChemistryMetricTest(TestCase):

ax/metrics/tests/test_noisy_function.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88

99
import math
1010

11+
from ax.core.types import TParamValue
1112
from ax.metrics.noisy_function import GenericNoisyFunctionMetric
1213
from ax.utils.common.testutils import TestCase
1314
from ax.utils.testing.core_stubs import get_trial
15+
from pyre_extensions import none_throws
1416

1517

1618
class GenericNoisyFunctionMetricTest(TestCase):
1719
def test_GenericNoisyFunctionMetric(self) -> None:
18-
# pyre-fixme[3]: Return type must be annotated.
19-
# pyre-fixme[2]: Parameter must be annotated.
20-
def f(params):
21-
return params["x"] + 1.0
20+
def f(params: dict[str, TParamValue]) -> float:
21+
return float(params["x"]) + 1.0
2222

2323
# noiseless
2424
metric = GenericNoisyFunctionMetric(
@@ -29,8 +29,10 @@ def f(params):
2929
df = metric.fetch_trial_data(trial).unwrap().df
3030
self.assertEqual(df["arm_name"].tolist(), ["0_0"])
3131
self.assertEqual(df["metric_name"].tolist(), ["test_metric"])
32-
# pyre-fixme[16]: Optional type has no attribute `parameters`.
33-
self.assertEqual(df["mean"].tolist(), [trial.arm.parameters["x"] + 1.0])
32+
self.assertEqual(
33+
df["mean"].tolist(),
34+
[float(none_throws(trial.arm).parameters["x"]) + 1.0],
35+
)
3436
self.assertEqual(df["sem"].tolist(), [0.0])
3537

3638
# noisy
@@ -43,12 +45,16 @@ def f(params):
4345
df = metric.fetch_trial_data(trial).unwrap().df
4446
self.assertEqual(df["arm_name"].tolist(), ["0_0"])
4547
self.assertEqual(df["metric_name"].tolist(), ["test_metric"])
46-
self.assertNotEqual(df["mean"].tolist(), [trial.arm.parameters["x"] + 1.0])
48+
self.assertNotEqual(
49+
df["mean"].tolist(),
50+
[float(none_throws(trial.arm).parameters["x"]) + 1.0],
51+
)
4752
self.assertEqual(df["sem"].tolist(), [1.0])
4853
df = metric.fetch_trial_data(trial, noisy=False).unwrap().df
4954
self.assertEqual(df["arm_name"].tolist(), ["0_0"])
5055
self.assertEqual(df["metric_name"].tolist(), ["test_metric"])
51-
self.assertEqual(df["mean"].tolist(), [trial.arm.parameters["x"] + 1.0])
56+
arm = none_throws(trial.arm)
57+
self.assertEqual(df["mean"].tolist(), [float(arm.parameters["x"]) + 1.0])
5258
self.assertEqual(df["sem"].tolist(), [0.0])
5359

5460
# unknown noise level
@@ -61,6 +67,7 @@ def f(params):
6167
df = metric.fetch_trial_data(trial).unwrap().df
6268
self.assertEqual(df["arm_name"].tolist(), ["0_0"])
6369
self.assertEqual(df["metric_name"].tolist(), ["test_metric"])
64-
self.assertEqual(df["mean"].tolist(), [trial.arm.parameters["x"] + 1.0])
65-
self.assertEqual(df["mean"].tolist(), [trial.arm.parameters["x"] + 1.0])
70+
arm = none_throws(trial.arm)
71+
self.assertEqual(df["mean"].tolist(), [float(arm.parameters["x"]) + 1.0])
72+
self.assertEqual(df["mean"].tolist(), [float(arm.parameters["x"]) + 1.0])
6673
self.assertTrue(math.isnan(df["sem"].tolist()[0]))

ax/plot/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ax.core.types import TParameterization
1414
from ax.utils.common.serialization import named_tuple_to_dict
1515
from plotly import utils
16+
from plotly.graph_objs import Figure
1617

1718

1819
# Constants used for numerous plots
@@ -43,7 +44,9 @@ class _AxPlotConfigBase(NamedTuple):
4344
class AxPlotConfig(_AxPlotConfigBase):
4445
"""Config for plots"""
4546

46-
def __new__(cls, data: dict[str, Any], plot_type: enum.Enum) -> "AxPlotConfig":
47+
def __new__(
48+
cls, data: dict[str, Any] | Figure, plot_type: enum.Enum
49+
) -> "AxPlotConfig":
4750
# Convert data to json-encodable form (strips out NamedTuple and numpy
4851
# array). This is a lossy conversion.
4952
dict_data = json.loads(

ax/plot/color.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ class COLORS(enum.Enum):
2323

2424

2525
# colors to be used for plotting discrete series
26-
# pyre-fixme[5]: Global expression must be annotated.
27-
DISCRETE_COLOR_SCALE = [
26+
DISCRETE_COLOR_SCALE: list[tuple[int, int, int]] = [
2827
COLORS.STEELBLUE.value,
2928
COLORS.CORAL.value,
3029
COLORS.PINK.value,

0 commit comments

Comments
 (0)