|
3 | 3 | # pyre-strict
|
4 | 4 | from collections import defaultdict
|
5 | 5 | from copy import copy
|
6 |
| -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union |
| 6 | +from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Tuple, Union |
7 | 7 |
|
8 | 8 | import torch
|
9 | 9 | from captum._utils.common import (
|
@@ -193,8 +193,7 @@ def _forward_with_dataloader(
|
193 | 193 | feature_mask: Tuple[Tensor, ...],
|
194 | 194 | # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
|
195 | 195 | reduce: Callable,
|
196 |
| - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. |
197 |
| - to_metric: Optional[Callable], |
| 196 | + to_metric: Optional[Callable[[Tensor], Tensor]], |
198 | 197 | show_progress: bool,
|
199 | 198 | feature_idx_to_mask_idx: Dict[int, List[int]],
|
200 | 199 | ) -> Tensor:
|
@@ -243,7 +242,8 @@ def _forward_with_dataloader(
|
243 | 242 |
|
244 | 243 | accum_states[i] = reduce(accum_states[i], output, perturbed_inputs)
|
245 | 244 |
|
246 |
| - accum_results = [ |
| 245 | + accum_states = cast(List[Tensor], accum_states) |
| 246 | + accum_results: List[Tensor] = [ |
247 | 247 | to_metric(accum) if to_metric else accum for accum in accum_states
|
248 | 248 | ]
|
249 | 249 |
|
@@ -276,7 +276,7 @@ def attribute(
|
276 | 276 | Args:
|
277 | 277 |
|
278 | 278 | dataloader (torch.Dataloader): the dataloader to attribute, which should
|
279 |
| - return a tuple of consistant size for every iteration |
| 279 | + return a tuple of consistent size for every iteration |
280 | 280 | input_roles (tuple[int, ...], optional): a tuple of integers to define the
|
281 | 281 | role of each element returned from the dataloader. It should
|
282 | 282 | have the same size as the return of the dataloader.
|
@@ -326,7 +326,7 @@ def attribute(
|
326 | 326 | traverses needed is
|
327 | 327 | ceil(n_perturbations / perturbations_per_pass).
|
328 | 328 |
|
329 |
| - This arguement offers control of the trade-off between memory |
| 329 | + This argument offers control of the trade-off between memory |
330 | 330 | and efficiency. If the dataloader involves slow operations like
|
331 | 331 | remote request or file I/O, multiple traversals can be
|
332 | 332 | inefficient. On the other hand, each perturbation needs to
|
|
0 commit comments