Skip to content

Commit 2d6bc40

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Remove pyre-fixme/pyre-ignore from ax/utils/ source files (#4983)
Summary: Remove ~125 pyre-fixme/pyre-ignore suppression comments from 21 source files in ax/utils/ by applying proper type fixes: - Use `none_throws()` for Optional unwrapping (e.g., grad access, sim_start_time) - Use `cast()` for GPyTorch model and FixedFeatureModel narrowing - Fix covariant TypeVar issue in Result.unwrap_or with new TypeVar U - Add proper annotations to StreamHandler[Any], Callable[..., Any] - Use helper functions for dynamic GPyTorch attribute access - Use `np.asarray()` for numpy return type consistency - Fix scipy stub limitations with positional args - Use `record.__dict__` for safe attribute assignment All 100 tests pass (0 failures across 6 test targets). Reviewed By: dme65 Differential Revision: D95265024
1 parent f973cb4 commit 2d6bc40

22 files changed

+266
-340
lines changed

ax/adapter/adapter_utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence
1313
from copy import deepcopy
1414
from logging import Logger
15-
from typing import Any, SupportsFloat, TYPE_CHECKING
15+
from typing import Any, cast, SupportsFloat, TYPE_CHECKING
1616

1717
import numpy as np
1818
import numpy.typing as npt
@@ -135,16 +135,18 @@ def extract_search_space_digest(
135135
if isinstance(p, ChoiceParameter):
136136
if p.is_task:
137137
task_features.append(i)
138-
target_values[i] = assert_is_instance_of_tuple(
139-
p.target_value, (int, float)
138+
target_values[i] = cast(
139+
TNumeric,
140+
assert_is_instance_of_tuple(p.target_value, (int, float)),
140141
)
141142
elif p.is_ordered:
142143
ordinal_features.append(i)
143144
else:
144145
categorical_features.append(i)
145146
# at this point we can assume that values are numeric due to transforms
146147
numeric_values: list[TNumeric] = [
147-
assert_is_instance_of_tuple(v, (int, float)) for v in p.values
148+
cast(TNumeric, assert_is_instance_of_tuple(v, (int, float)))
149+
for v in p.values
148150
]
149151
discrete_choices[i] = numeric_values
150152
bounds.append((min(numeric_values), max(numeric_values)))
@@ -164,15 +166,21 @@ def extract_search_space_digest(
164166
raise ValueError(f"Unknown parameter type {type(p)}")
165167
if p.is_fidelity:
166168
fidelity_features.append(i)
167-
target_values[i] = assert_is_instance_of_tuple(p.target_value, (int, float))
169+
target_values[i] = cast(
170+
TNumeric,
171+
assert_is_instance_of_tuple(p.target_value, (int, float)),
172+
)
168173

169174
if search_space.is_hierarchical:
170175
hierarchical_dependencies = {}
171176

172177
for p_name, p in search_space.parameters.items():
173178
if p.is_hierarchical:
174179
hierarchical_dependencies[param_names.index(p_name)] = {
175-
assert_is_instance_of_tuple(parent_value, (int, float)): [
180+
cast(
181+
TNumeric,
182+
assert_is_instance_of_tuple(parent_value, (int, float)),
183+
): [
176184
param_names.index(activated_param)
177185
for activated_param in activated_params
178186
]

ax/adapter/transforms/log.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,13 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
114114
]
115115
target_value = p.target_value
116116
if target_value is not None:
117-
target_value = math.log10(
118-
assert_is_instance_of_tuple(target_value, (float, int))
119-
)
117+
assert_is_instance_of_tuple(target_value, (float, int))
118+
target_value = math.log10(float(target_value))
120119
if dependents is not None:
121120
dependents = {
122-
math.log10(assert_is_instance_of_tuple(k, (float, int))): v
121+
math.log10(
122+
float(assert_is_instance_of_tuple(k, (float, int)))
123+
): v
123124
for k, v in dependents.items()
124125
}
125126

ax/adapter/transforms/one_hot.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from __future__ import annotations
1010

1111
from copy import deepcopy
12-
from typing import TYPE_CHECKING
12+
from typing import cast, TYPE_CHECKING
1313

1414
import numpy as np
1515
import pandas as pd
@@ -32,7 +32,6 @@
3232
from ax.core.types import TParameterization, TParamValue
3333
from ax.exceptions.core import UnsupportedError
3434
from ax.generators.types import TConfig
35-
from ax.utils.common.typeutils import assert_is_instance_of_tuple
3635
from pyre_extensions import assert_is_instance
3736

3837
if TYPE_CHECKING:
@@ -176,9 +175,7 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
176175
# If the dependent is not being transformed, keep it as is.
177176
new_deps.extend(oh_param_names_for_p.get(dep, [dep]))
178177
new_dependents[val] = new_deps
179-
assert_is_instance_of_tuple(
180-
p, (ChoiceParameter, FixedParameter)
181-
).dependents = new_dependents
178+
cast(ChoiceParameter | FixedParameter, p).dependents = new_dependents
182179

183180
return construct_new_search_space(
184181
search_space=search_space,

ax/utils/common/equality.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ def equality_typechecker(eq_func: Callable) -> Callable:
2323
"""
2424

2525
# no type annotation for now; breaks sphinx-autodoc-typehints
26-
# pyre-fixme[3]: Return type must be annotated.
27-
# pyre-fixme[2]: Parameter must be annotated.
28-
def _type_safe_equals(self, other):
26+
def _type_safe_equals(self: Any, other: Any) -> bool:
2927
if not isinstance(other, self.__class__):
3028
return False
3129
return eq_func(self, other)
@@ -60,7 +58,6 @@ def same_elements(list1: list[Any], list2: list[Any]) -> bool:
6058
return all(matched)
6159

6260

63-
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
6461
def is_ax_equal(one_val: Any, other_val: Any) -> bool:
6562
"""Check for equality of two values, handling lists, dicts, dfs, floats,
6663
dates, and numpy arrays. This method and ``same_elements`` function

ax/utils/common/executils.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,20 +80,15 @@ def retry_on_exception(
8080
then there is no wait between retries.
8181
"""
8282

83-
# pyre-fixme[3]: Return type must be annotated.
84-
# pyre-fixme[2]: Parameter must be annotated.
85-
def func_wrapper(func):
83+
def func_wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
8684
# Depending on whether `func` is async or not, we use a slightly different
8785
# wrapper; if wrapping an async function, decorator will await it.
8886
# `async_actual_wrapper` and `actual_wrapper` are almost exactly the same,
8987
# except that the former is async and awaits the wrapped function.
9088
if asyncio.iscoroutinefunction(func):
9189

9290
@functools.wraps(func)
93-
# pyre-fixme[53]: Captured variable `func` is not annotated.
94-
# pyre-fixme[3]: Return type must be annotated.
95-
# pyre-fixme[2]: Parameter must be annotated.
96-
async def async_actual_wrapper(*args, **kwargs):
91+
async def async_actual_wrapper(*args: Any, **kwargs: Any) -> Any:
9792
(
9893
retry_exceptions,
9994
no_retry_exceptions,
@@ -110,8 +105,6 @@ async def async_actual_wrapper(*args, **kwargs):
110105
no_retry_exceptions=no_retry_exceptions,
111106
retry_exceptions=retry_exceptions,
112107
suppress_errors=suppress_errors,
113-
# pyre-fixme[6]: For 4th param expected `Optional[str]` but
114-
# got `Optional[List[str]]`.
115108
check_message_contains=check_message_contains,
116109
last_retry=i >= retries - 1,
117110
logger=logger,
@@ -130,10 +123,7 @@ async def async_actual_wrapper(*args, **kwargs):
130123
return async_actual_wrapper
131124

132125
@functools.wraps(func)
133-
# pyre-fixme[53]: Captured variable `func` is not annotated.
134-
# pyre-fixme[3]: Return type must be annotated.
135-
# pyre-fixme[2]: Parameter must be annotated.
136-
def actual_wrapper(*args, **kwargs):
126+
def actual_wrapper(*args: Any, **kwargs: Any) -> Any:
137127
(
138128
retry_exceptions,
139129
no_retry_exceptions,
@@ -150,8 +140,6 @@ def actual_wrapper(*args, **kwargs):
150140
no_retry_exceptions=no_retry_exceptions,
151141
retry_exceptions=retry_exceptions,
152142
suppress_errors=suppress_errors,
153-
# pyre-fixme[6]: For 4th param expected `Optional[str]` but got
154-
# `Optional[List[str]]`.
155143
check_message_contains=check_message_contains,
156144
last_retry=i >= retries - 1,
157145
logger=logger,
@@ -178,7 +166,7 @@ def handle_exceptions_in_retries(
178166
no_retry_exceptions: tuple[type[Exception], ...],
179167
retry_exceptions: tuple[type[Exception], ...],
180168
suppress_errors: bool,
181-
check_message_contains: str | None,
169+
check_message_contains: list[str] | None,
182170
last_retry: bool,
183171
logger: Logger | None,
184172
wrap_error_message_in: str | None,

ax/utils/common/func_enum.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,13 @@ class FuncEnum(Enum):
1717
"""A base class for all enums with the following structure: string values that
1818
map to names of functions, which reside in the same module as the enum."""
1919

20-
# pyre-ignore[3]: Input constructors will be used to make different inputs,
21-
# so we need to allow `Any` return type here.
2220
def __call__(self, **kwargs: Any) -> Any:
2321
"""Defines a method, by which the members of this enum can be called,
2422
e.g. ``MyFunctions.F(**kwargs)``, which will call the corresponding
2523
function registered by the name ``F`` in the enum."""
2624
return self._get_function_for_value()(**kwargs)
2725

28-
# pyre-ignore[31]: Expression `typing.Callable[([...], typing.Any)]`
29-
# is not a valid type.
30-
def _get_function_for_value(self) -> Callable[[...], Any]:
26+
def _get_function_for_value(self) -> Callable[..., Any]:
3127
"""Retrieve the function in this module, name of which corresponds to the
3228
value of the enum member."""
3329
try:

ax/utils/common/logger.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ def get_logger(
6565
return logger
6666

6767

68-
# pyre-fixme[24]: Generic type `logging.StreamHandler` expects 1 type parameter.
69-
def build_stream_handler(level: int = DEFAULT_LOG_LEVEL) -> logging.StreamHandler:
68+
def build_stream_handler(level: int = DEFAULT_LOG_LEVEL) -> logging.StreamHandler[Any]:
7069
"""Build the default stream handler used for most Ax logging. Sets
7170
default level to INFO, instead of WARNING.
7271
@@ -86,8 +85,7 @@ def build_stream_handler(level: int = DEFAULT_LOG_LEVEL) -> logging.StreamHandle
8685
def build_file_handler(
8786
filepath: str,
8887
level: int = DEFAULT_LOG_LEVEL,
89-
# pyre-fixme[24]: Generic type `logging.StreamHandler` expects 1 type parameter.
90-
) -> logging.StreamHandler:
88+
) -> logging.StreamHandler[Any]:
9189
"""Build a file handle that logs entries to the given file, using the
9290
same formatting as the stream handler.
9391
@@ -216,8 +214,7 @@ def inner(*args: Any, **kwargs: Any) -> T:
216214
# Uses a permissive level on the logger, instead make each
217215
# handler as permissive/restrictive as desired
218216
ROOT_LOGGER.setLevel(logging.DEBUG)
219-
# pyre-fixme[24]: Generic type `logging.StreamHandler` expects 1 type parameter.
220-
ROOT_STREAM_HANDLER: logging.StreamHandler = build_stream_handler()
217+
ROOT_STREAM_HANDLER: logging.StreamHandler[Any] = build_stream_handler()
221218
ROOT_LOGGER.addHandler(ROOT_STREAM_HANDLER)
222219

223220

ax/utils/common/result.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,7 @@ def unwrap_err(self) -> E:
114114
pass
115115

116116
@abstractmethod
117-
# pyre-ignore[46]: The type variable `Variable[T](covariant)` is covariant and
118-
# cannot be a parameter type.
119-
def unwrap_or(self, default: T) -> T:
117+
def unwrap_or(self, default: U) -> T | U:
120118
"""Returns the contained Ok value or a provided default."""
121119

122120
pass
@@ -183,7 +181,7 @@ def unwrap(self) -> T:
183181
def unwrap_err(self) -> NoReturn:
184182
raise RuntimeError(f"Tried to unwrap_err {self}.")
185183

186-
def unwrap_or(self, default: U) -> T:
184+
def unwrap_or(self, default: U) -> T | U:
187185
return self._value
188186

189187
def unwrap_or_else(self, op: Callable[[E], T]) -> T:
@@ -243,9 +241,7 @@ def unwrap(self) -> NoReturn:
243241
def unwrap_err(self) -> E:
244242
return self._value
245243

246-
# pyre-ignore[46]: The type variable `Variable[T](covariant)` is covariant and
247-
# cannot be a parameter type.
248-
def unwrap_or(self, default: T) -> T:
244+
def unwrap_or(self, default: U) -> T | U:
249245
return default
250246

251247
def unwrap_or_else(self, op: Callable[[E], T]) -> T:

ax/utils/common/serialization.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,7 @@ def serialize_init_args(
7272
return properties
7373

7474

75-
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
76-
# avoid runtime subscripting errors.
77-
def extract_init_args(args: dict[str, Any], class_: type) -> dict[str, Any]:
75+
def extract_init_args(args: dict[str, Any], class_: type[Any]) -> dict[str, Any]:
7876
"""Given a dictionary, extract the arguments required for the
7977
given class's constructor.
8078
"""

ax/utils/common/testutils.py

Lines changed: 23 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
import unittest
2121
import warnings
2222
from collections.abc import Callable, Generator
23-
from contextlib import AbstractContextManager
2423
from logging import Logger
2524
from pstats import Stats
2625
from types import FrameType
27-
from typing import Any, TypeVar, Union
26+
from typing import Any, cast, TypeVar, Union
2827

2928
import numpy as np
3029
import torch
@@ -60,9 +59,7 @@ def _get_tb_lines(tb: types.TracebackType) -> list[tuple[str, int, str]]:
6059
return res
6160

6261

63-
# pyre-fixme[24]: Generic type `unittest.case._AssertRaisesContext` expects 1 type
64-
# parameter.
65-
class _AssertRaisesContextOn(unittest.case._AssertRaisesContext):
62+
class _AssertRaisesContextOn(unittest.case._AssertRaisesContext[Exception]):
6663
"""
6764
Attributes:
6865
lineno: the line number on which the error occurred
@@ -89,12 +86,10 @@ def __init__(
8986
expected=expected, test_case=test_case, expected_regex=expected_regex
9087
)
9188

92-
# pyre-fixme[14]: `__exit__` overrides method defined in `_AssertRaisesContext`
93-
# inconsistently.
9489
def __exit__(
9590
self,
96-
exc_type: type[Exception] | None,
97-
exc_value: Exception | None,
91+
exc_type: type[BaseException] | None,
92+
exc_value: BaseException | None,
9893
tb: types.TracebackType | None,
9994
) -> bool:
10095
"""This is called when the context closes. If an exception was raised
@@ -110,10 +105,8 @@ def __exit__(
110105
self.filename, self.lineno, _ = frames[0]
111106
lines = [line for _, _, line in frames]
112107
if self._expected_line is not None and self._expected_line not in lines:
113-
# pyre-ignore [16]: ... has no attribute `_raiseFailure`.
114-
self._raiseFailure(
115-
f"{self._expected_line!r} was not found in the traceback: {lines!r}"
116-
)
108+
msg = f"{self._expected_line!r} was not found in the traceback: {lines!r}"
109+
raise self.test_case.failureException(msg)
117110

118111
return True
119112

@@ -420,12 +413,10 @@ def assertRaisesOn(
420413
exc: type[Exception],
421414
line: str | None = None,
422415
regex: str | None = None,
423-
# pyre-ignore[24]: Generic type `AbstractContextManager`
424-
# expects 2 type parameters, received 1.
425-
) -> AbstractContextManager[None]:
416+
) -> _AssertRaisesContextOn:
426417
"""Assert that an exception is raised on a specific line."""
427418
context = _AssertRaisesContextOn(exc, self, line, regex)
428-
return context.handle("assertRaisesOn", [], {})
419+
return cast(_AssertRaisesContextOn, context.handle("assertRaisesOn", [], {}))
429420

430421
def assertDictsAlmostEqual(
431422
self, a: dict[str, Any], b: dict[str, Any], consider_nans_equal: bool = False
@@ -532,30 +523,21 @@ def ax_long_test(cls, reason: str | None) -> Generator[None, None, None]:
532523
cls._long_test_active_reason = None
533524

534525
# This list is taken from the python standard library
535-
# pyre-fixme[4]: Attribute must be annotated.
536-
failUnlessEqual = assertEquals = _deprecate(unittest.TestCase.assertEqual)
537-
# pyre-fixme[4]: Attribute must be annotated.
538-
failIfEqual = assertNotEquals = _deprecate(unittest.TestCase.assertNotEqual)
539-
# pyre-fixme[4]: Attribute must be annotated.
540-
failUnlessAlmostEqual = assertAlmostEquals = _deprecate(
541-
unittest.TestCase.assertAlmostEqual
542-
)
543-
# pyre-fixme[4]: Attribute must be annotated.
544-
failIfAlmostEqual = assertNotAlmostEquals = _deprecate(
545-
unittest.TestCase.assertNotAlmostEqual
546-
)
547-
# pyre-fixme[4]: Attribute must be annotated.
548-
failUnless = assert_ = _deprecate(unittest.TestCase.assertTrue)
549-
# pyre-fixme[4]: Attribute must be annotated.
550-
failUnlessRaises = _deprecate(unittest.TestCase.assertRaises)
551-
# pyre-fixme[4]: Attribute must be annotated.
552-
failIf = _deprecate(unittest.TestCase.assertFalse)
553-
# pyre-fixme[4]: Attribute must be annotated.
554-
assertRaisesRegexp = _deprecate(unittest.TestCase.assertRaisesRegex)
555-
# pyre-fixme[4]: Attribute must be annotated.
556-
assertRegexpMatches = _deprecate(unittest.TestCase.assertRegex)
557-
# pyre-fixme[4]: Attribute must be annotated.
558-
assertNotRegexpMatches = _deprecate(unittest.TestCase.assertNotRegex)
526+
failUnlessEqual: Callable = _deprecate(unittest.TestCase.assertEqual)
527+
assertEquals: Callable = failUnlessEqual
528+
failIfEqual: Callable = _deprecate(unittest.TestCase.assertNotEqual)
529+
assertNotEquals: Callable = failIfEqual
530+
failUnlessAlmostEqual: Callable = _deprecate(unittest.TestCase.assertAlmostEqual)
531+
assertAlmostEquals: Callable = failUnlessAlmostEqual
532+
failIfAlmostEqual: Callable = _deprecate(unittest.TestCase.assertNotAlmostEqual)
533+
assertNotAlmostEquals: Callable = failIfAlmostEqual
534+
failUnless: Callable = _deprecate(unittest.TestCase.assertTrue)
535+
assert_: Callable = failUnless
536+
failUnlessRaises: Callable = _deprecate(unittest.TestCase.assertRaises)
537+
failIf: Callable = _deprecate(unittest.TestCase.assertFalse)
538+
assertRaisesRegexp: Callable = _deprecate(unittest.TestCase.assertRaisesRegex)
539+
assertRegexpMatches: Callable = _deprecate(unittest.TestCase.assertRegex)
540+
assertNotRegexpMatches: Callable = _deprecate(unittest.TestCase.assertNotRegex)
559541

560542
# Copied from BoTorch assertAllClose
561543
def assertAllClose(

0 commit comments

Comments
 (0)