Skip to content

Commit 85630e5

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Resolve all broken OSS test cases (#1502)
Summary: Pull Request resolved: #1502 Resolve various errors regarding pyre typing, mypy typing, ufmt formatting, and flake8 formatting. Reviewed By: sarahtranfb Differential Revision: D69695280 fbshipit-source-id: d76b024fbba0cd78f78bf2b6fb7ffd0e92084b30
1 parent fc9026a commit 85630e5

File tree

13 files changed

+41
-33
lines changed

13 files changed

+41
-33
lines changed

captum/_utils/common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,12 @@ def _format_tensor_into_tuples(inputs: None) -> None: ...
232232

233233
@overload
234234
def _format_tensor_into_tuples(
235-
inputs: Union[Tensor, Tuple[Tensor, ...]]
235+
inputs: Union[Tensor, Tuple[Tensor, ...]],
236236
) -> Tuple[Tensor, ...]: ...
237237

238238

239239
def _format_tensor_into_tuples(
240-
inputs: Union[None, Tensor, Tuple[Tensor, ...]]
240+
inputs: Union[None, Tensor, Tuple[Tensor, ...]],
241241
) -> Union[None, Tuple[Tensor, ...]]:
242242
if inputs is None:
243243
return None
@@ -261,7 +261,7 @@ def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any:
261261

262262

263263
def _format_float_or_tensor_into_tuples(
264-
inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]
264+
inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]],
265265
) -> Tuple[Union[float, Tensor], ...]:
266266
if not isinstance(inputs, tuple):
267267
assert isinstance(
@@ -276,7 +276,7 @@ def _format_float_or_tensor_into_tuples(
276276
@overload
277277
def _format_additional_forward_args(
278278
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
279-
additional_forward_args: Union[Tensor, Tuple]
279+
additional_forward_args: Union[Tensor, Tuple],
280280
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
281281
) -> Tuple: ...
282282

captum/attr/_core/dataloader_attr.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4+
45
from collections import defaultdict
56
from copy import copy
67
from typing import Callable, cast, Dict, Iterable, List, Optional, Tuple, Union
@@ -30,7 +31,6 @@ class InputRole:
3031

3132

3233
# default reducer wehn reduce is None. Simply concat the outputs by the batch dimension
33-
# pyre-fixme[2]: Parameter must be annotated.
3434
def _concat_tensors(accum: Optional[Tensor], cur_output: Tensor, _) -> Tensor:
3535
return cur_output if accum is None else torch.cat([accum, cur_output])
3636

@@ -87,9 +87,7 @@ def _perturb_inputs(
8787
else:
8888
baseline = baselines[attr_inp_count]
8989

90-
# pyre-fixme[58]: `*` is not supported for operand types `object` and
91-
# `Tensor`.
92-
perturbed_inp = inp * pert_mask + baseline * (1 - pert_mask)
90+
perturbed_inp = cast(Tensor, inp) * pert_mask + baseline * (1 - pert_mask)
9391
perturbed_inputs.append(perturbed_inp)
9492

9593
attr_inp_count += 1

captum/attr/_core/layer/layer_lrp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def _get_output_relevance(
307307

308308
@staticmethod
309309
def _convert_list_to_tuple(
310-
relevances: Union[List[T], Tuple[T, ...]]
310+
relevances: Union[List[T], Tuple[T, ...]],
311311
) -> Tuple[T, ...]:
312312
if isinstance(relevances, list):
313313
return tuple(relevances)

captum/attr/_core/llm_attr.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,12 +559,19 @@ def _forward_func(
559559
outputs.past_key_values = DynamicCache.from_legacy_cache(
560560
outputs.past_key_values
561561
)
562+
# nn.Module typing suggests non-base attributes are modules or
563+
# tensors
564+
_update_model_kwargs_for_generation = (
565+
self.model._update_model_kwargs_for_generation
566+
)
562567
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
563-
model_kwargs = self.model._update_model_kwargs_for_generation(
568+
model_kwargs = _update_model_kwargs_for_generation( # type: ignore
564569
outputs, model_kwargs
565570
)
571+
# nn.Module typing suggests non-base attributes are modules or tensors
572+
prep_inputs_for_generation = self.model.prepare_inputs_for_generation
566573
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
567-
model_inputs = self.model.prepare_inputs_for_generation(
574+
model_inputs = prep_inputs_for_generation( # type: ignore
568575
model_inp, **model_kwargs
569576
)
570577
outputs = self.model.forward(**model_inputs)

captum/attr/_utils/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def _find_output_mode_and_verify(
373373

374374

375375
def _construct_default_feature_mask(
376-
inputs: Tuple[Tensor, ...]
376+
inputs: Tuple[Tensor, ...],
377377
) -> Tuple[Tuple[Tensor, ...], int]:
378378
feature_mask = []
379379
current_num_features = 0

captum/attr/_utils/stat.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#!/usr/bin/env python3
22

33
# pyre-strict
4-
from typing import Any, Callable, List, Optional, TYPE_CHECKING
4+
5+
from typing import Any, Callable, cast, List, Optional, TYPE_CHECKING
56

67
import torch
78
from torch import Tensor
@@ -117,20 +118,18 @@ def get(self) -> Optional[Tensor]:
117118
return self.rolling_mean
118119

119120
def init(self) -> None:
120-
# pyre-fixme[8]: Attribute has type `Optional[Count]`; used as `Optional[Stat]`.
121-
self.n = self._get_stat(Count()) # type: ignore
121+
self.n = cast(Count, self._get_stat(Count()))
122122

123123
def update(self, x: Tensor) -> None:
124-
# pyre-fixme[16]: `Optional` has no attribute `get`.
125-
n = self.n.get() # type: ignore
124+
n = cast(Count, self.n).get()
126125

127126
if self.rolling_mean is None:
128127
# Ensures rolling_mean is a float tensor
129128
self.rolling_mean = x.clone() if x.is_floating_point() else x.double()
130129
else:
131130
delta = x - self.rolling_mean
132-
# pyre-fixme[16]: `Optional` has no attribute `__iadd__`.
133-
self.rolling_mean += delta / n
131+
# pyre-ignore[16]: `Optional` has no attribute `__iadd__` (false positive)
132+
self.rolling_mean += delta / cast(int, n)
134133

135134

136135
class MSE(Stat):

captum/influence/_utils/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def __len__(self) -> int:
338338

339339
def _format_inputs_dataset(
340340
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
341-
inputs_dataset: Union[Tuple[Any, ...], DataLoader]
341+
inputs_dataset: Union[Tuple[Any, ...], DataLoader],
342342
) -> DataLoader:
343343
# if `inputs_dataset` is not a `DataLoader`, turn it into one.
344344
# `_DatasetFromList` turns a list into a `Dataset` where `__getitem__`
@@ -604,7 +604,7 @@ def _flatten_params(_params: Tuple[Tensor, ...]) -> Tensor:
604604

605605
# pyre-fixme[3]: Return type must be annotated.
606606
def _unflatten_params_factory(
607-
param_shapes: Union[List[Tuple[int, ...]], Tuple[Tensor, ...]]
607+
param_shapes: Union[List[Tuple[int, ...]], Tuple[Tensor, ...]],
608608
):
609609
"""
610610
returns a function which is the inverse of `_flatten_params`

captum/insights/attr_vis/_utils/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
def format_transforms(
99
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
10-
transforms: Optional[Union[Callable, List[Callable]]]
10+
transforms: Optional[Union[Callable, List[Callable]]],
1111
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
1212
) -> List[Callable]:
1313
if transforms is None:

captum/metrics/_core/infidelity.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def infidelity_perturb_func_decorator(
5959
"""
6060

6161
def sub_infidelity_perturb_func_decorator(
62-
perturb_func: Callable[..., TensorOrTupleOfTensorsGeneric]
62+
perturb_func: Callable[..., TensorOrTupleOfTensorsGeneric],
6363
) -> Callable[
6464
[TensorOrTupleOfTensorsGeneric, BaselineType],
6565
Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]],
@@ -611,6 +611,11 @@ def _next_infidelity_tensors(
611611
targets_expanded,
612612
additional_forward_args_expanded,
613613
)
614+
if isinstance(inputs_perturbed_fwd, torch.futures.Future):
615+
raise NotImplementedError(
616+
f"Outputs from forward_func of type {type(inputs_perturbed_fwd)} are "
617+
"not yet supported."
618+
)
614619
inputs_fwd = _run_forward(forward_func, inputs, target, additional_forward_args)
615620
# _run_forward may return future of Tensor,
616621
# but we don't support it here now
@@ -619,8 +624,6 @@ def _next_infidelity_tensors(
619624
inputs_fwd = torch.repeat_interleave(
620625
inputs_fwd, current_n_perturb_samples, dim=0
621626
)
622-
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
623-
# `Union[Future[Tensor], Tensor]`.
624627
perturbed_fwd_diffs = inputs_fwd - inputs_perturbed_fwd
625628
attributions_expanded = tuple(
626629
torch.repeat_interleave(attribution, current_n_perturb_samples, dim=0)

captum/testing/helpers/influence/common.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _isSorted(x, key=lambda x: x, descending=True) -> bool:
3333

3434
# pyre-fixme[2]: Parameter must be annotated.
3535
def _wrap_model_in_dataparallel(net) -> Module:
36-
alt_device_ids = [0] + [x for x in range(torch.cuda.device_count() - 1, 0, -1)]
36+
alt_device_ids = [0] + list(range(torch.cuda.device_count() - 1, 0, -1))
3737
net = net.cuda()
3838
return torch.nn.DataParallel(net, device_ids=alt_device_ids)
3939

@@ -505,7 +505,7 @@ def get_random_model_and_data(
505505

506506
# pyre-fixme[3]: Return type must be annotated.
507507
def generate_symmetric_matrix_given_eigenvalues(
508-
eigenvalues: Union[Tensor, List[float]]
508+
eigenvalues: Union[Tensor, List[float]],
509509
):
510510
"""
511511
following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L123 # noqa: E501
@@ -523,7 +523,7 @@ def generate_symmetric_matrix_given_eigenvalues(
523523

524524

525525
def generate_assymetric_matrix_given_eigenvalues(
526-
eigenvalues: Union[Tensor, List[float]]
526+
eigenvalues: Union[Tensor, List[float]],
527527
) -> Tensor:
528528
"""
529529
following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L105 # noqa: E501

0 commit comments

Comments
 (0)