From 45e2051cc7be1481d5d73694faf1c80b9ceadfeb Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Tue, 23 Sep 2025 20:03:55 -0700 Subject: [PATCH 1/3] Clean up enable_cross_tensor_attribution from FeaturePermutation (#1649) Summary: Defaulted to true everywhere in D81948483 stack Differential Revision: D83107514 --- captum/attr/_core/feature_permutation.py | 74 +----- tests/attr/test_feature_permutation.py | 296 ++++++++++------------- 2 files changed, 131 insertions(+), 239 deletions(-) diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index 238f6958e..1b6eb08fa 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -55,9 +55,9 @@ class FeaturePermutation(FeatureAblation): of examples to compute attributions and cannot be performed on a single example. By default, each scalar value within - each input tensor is taken as a feature and shuffled independently, *unless* - attribute() is called with enable_cross_tensor_attribution=True. Passing - a feature mask, allows grouping features to be shuffled together. + each input tensor is taken as a feature and shuffled independently. Passing + a feature mask allows grouping features to be shuffled together (including + features defined across different input tensors). Each input scalar in the group will be given the same attribution value equal to the change in target as a result of shuffling the entire feature group. @@ -92,12 +92,6 @@ def __init__( """ FeatureAblation.__init__(self, forward_func=forward_func) self.perm_func = perm_func - # Minimum number of elements needed in each input tensor, when - # `enable_cross_tensor_attribution` is False, otherwise the - # attribution for the tensor will be skipped. Set to 1 to throw if any - # input tensors only have one example - self._min_examples_per_batch = 2 - # Similar to above, when `enable_cross_tensor_attribution` is True. # Considering the case when we permute multiple input tensors at once # through `feature_mask`, we disregard the feature group if the 0th # dim of *any* input tensor in the group is less than @@ -115,7 +109,6 @@ def attribute( # type: ignore feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, show_progress: bool = False, - enable_cross_tensor_attribution: bool = True, **kwargs: Any, ) -> TensorOrTupleOfTensorsGeneric: r""" @@ -187,18 +180,12 @@ def attribute( # type: ignore input tensor. Each tensor should contain integers in the range 0 to num_features - 1, and indices corresponding to the same feature should have the - same value. Note that features within each input - tensor are ablated independently (not across - tensors), unless enable_cross_tensor_attribution is - True. - + same value. The first dimension of each mask must be 1, as we require to have the same group of features for each input sample. If None, then a feature mask is constructed which assigns - each scalar within a tensor as a separate feature, which - is permuted independently, unless - enable_cross_tensor_attribution is True. + each scalar within a tensor as a separate feature. Default: None perturbations_per_eval (int, optional): Allows permutations of multiple features to be processed simultaneously @@ -217,10 +204,6 @@ def attribute( # type: ignore (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False - enable_cross_tensor_attribution (bool, optional): If True, then - features can be grouped across input tensors depending on - the values in the feature mask. - Default: False **kwargs (Any, optional): Any additional arguments used by child classes of :class:`.FeatureAblation` (such as :class:`.Occlusion`) to construct ablations. These @@ -292,7 +275,6 @@ def attribute( # type: ignore feature_mask=feature_mask, perturbations_per_eval=perturbations_per_eval, show_progress=show_progress, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, **kwargs, ) @@ -304,7 +286,6 @@ def attribute_future( feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, show_progress: bool = False, - enable_cross_tensor_attribution: bool = True, **kwargs: Any, ) -> Future[TensorOrTupleOfTensorsGeneric]: """ @@ -321,54 +302,9 @@ def attribute_future( feature_mask=feature_mask, perturbations_per_eval=perturbations_per_eval, show_progress=show_progress, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, **kwargs, ) - def _construct_ablated_input( - self, - expanded_input: Tensor, - input_mask: Union[None, Tensor, Tuple[Tensor, ...]], - baseline: Union[None, float, Tensor], - start_feature: int, - end_feature: int, - **kwargs: Any, - ) -> Tuple[Tensor, Tensor]: - r""" - This function permutes the features of `expanded_input` with a given - feature mask and feature range. Permutation occurs via calling - `self.perm_func` across each batch within `expanded_input`. As with - `FeatureAblation._construct_ablated_input`: - - `expanded_input.shape = (num_features, num_examples, ...)` - - `num_features = end_feature - start_feature` (i.e. start and end is a - half-closed interval) - - `input_mask` is a tensor of the same shape as one input, which - describes the locations of each feature via their "index" - - Since `baselines` is set to None for `FeatureAblation.attribute, this - will be the zero tensor, however, it is not used. - """ - assert ( - input_mask is not None - and not isinstance(input_mask, tuple) - and input_mask.shape[0] == 1 - ), ( - "input_mask.shape[0] != 1: pass in one mask in order to permute" - "the same features for each input" - ) - current_mask = torch.stack( - [input_mask == j for j in range(start_feature, end_feature)], dim=0 - ).bool() - current_mask = current_mask.to(expanded_input.device) - - output = torch.stack( - [ - self.perm_func(x, mask.squeeze(0)) - for x, mask in zip(expanded_input, current_mask) - ] - ) - return output, current_mask - def _construct_ablated_input_across_tensors( self, inputs: Tuple[Tensor, ...], diff --git a/tests/attr/test_feature_permutation.py b/tests/attr/test_feature_permutation.py index 329fea812..69683a402 100644 --- a/tests/attr/test_feature_permutation.py +++ b/tests/attr/test_feature_permutation.py @@ -103,31 +103,12 @@ def forward_func(x: Tensor) -> Tensor: inp[:, 0] = constant_value zeros = torch.zeros_like(inp[:, 0]) - for enable_cross_tensor_attribution in (True, False): - attribs = feature_importance.attribute( - inp, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) - self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size) - assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max") - self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all()) - - def test_simple_input_with_min_examples(self) -> None: - def forward_func(x: Tensor) -> Tensor: - return x.sum(dim=-1) - - feature_importance = FeaturePermutation(forward_func=forward_func) - inp = torch.tensor([[1.0, 2.0]]) - assertTensorAlmostEqual( - self, - feature_importance.attribute(inp, enable_cross_tensor_attribution=False), - torch.tensor([[0.0, 0.0]]), - delta=0.0, + attribs = feature_importance.attribute( + inp, ) - - feature_importance._min_examples_per_batch = 1 - with self.assertRaises(AssertionError): - feature_importance.attribute(inp, enable_cross_tensor_attribution=False) + self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size) + assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max") + self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all()) def test_simple_input_with_min_examples_in_group(self) -> None: def forward_func(x: Tensor) -> Tensor: @@ -137,22 +118,20 @@ def forward_func(x: Tensor) -> Tensor: inp = torch.tensor([[1.0, 2.0]]) assertTensorAlmostEqual( self, - feature_importance.attribute(inp, enable_cross_tensor_attribution=True), + feature_importance.attribute(inp), torch.tensor([[0.0, 0.0]]), delta=0.0, ) assertTensorAlmostEqual( self, - feature_importance.attribute( - torch.tensor([]), enable_cross_tensor_attribution=True - ), + feature_importance.attribute(torch.tensor([])), torch.tensor([0.0]), delta=0.0, ) feature_importance._min_examples_per_batch_grouped = 1 with self.assertRaises(AssertionError): - feature_importance.attribute(inp, enable_cross_tensor_attribution=True) + feature_importance.attribute(inp) def test_simple_input_custom_mask_with_min_examples_in_group(self) -> None: def forward_func(x1: Tensor, x2: Tensor) -> Tensor: @@ -169,18 +148,14 @@ def forward_func(x1: Tensor, x2: Tensor) -> Tensor: ) assertTensorAlmostEqual( self, - feature_importance.attribute( - inp, feature_mask=mask, enable_cross_tensor_attribution=True - )[0], + feature_importance.attribute(inp, feature_mask=mask)[0], torch.tensor([[0.0, 0.0]]), delta=0.0, ) feature_importance._min_examples_per_batch_grouped = 1 with self.assertRaises(AssertionError): - feature_importance.attribute( - inp, feature_mask=mask, enable_cross_tensor_attribution=True - ) + feature_importance.attribute(inp, feature_mask=mask) def test_single_input_with_future( self, @@ -200,18 +175,16 @@ def forward_func(x: Tensor) -> Tensor: inp[:, 0] = constant_value zeros = torch.zeros_like(inp[:, 0]) - for enable_cross_tensor_attribution in [True, False]: - attribs = feature_importance.attribute_future( - inp, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) + attribs = feature_importance.attribute_future( + inp, + ) - self.assertTrue(type(attribs) is torch.Future) - attribs = attribs.wait() + self.assertTrue(type(attribs) is torch.Future) + attribs = attribs.wait() - self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size) - assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max") - self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all()) + self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size) + assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max") + self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all()) def test_multi_input( self, @@ -247,24 +220,22 @@ def forward_func(*x: Tensor) -> Tensor: ) inp[1][:, :, 1] = 4 - for enable_cross_tensor_attribution in (True, False): - attribs = feature_importance.attribute( - inp, - feature_mask=feature_mask, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) + attribs = feature_importance.attribute( + inp, + feature_mask=feature_mask, + ) - self.assertTrue(isinstance(attribs, tuple)) - self.assertTrue(len(attribs) == 2) + self.assertTrue(isinstance(attribs, tuple)) + self.assertTrue(len(attribs) == 2) - self.assertTrue(attribs[0].squeeze(0).size() == inp1_size) - self.assertTrue(attribs[1].squeeze(0).size() == inp2_size) + self.assertTrue(attribs[0].squeeze(0).size() == inp1_size) + self.assertTrue(attribs[1].squeeze(0).size() == inp2_size) - self.assertTrue((attribs[1][:, :, 1] == 0).all()) - self.assertTrue((attribs[1][:, :, 2] == 0).all()) + self.assertTrue((attribs[1][:, :, 1] == 0).all()) + self.assertTrue((attribs[1][:, :, 2] == 0).all()) - self.assertTrue((attribs[0] != 0).all()) - self.assertTrue((attribs[1][:, :, 0] != 0).all()) + self.assertTrue((attribs[0] != 0).all()) + self.assertTrue((attribs[1][:, :, 0] != 0).all()) def test_multi_input_group_across_input_tensors( self, @@ -295,9 +266,7 @@ def forward_func(*x: Tensor) -> Tensor: feature_mask = tuple( torch.zeros_like(inp_tensor[0]).unsqueeze(0) for inp_tensor in inp ) - attribs = feature_importance.attribute( - inp, feature_mask=feature_mask, enable_cross_tensor_attribution=True - ) + attribs = feature_importance.attribute(inp, feature_mask=feature_mask) self.assertTrue(isinstance(attribs, tuple)) self.assertTrue(len(attribs) == 2) @@ -348,26 +317,24 @@ def forward_func(*x: Tensor) -> Tensor: inp[1][:, :, 1] = 4 - for enable_cross_tensor_attribution in [True, False]: - attribs = feature_importance.attribute_future( - inp, - feature_mask=feature_mask, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) - self.assertTrue(type(attribs) is torch.Future) - attribs = attribs.wait() + attribs = feature_importance.attribute_future( + inp, + feature_mask=feature_mask, + ) + self.assertTrue(type(attribs) is torch.Future) + attribs = attribs.wait() - self.assertTrue(isinstance(attribs, tuple)) - self.assertTrue(len(attribs) == 2) + self.assertTrue(isinstance(attribs, tuple)) + self.assertTrue(len(attribs) == 2) - self.assertTrue(attribs[0].squeeze(0).size() == inp1_size) - self.assertTrue(attribs[1].squeeze(0).size() == inp2_size) + self.assertTrue(attribs[0].squeeze(0).size() == inp1_size) + self.assertTrue(attribs[1].squeeze(0).size() == inp2_size) - self.assertTrue((attribs[1][:, :, 1] == 0).all()) - self.assertTrue((attribs[1][:, :, 2] == 0).all()) + self.assertTrue((attribs[1][:, :, 1] == 0).all()) + self.assertTrue((attribs[1][:, :, 2] == 0).all()) - self.assertTrue((attribs[0] != 0).all()) - self.assertTrue((attribs[1][:, :, 0] != 0).all()) + self.assertTrue((attribs[0] != 0).all()) + self.assertTrue((attribs[1][:, :, 0] != 0).all()) def test_multiple_perturbations_per_eval( self, @@ -420,28 +387,26 @@ def forward_func(x: Tensor) -> Tensor: forward_func=self.construct_future_forward(forward_func) ) - for enable_cross_tensor_attribution in [True, False]: - attribs = feature_importance.attribute_future( - inp, - perturbations_per_eval=perturbations_per_eval, - target=target, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) - self.assertTrue(type(attribs) is torch.Future) - attribs = attribs.wait() + attribs = feature_importance.attribute_future( + inp, + perturbations_per_eval=perturbations_per_eval, + target=target, + ) + self.assertTrue(type(attribs) is torch.Future) + attribs = attribs.wait() - self.assertTrue(attribs.size() == (batch_size,) + input_size) + self.assertTrue(attribs.size() == (batch_size,) + input_size) - for i in range(inp.size(1)): - if i == target: - continue - assertTensorAlmostEqual( - self, attribs[:, i], torch.zeros_like(attribs[:, i]) - ) + for i in range(inp.size(1)): + if i == target: + continue + assertTensorAlmostEqual( + self, attribs[:, i], torch.zeros_like(attribs[:, i]) + ) - y = forward_func(inp) - actual_diff = torch.stack([(y[0] - y[1])[target], (y[1] - y[0])[target]]) - assertTensorAlmostEqual(self, attribs[:, target], actual_diff) + y = forward_func(inp) + actual_diff = torch.stack([(y[0] - y[1])[target], (y[1] - y[0])[target]]) + assertTensorAlmostEqual(self, attribs[:, target], actual_diff) def test_broadcastable_masks( self, @@ -461,30 +426,28 @@ def forward_func(x: Tensor) -> Tensor: torch.tensor([[0, 1, 2, 3]]), torch.tensor([[[0, 1, 2, 3], [3, 3, 4, 5], [6, 6, 4, 6], [7, 8, 9, 10]]]), ] - for enable_cross_tensor_attribution in (True, False): - for mask in masks: + for mask in masks: + + attribs = feature_importance.attribute( + inp, + feature_mask=mask, + ) + self.assertTrue(attribs is not None) + self.assertTrue(attribs.shape == inp.shape) - attribs = feature_importance.attribute( - inp, - feature_mask=mask, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, + fm = mask.expand_as(inp[0]) + + features = set(mask.flatten()) + for feature in features: + m = (fm == feature).bool() + attribs_for_feature = attribs[:, m] + assertTensorAlmostEqual( + self, + attribs_for_feature[0], + -attribs_for_feature[1], + delta=0.05, + mode="max", ) - self.assertTrue(attribs is not None) - self.assertTrue(attribs.shape == inp.shape) - - fm = mask.expand_as(inp[0]) - - features = set(mask.flatten()) - for feature in features: - m = (fm == feature).bool() - attribs_for_feature = attribs[:, m] - assertTensorAlmostEqual( - self, - attribs_for_feature[0], - -attribs_for_feature[1], - delta=0.05, - mode="max", - ) def test_broadcastable_masks_with_future( self, @@ -507,35 +470,33 @@ def forward_func(x: Tensor) -> Tensor: torch.tensor([[[0, 1, 2, 3], [3, 3, 4, 5], [6, 6, 4, 6], [7, 8, 9, 10]]]), ] - for enable_cross_tensor_attribution in [True, False]: - results = [] - for mask in masks: - attribs_future = feature_importance.attribute_future( - inp, - feature_mask=mask, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, + results = [] + for mask in masks: + attribs_future = feature_importance.attribute_future( + inp, + feature_mask=mask, + ) + results.append(attribs_future) + self.assertTrue(attribs_future is not None) + + for idx in range(len(results)): + attribs = results[idx].wait() + self.assertTrue(attribs is not None) + self.assertTrue(attribs.shape == inp.shape) + + fm = masks[idx].expand_as(inp[0]) + + features = set(masks[idx].flatten()) + for feature in features: + m = (fm == feature).bool() + attribs_for_feature = attribs[:, m] + assertTensorAlmostEqual( + self, + attribs_for_feature[0], + -attribs_for_feature[1], + delta=0.05, + mode="max", ) - results.append(attribs_future) - self.assertTrue(attribs_future is not None) - - for idx in range(len(results)): - attribs = results[idx].wait() - self.assertTrue(attribs is not None) - self.assertTrue(attribs.shape == inp.shape) - - fm = masks[idx].expand_as(inp[0]) - - features = set(masks[idx].flatten()) - for feature in features: - m = (fm == feature).bool() - attribs_for_feature = attribs[:, m] - assertTensorAlmostEqual( - self, - attribs_for_feature[0], - -attribs_for_feature[1], - delta=0.05, - mode="max", - ) def test_empty_sparse_features(self) -> None: model = BasicModelWithSparseInputs() @@ -544,13 +505,11 @@ def test_empty_sparse_features(self) -> None: # test empty sparse tensor feature_importance = FeaturePermutation(model) - for enable_cross_tensor_attribution in (True, False): - attr1, attr2 = feature_importance.attribute( - (inp1, inp2), - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) - self.assertEqual(attr1.shape, (1, 3)) - self.assertEqual(attr2.shape, (1,)) + attr1, attr2 = feature_importance.attribute( + (inp1, inp2), + ) + self.assertEqual(attr1.shape, (1, 3)) + self.assertEqual(attr2.shape, (1,)) def test_sparse_features(self) -> None: model = BasicModelWithSparseInputs() @@ -560,21 +519,18 @@ def test_sparse_features(self) -> None: feature_importance = FeaturePermutation(model) - for enable_cross_tensor_attribution in [True, False]: - set_all_random_seeds(1234) - total_attr1, total_attr2 = feature_importance.attribute( + set_all_random_seeds(1234) + total_attr1, total_attr2 = feature_importance.attribute( + (inp1, inp2), + ) + for _ in range(50): + attr1, attr2 = feature_importance.attribute( (inp1, inp2), - enable_cross_tensor_attribution=enable_cross_tensor_attribution, ) - for _ in range(50): - attr1, attr2 = feature_importance.attribute( - (inp1, inp2), - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) - total_attr1 += attr1 - total_attr2 += attr2 - total_attr1 /= 50 - total_attr2 /= 50 - self.assertEqual(total_attr2.shape, (1,)) - assertTensorAlmostEqual(self, total_attr1, torch.zeros_like(total_attr1)) - assertTensorAlmostEqual(self, total_attr2, [-6.0], delta=0.2) + total_attr1 += attr1 + total_attr2 += attr2 + total_attr1 /= 50 + total_attr2 /= 50 + self.assertEqual(total_attr2.shape, (1,)) + assertTensorAlmostEqual(self, total_attr1, torch.zeros_like(total_attr1)) + assertTensorAlmostEqual(self, total_attr2, [-6.0], delta=0.2) From d5eac9e7e011ae6c61704db9164a6ae63c0949bf Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Tue, 23 Sep 2025 20:03:55 -0700 Subject: [PATCH 2/3] Clean up unused function in Occlusion (#1650) Summary: No longer needed when enable_cross_tensor_attribution=False is no longer supported in the FeatureAblation parent class Differential Revision: D83109001 --- captum/attr/_core/occlusion.py | 66 ---------------------------------- 1 file changed, 66 deletions(-) diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py index 3e5b9c2b6..9279a0b48 100644 --- a/captum/attr/_core/occlusion.py +++ b/captum/attr/_core/occlusion.py @@ -275,57 +275,6 @@ def attribute_future(self) -> None: """ raise NotImplementedError("attribute_future is not implemented for Occlusion") - def _construct_ablated_input( - self, - expanded_input: Tensor, - input_mask: Union[None, Tensor, Tuple[Tensor, ...]], - baseline: Union[None, float, Tensor], - start_feature: int, - end_feature: int, - **kwargs: Any, - ) -> Tuple[Tensor, Tensor]: - r""" - Ablates given expanded_input tensor with given feature mask, feature range, - and baselines, and any additional arguments. - expanded_input shape is (num_features, num_examples, ...) - with remaining dimensions corresponding to remaining original tensor - dimensions and num_features = end_feature - start_feature. - - input_mask is None for occlusion, and the mask is constructed - using sliding_window_tensors, strides, and shift counts, which are provided in - kwargs. baseline is expected to - be broadcastable to match expanded_input. - - This method returns the ablated input tensor, which has the same - dimensionality as expanded_input as well as the corresponding mask with - either the same dimensionality as expanded_input or second dimension - being 1. This mask contains 1s in locations which have been ablated (and - thus counted towards ablations for that feature) and 0s otherwise. - """ - input_mask = torch.stack( - [ - self._occlusion_mask( - expanded_input, - j, - kwargs["sliding_window_tensors"], - kwargs["strides"], - kwargs["shift_counts"], - is_expanded_input=True, - ) - for j in range(start_feature, end_feature) - ], - dim=0, - ).long() - assert baseline is not None, "baseline should not be None" - ablated_tensor = ( - expanded_input - * ( - torch.ones(1, dtype=torch.long, device=expanded_input.device) - - input_mask - ).to(expanded_input.dtype) - ) + (baseline * input_mask.to(expanded_input.dtype)) - return ablated_tensor, input_mask - def _occlusion_mask( self, input: Tensor, @@ -380,21 +329,6 @@ def _occlusion_mask( ) return padded_tensor.reshape((1,) + tuple(padded_tensor.shape)) - def _get_feature_range_and_mask( - self, input: Tensor, input_mask: Optional[Tensor], **kwargs: Any - ) -> Tuple[int, int, Union[None, Tensor, Tuple[Tensor, ...]]]: - feature_max = int(np.prod(kwargs["shift_counts"])) - return 0, feature_max, None - - def _get_feature_counts( - self, - inputs: TensorOrTupleOfTensorsGeneric, - feature_mask: Tuple[Tensor, ...], - **kwargs: Any, - ) -> Tuple[int, ...]: - """return the numbers of possible input features""" - return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"]) - def _get_feature_idx_to_tensor_idx( self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any ) -> Dict[int, List[int]]: From a7f0eeb134451e7d1482f75234ee1f0a316a343b Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Tue, 23 Sep 2025 20:03:55 -0700 Subject: [PATCH 3/3] Clean up independent feature masking in FeatureAblation (#1651) Summary: Finally. Differential Revision: D83110168 --- captum/attr/_core/feature_ablation.py | 723 ++------------------------ tests/attr/test_feature_ablation.py | 138 ++--- 2 files changed, 87 insertions(+), 774 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 99873050a..8df1e53d3 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -4,18 +4,7 @@ import logging import math -from typing import ( - Any, - Callable, - cast, - Dict, - Generator, - List, - Optional, - Tuple, - TypeVar, - Union, -) +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TypeVar, Union import torch from captum._utils.common import ( @@ -92,11 +81,6 @@ def __init__( # behavior stays consistent and no longer check again self._is_output_shape_valid = False - # Minimum number of elements needed in each input tensor when - # `enable_cross_tensor_attribution`, is False, otherwise the attribution - # for the tensor will be skipped - self._min_examples_per_batch = 1 - # Similar to above, when `enable_cross_tensor_attribution` is True. # Considering the case when we permute multiple input tensors at once # through `feature_mask`, we disregard the feature group if the 0th # dim of *any* input tensor in the group is less than @@ -115,7 +99,6 @@ def attribute( feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, perturbations_per_eval: int = 1, show_progress: bool = False, - enable_cross_tensor_attribution: bool = True, **kwargs: Any, ) -> TensorOrTupleOfTensorsGeneric: r""" @@ -207,9 +190,6 @@ def attribute( should contain integers in the range 0 to num_features - 1, and indices corresponding to the same feature should have the same value. - Note that features within each input tensor are ablated - independently (not across tensors), unless - enable_cross_tensor_attribution is True. If the forward function returns a single scalar per batch, we enforce that the first dimension of each mask must be 1, since attributions are returned batch-wise rather than per @@ -240,10 +220,6 @@ def attribute( (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False - enable_cross_tensor_attribution (bool, optional): If True, features - IDs in feature_mask are global IDs across input tensors, - and are ablated together. - Default: False **kwargs (Any, optional): Any additional arguments used by child classes of FeatureAblation (such as Occlusion) to construct ablations. These arguments are ignored when using @@ -309,7 +285,6 @@ def attribute( formatted_additional_forward_args = _format_additional_forward_args( additional_forward_args ) - num_examples = formatted_inputs[0].shape[0] formatted_feature_mask = _format_feature_mask(feature_mask, formatted_inputs) assert ( @@ -321,7 +296,6 @@ def attribute( attr_progress = self._attribute_progress_setup( formatted_inputs, formatted_feature_mask, - enable_cross_tensor_attribution, **kwargs, perturbations_per_eval=perturbations_per_eval, ) @@ -362,133 +336,30 @@ def attribute( formatted_inputs, ) - if enable_cross_tensor_attribution: - total_attrib, weights = self._attribute_with_cross_tensor_feature_masks( - formatted_inputs, - formatted_additional_forward_args, - target, - baselines, - formatted_feature_mask, - attr_progress, - flattened_initial_eval, - initial_eval, - n_outputs, - total_attrib, - weights, - attrib_type, - perturbations_per_eval, - **kwargs, - ) - else: - total_attrib, weights = self._attribute_with_independent_feature_masks( - formatted_inputs, - formatted_additional_forward_args, - target, - baselines, - formatted_feature_mask, - num_examples, - perturbations_per_eval, - attr_progress, - initial_eval, - flattened_initial_eval, - n_outputs, - total_attrib, - weights, - attrib_type, - **kwargs, - ) - - if attr_progress is not None: - attr_progress.close() - - return cast( - TensorOrTupleOfTensorsGeneric, - self._generate_result(total_attrib, weights, is_inputs_tuple), - ) - - def _attribute_with_independent_feature_masks( - self, - formatted_inputs: Tuple[Tensor, ...], - formatted_additional_forward_args: Optional[Tuple[object, ...]], - target: TargetType, - baselines: BaselineType, - formatted_feature_mask: Tuple[Tensor, ...], - num_examples: int, - perturbations_per_eval: int, - attr_progress: Optional[tqdm], - initial_eval: Tensor, - flattened_initial_eval: Tensor, - n_outputs: int, - total_attrib: List[Tensor], - weights: List[Tensor], - attrib_type: dtype, - **kwargs: Any, - ) -> Tuple[List[Tensor], List[Tensor]]: - # Iterate through each feature tensor for ablation - for i in range(len(formatted_inputs)): - if torch.numel(formatted_inputs[i]) == 0: - logger.info( - f"Skipping input tensor at index {i} since it contains no elements" - ) - continue - if formatted_inputs[i].shape[0] < self._min_examples_per_batch: - logger.warning( - f"Skipping input tensor at index {i} since its 0th dim " - f"({formatted_inputs[i].shape[0]}) " - f"is less than {self._min_examples_per_batch}" - ) - continue - - for ( - current_inputs, - current_add_args, - current_target, - current_mask, - ) in self._ith_input_ablation_generator( - i, + total_attrib, weights = self._attribute_with_cross_tensor_feature_masks( formatted_inputs, formatted_additional_forward_args, target, baselines, formatted_feature_mask, + attr_progress, + flattened_initial_eval, + initial_eval, + n_outputs, + total_attrib, + weights, + attrib_type, perturbations_per_eval, **kwargs, - ): - # modified_eval has (n_feature_perturbed * n_outputs) elements - # shape: - # agg mode: (*initial_eval.shape) - # non-agg mode: - # (feature_perturbed * batch_size, *initial_eval.shape[1:]) - modified_eval: Union[Tensor, Future[Tensor]] = _run_forward( - self.forward_func, - current_inputs, - current_target, - current_add_args, - ) + ) - if attr_progress is not None: - attr_progress.update() + if attr_progress is not None: + attr_progress.close() - assert not isinstance(modified_eval, torch.Future), ( - "when use_futures is True, modified_eval should have " - f"non-Future type rather than {type(modified_eval)}" - ) - total_attrib, weights = self._process_ablated_out( - modified_eval, - current_inputs, - current_mask, - perturbations_per_eval, - num_examples, - initial_eval, - flattened_initial_eval, - formatted_inputs, - n_outputs, - total_attrib, - weights, - i, - attrib_type, - ) - return total_attrib, weights + return cast( + TensorOrTupleOfTensorsGeneric, + self._generate_result(total_attrib, weights, is_inputs_tuple), + ) def _attribute_with_cross_tensor_feature_masks( self, @@ -735,7 +606,6 @@ def attribute_future( feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, perturbations_per_eval: int = 1, show_progress: bool = False, - enable_cross_tensor_attribution: bool = True, **kwargs: Any, ) -> Future[TensorOrTupleOfTensorsGeneric]: r""" @@ -761,7 +631,6 @@ def attribute_future( attr_progress = self._attribute_progress_setup( formatted_inputs, formatted_feature_mask, - enable_cross_tensor_attribution, **kwargs, perturbations_per_eval=perturbations_per_eval, ) @@ -796,149 +665,20 @@ def attribute_future( ) ) - if enable_cross_tensor_attribution: - return cast( - Future[TensorOrTupleOfTensorsGeneric], - self._attribute_with_cross_tensor_feature_masks_future( - formatted_inputs=formatted_inputs, - formatted_additional_forward_args=formatted_additional_forward_args, # noqa: E501 line too long - target=target, - baselines=baselines, - formatted_feature_mask=formatted_feature_mask, - attr_progress=attr_progress, - processed_initial_eval_fut=processed_initial_eval_fut, - is_inputs_tuple=is_inputs_tuple, - perturbations_per_eval=perturbations_per_eval, - ), - ) - else: - return cast( - Future[TensorOrTupleOfTensorsGeneric], - self._attribute_with_independent_feature_masks_future( # type: ignore # noqa: E501 line too long - formatted_inputs, - formatted_additional_forward_args, - target, - baselines, - formatted_feature_mask, - perturbations_per_eval, - attr_progress, - processed_initial_eval_fut, - is_inputs_tuple, - **kwargs, - ), - ) - - def _attribute_with_independent_feature_masks_future( - self, - formatted_inputs: Tuple[Tensor, ...], - formatted_additional_forward_args: Optional[Tuple[object, ...]], - target: TargetType, - baselines: BaselineType, - formatted_feature_mask: Tuple[Tensor, ...], - perturbations_per_eval: int, - attr_progress: Optional[tqdm], - processed_initial_eval_fut: Future[ - Tuple[List[Tensor], List[Tensor], Tensor, Tensor, int, dtype] - ], - is_inputs_tuple: bool, - **kwargs: Any, - ) -> Future[Union[Tensor, Tuple[Tensor, ...]]]: - num_examples = formatted_inputs[0].shape[0] - # The will be the same amount futures as modified_eval down there, - # since we cannot add up the evaluation result adhoc under async mode. - all_modified_eval_futures: List[ - List[Future[Tuple[List[Tensor], List[Tensor]]]] - ] = [[] for _ in range(len(formatted_inputs))] - # Iterate through each feature tensor for ablation - for i in range(len(formatted_inputs)): - # Skip any empty input tensors - if torch.numel(formatted_inputs[i]) == 0: - continue - - for ( - current_inputs, - current_add_args, - current_target, - current_mask, - ) in self._ith_input_ablation_generator( - i, - formatted_inputs, - formatted_additional_forward_args, - target, - baselines, - formatted_feature_mask, - perturbations_per_eval, - **kwargs, - ): - # modified_eval has (n_feature_perturbed * n_outputs) elements - # shape: - # agg mode: (*initial_eval.shape) - # non-agg mode: - # (feature_perturbed * batch_size, *initial_eval.shape[1:]) - modified_eval: Union[Tensor, Future[Tensor]] = _run_forward( - self.forward_func, - current_inputs, - current_target, - current_add_args, - ) - - if attr_progress is not None: - attr_progress.update() - - if not isinstance(modified_eval, torch.Future): - raise AssertionError( - "when using attribute_future, modified_eval should have " - f"Future type rather than {type(modified_eval)}" - ) - if processed_initial_eval_fut is None: - raise AssertionError( - "processed_initial_eval_fut should not be None" - ) - - # Need to collect both initial eval and modified_eval - eval_futs: Future[ - List[ - Future[ - Union[ - Tuple[ - List[Tensor], - List[Tensor], - Tensor, - Tensor, - int, - dtype, - ], - Tensor, - ] - ] - ] - ] = collect_all( - [ - processed_initial_eval_fut, - modified_eval, - ] - ) - - ablated_out_fut: Future[Tuple[List[Tensor], List[Tensor]]] = ( - eval_futs.then( - lambda eval_futs, current_inputs=current_inputs, current_mask=current_mask, i=i: self._eval_fut_to_ablated_out_fut( # type: ignore # noqa: E501 line too long - eval_futs=eval_futs, - current_inputs=current_inputs, - current_mask=current_mask, - i=i, - perturbations_per_eval=perturbations_per_eval, - num_examples=num_examples, - formatted_inputs=formatted_inputs, - ) - ) - ) - - all_modified_eval_futures[i].append(ablated_out_fut) - - if attr_progress is not None: - attr_progress.close() - - return self._generate_async_result(all_modified_eval_futures, is_inputs_tuple) # type: ignore # noqa: E501 line too long + return cast( + Future[TensorOrTupleOfTensorsGeneric], + self._attribute_with_cross_tensor_feature_masks_future( + formatted_inputs=formatted_inputs, + formatted_additional_forward_args=formatted_additional_forward_args, # noqa: E501 line too long + target=target, + baselines=baselines, + formatted_feature_mask=formatted_feature_mask, + attr_progress=attr_progress, + processed_initial_eval_fut=processed_initial_eval_fut, + is_inputs_tuple=is_inputs_tuple, + perturbations_per_eval=perturbations_per_eval, + ), + ) def _attribute_with_cross_tensor_feature_masks_future( self, @@ -1115,21 +855,11 @@ def _attribute_progress_setup( self, formatted_inputs: Tuple[Tensor, ...], feature_mask: Tuple[Tensor, ...], - enable_cross_tensor_attribution: bool, perturbations_per_eval: int, **kwargs: Any, ) -> tqdm: - feature_counts = self._get_feature_counts( - formatted_inputs, feature_mask, **kwargs - ) - total_forwards = ( - math.ceil( - get_total_features_from_mask(feature_mask) / perturbations_per_eval - ) - if enable_cross_tensor_attribution - else sum( - math.ceil(count / perturbations_per_eval) for count in feature_counts - ) + total_forwards = math.ceil( + get_total_features_from_mask(feature_mask) / perturbations_per_eval ) total_forwards += 1 # add 1 for the initial eval attr_progress = progress( @@ -1137,69 +867,6 @@ def _attribute_progress_setup( ) return attr_progress - def _eval_fut_to_ablated_out_fut( - self, - eval_futs: Future[List[Future[List[object]]]], - current_inputs: Tuple[Tensor, ...], - current_mask: Tensor, - i: int, - perturbations_per_eval: int, - num_examples: int, - formatted_inputs: Tuple[Tensor, ...], - ) -> Tuple[List[Tensor], List[Tensor]]: - try: - modified_eval = cast(Tensor, eval_futs.value()[1].value()) - initial_eval_tuple = cast( - Tuple[ - List[Tensor], - List[Tensor], - Tensor, - Tensor, - int, - dtype, - ], - eval_futs.value()[0].value(), - ) - if len(initial_eval_tuple) != 6: - raise AssertionError( - "eval_fut_to_ablated_out_fut: " - "initial_eval_tuple should have 6 elements: " - "total_attrib, weights, initial_eval, " - "flattened_initial_eval, n_outputs, attrib_type " - ) - if not isinstance(modified_eval, Tensor): - raise AssertionError( - "eval_fut_to_ablated_out_fut: " "modified eval should be a Tensor" - ) - ( - total_attrib, - weights, - initial_eval, - flattened_initial_eval, - n_outputs, - attrib_type, - ) = initial_eval_tuple - result = self._process_ablated_out( # type: ignore # noqa: E501 line too long - modified_eval=modified_eval, - current_inputs=current_inputs, - current_mask=current_mask, - perturbations_per_eval=perturbations_per_eval, - num_examples=num_examples, - initial_eval=initial_eval, - flattened_initial_eval=flattened_initial_eval, - inputs=formatted_inputs, - n_outputs=n_outputs, - total_attrib=total_attrib, - weights=weights, - i=i, - attrib_type=attrib_type, - ) - except FeatureAblationFutureError as e: - raise FeatureAblationFutureError( - "eval_fut_to_ablated_out_fut func failed)" - ) from e - return result - def _generate_async_result_cross_tensor( self, futs: List[Future[Tuple[List[Tensor], List[Tensor]]]], @@ -1288,205 +955,6 @@ def _eval_fut_to_ablated_out_fut_cross_tensor( ) from e return total_attrib, weights - def _ith_input_ablation_generator( - self, - i: int, - inputs: TensorOrTupleOfTensorsGeneric, - additional_args: Optional[Tuple[object, ...]], - target: TargetType, - baselines: BaselineType, - input_mask: Union[None, Tensor, Tuple[Tensor, ...]], - perturbations_per_eval: int, - **kwargs: Any, - ) -> Generator[ - Tuple[ - Tuple[Tensor, ...], - object, - TargetType, - Tensor, - ], - None, - None, - ]: - """ - This method returns a generator of ablation perturbations of the i-th input - - Returns: - ablation_iter (Generator): yields each perturbation to be evaluated - as a tuple (inputs, additional_forward_args, targets, mask). - """ - extra_args = {} - for key, value in kwargs.items(): - # For any tuple argument in kwargs, we choose index i of the tuple. - if isinstance(value, tuple): - extra_args[key] = value[i] - else: - extra_args[key] = value - - cur_input_mask = input_mask[i] if input_mask is not None else None - min_feature, num_features, cur_input_mask = self._get_feature_range_and_mask( - inputs[i], cur_input_mask, **extra_args - ) - num_examples = inputs[0].shape[0] - perturbations_per_eval = min(perturbations_per_eval, num_features) - baseline = baselines[i] if isinstance(baselines, tuple) else baselines - if isinstance(baseline, torch.Tensor): - baseline = baseline.reshape((1,) + tuple(baseline.shape)) - - additional_args_repeated: object - if perturbations_per_eval > 1: - # Repeat features and additional args for batch size. - all_features_repeated = [ - torch.cat([inputs[j]] * perturbations_per_eval, dim=0) - for j in range(len(inputs)) - ] - additional_args_repeated = ( - _expand_additional_forward_args(additional_args, perturbations_per_eval) - if additional_args is not None - else None - ) - target_repeated = _expand_target(target, perturbations_per_eval) - else: - all_features_repeated = list(inputs) - additional_args_repeated = additional_args - target_repeated = target - - num_features_processed = min_feature - current_additional_args: object - while num_features_processed < num_features: - current_num_ablated_features = min( - perturbations_per_eval, num_features - num_features_processed - ) - - # Store appropriate inputs and additional args based on batch size. - if current_num_ablated_features != perturbations_per_eval: - current_features = [ - feature_repeated[0 : current_num_ablated_features * num_examples] - for feature_repeated in all_features_repeated - ] - current_additional_args = ( - _expand_additional_forward_args( - additional_args, current_num_ablated_features - ) - if additional_args is not None - else None - ) - current_target = _expand_target(target, current_num_ablated_features) - else: - current_features = all_features_repeated - current_additional_args = additional_args_repeated - current_target = target_repeated - - # Store existing tensor before modifying - original_tensor = current_features[i] - # Construct ablated batch for features in range num_features_processed - # to num_features_processed + current_num_ablated_features and return - # mask with same size as ablated batch. ablated_features has dimension - # (current_num_ablated_features, num_examples, inputs[i].shape[1:]) - # Note that in the case of sparse tensors, the second dimension - # may not necessarilly be num_examples and will match the first - # dimension of this tensor. - current_reshaped = current_features[i].reshape( - (current_num_ablated_features, -1) - + tuple(current_features[i].shape[1:]) - ) - - ablated_features, current_mask = self._construct_ablated_input( - current_reshaped, - cur_input_mask, - baseline, - num_features_processed, - num_features_processed + current_num_ablated_features, - **extra_args, - ) - - # current_features[i] has dimension - # (current_num_ablated_features * num_examples, inputs[i].shape[1:]), - # which can be provided to the model as input. - current_features[i] = ablated_features.reshape( - (-1,) + tuple(ablated_features.shape[2:]) - ) - yield tuple( - current_features - ), current_additional_args, current_target, current_mask - # Replace existing tensor at index i. - current_features[i] = original_tensor - num_features_processed += current_num_ablated_features - - def _construct_ablated_input( - self, - expanded_input: Tensor, - input_mask: Union[None, Tensor, Tuple[Tensor, ...]], - baseline: Union[None, float, Tensor], - start_feature: int, - end_feature: int, - **kwargs: Any, - ) -> Tuple[Tensor, Tensor]: - r""" - Ablates given expanded_input tensor with given feature mask, feature range, - and baselines. expanded_input shape is (`num_features`, `num_examples`, ...) - with remaining dimensions corresponding to remaining original tensor - dimensions and `num_features` = `end_feature` - `start_feature`. - input_mask has same number of dimensions as original input tensor (one less - than `expanded_input`), and can have first dimension either 1, applying same - feature mask to all examples, or `num_examples`. baseline is expected to - be broadcastable to match `expanded_input`. - - This method returns the ablated input tensor, which has the same - dimensionality as `expanded_input` as well as the corresponding mask with - either the same dimensionality as `expanded_input` or second dimension - being 1. This mask contains 1s in locations which have been ablated (and - thus counted towards ablations for that feature) and 0s otherwise. - """ - current_mask = torch.stack( - cast(List[Tensor], [input_mask == j for j in range(start_feature, end_feature)]), # type: ignore # noqa: E501 line too long - dim=0, - ).long() - current_mask = current_mask.to(expanded_input.device) - assert baseline is not None, "baseline must be provided" - ablated_tensor = ( - expanded_input * (1 - current_mask).to(expanded_input.dtype) - ) + (baseline * current_mask.to(expanded_input.dtype)) - return ablated_tensor, current_mask - - def _get_feature_range_and_mask( - self, - input: Tensor, - input_mask: Optional[Tensor], - **kwargs: Any, - ) -> Tuple[int, int, Union[None, Tensor, Tuple[Tensor, ...]]]: - if input_mask is None: - # Obtain feature mask for selected input tensor, matches size of - # 1 input example, (1 x inputs[i].shape[1:]) - input_mask = torch.reshape( - torch.arange(torch.numel(input[0]), device=input.device), - input[0:1].shape, - ).long() - return ( - int(torch.min(input_mask).item()), - int(torch.max(input_mask).item() + 1), - input_mask, - ) - - def _get_feature_counts( - self, - inputs: TensorOrTupleOfTensorsGeneric, - feature_mask: Tuple[Tensor, ...], - **kwargs: Any, - ) -> Tuple[float, ...]: - """return the numbers of input features""" - if not feature_mask: - return tuple(inp[0].numel() if inp.numel() else 0 for inp in inputs) - - return tuple( - ( - (mask.max() - mask.min()).item() + 1 - if mask is not None - else (inp[0].numel() if inp.numel() else 0) - ) - for inp, mask in zip(inputs, feature_mask) - ) - def _parse_forward_out(self, forward_output: Tensor) -> Tensor: """ A temp wrapper for global _run_forward util to force forward output @@ -1551,72 +1019,6 @@ def _process_initial_eval( attrib_type, ) - def _process_ablated_out( - self, - modified_eval: Tensor, - current_inputs: Tuple[Tensor, ...], - current_mask: Tensor, - perturbations_per_eval: int, - num_examples: int, - initial_eval: Tensor, - flattened_initial_eval: Tensor, - inputs: TensorOrTupleOfTensorsGeneric, - n_outputs: int, - total_attrib: List[Tensor], - weights: List[Tensor], - i: int, - attrib_type: dtype, - ) -> Tuple[List[Tensor], List[Tensor]]: - modified_eval = self._parse_forward_out(modified_eval) - - # if perturbations_per_eval > 1, the output shape must grow with - # input and not be aggregated - if perturbations_per_eval > 1 and not self._is_output_shape_valid: - current_batch_size = current_inputs[0].shape[0] - - # number of perturbation, which is not the same as - # perturbations_per_eval when not enough features to perturb - n_perturb = current_batch_size / num_examples - - current_output_shape = modified_eval.shape - - # use initial_eval as the forward of perturbations_per_eval = 1 - initial_output_shape = initial_eval.shape - - assert ( - # check if the output is not a scalar - current_output_shape - and initial_output_shape - # check if the output grow in same ratio, i.e., not agg - and current_output_shape[0] == n_perturb * initial_output_shape[0] - ), ( - "When perturbations_per_eval > 1, forward_func's output " - "should be a tensor whose 1st dim grow with the input " - f"batch size: when input batch size is {num_examples}, " - f"the output shape is {initial_output_shape}; " - f"when input batch size is {current_batch_size}, " - f"the output shape is {current_output_shape}" - ) - - self._is_output_shape_valid = True - - # reshape the leading dim for n_feature_perturbed - # flatten each feature's eval outputs into 1D of (n_outputs) - modified_eval = modified_eval.reshape(-1, n_outputs) - # eval_diff in shape (n_feature_perturbed, n_outputs) - eval_diff = flattened_initial_eval - modified_eval - - # append the shape of one input example - # to make it broadcastable to mask - eval_diff = eval_diff.reshape(eval_diff.shape + (inputs[i].dim() - 1) * (1,)) - eval_diff = eval_diff.to(total_attrib[i].device) - - if self.use_weights: - weights[i] += current_mask.float().sum(dim=0) - - total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(dim=0) - return total_attrib, weights - def _process_ablated_out_full( self, modified_eval: Tensor, @@ -1685,67 +1087,6 @@ def _process_ablated_out_full( return total_attrib, weights - def _fut_tuple_to_accumulate_fut_list( - self, - total_attrib: List[Tensor], - weights: List[Tensor], - i: int, - fut_tuple: Future[Tuple[List[Tensor], List[Tensor]]], - ) -> None: - try: - attrib, weight = fut_tuple.value() - self._accumulate_for_single_input(total_attrib, weights, i, attrib, weight) - except FeatureAblationFutureError as e: - raise FeatureAblationFutureError( - "fut_tuple_to_accumulate_fut_list failed" - ) from e - - def _generate_async_result( - self, - futs: List[List[Future[Tuple[List[Tensor], List[Tensor]]]]], - is_inputs_tuple: bool, - ) -> Future[Union[Tensor, Tuple[Tensor, ...]]]: - # Each element of the 2d list contains evalutaion results for a feature - # Need to add up all the results for each input - accumulate_fut_list: List[Future[None]] = [] - total_attrib: List[Tensor] = [] - weights: List[Tensor] = [] - - for i, fut_tuples in enumerate(futs): - for fut_tuple in fut_tuples: - - accumulate_fut_list.append( - fut_tuple.then( - lambda fut_tuple, i=i: self._fut_tuple_to_accumulate_fut_list( # type: ignore # noqa: E501 line too long - total_attrib, weights, i, fut_tuple - ) - ) - ) - - result_fut = collect_all(accumulate_fut_list).then( - lambda x: self._generate_result(total_attrib, weights, is_inputs_tuple) - ) - - return result_fut - - def _accumulate_for_single_input( - self, - total_attrib: List[Tensor], - weights: List[Tensor], - idx: int, - attrib: List[Tensor], - weight: List[Tensor], - ) -> None: - if total_attrib: - total_attrib[idx] = attrib[idx] - else: - total_attrib.extend(attrib) - if self.use_weights: - if weights: - weights[idx] = weight[idx] - else: - weights.extend(weight) - def _generate_result( self, total_attrib: List[Tensor], diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index b47af6264..e18b1fda1 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -253,7 +253,6 @@ def sum_forward(*inps: torch.Tensor) -> torch.Tensor: expected, target=None, baselines=(0.2, 3.1, 0.4), - test_enable_cross_tensor_attribution=[False, True], ) def test_multi_input_ablation_with_mask_weighted(self) -> None: @@ -308,57 +307,33 @@ def test_multi_input_ablation_with_mask_dupe_feature_idx(self) -> None: mask1 = torch.tensor([[1, 1, 1], [0, 1, 0]]) mask2 = torch.tensor([[0, 1, 2]]) mask3 = torch.tensor([[0, 1, 2], [0, 0, 0]]) - expected = ( - [[492.0, 492.0, 492.0], [200.0, 200.0, 200.0]], - [[80.0, 200.0, 120.0], [0.0, 400.0, 0.0]], - [[0.0, 400.0, 40.0], [60.0, 60.0, 60.0]], - ) - expected_cross_tensor = ( + expected_out = ( [[1092.0, 1092.0, 1092.0], [260.0, 600.0, 260.0]], [[80.0, 1092.0, 160.0], [260.0, 600.0, 0.0]], [[80.0, 1092.0, 160.0], [260.0, 260.0, 260.0]], ) - for test_enable_cross_tensor_attribution, expected_out in [ - (True, expected_cross_tensor), - (False, expected), - ]: - self._ablation_test_assert( - ablation_algo, - (inp1, inp2, inp3), - expected_out, - additional_input=(1,), - feature_mask=(mask1, mask2, mask3), - test_enable_cross_tensor_attribution=[ - test_enable_cross_tensor_attribution - ], - ) - - expected_with_baseline = ( - [[468.0, 468.0, 468.0], [184.0, 192.0, 184.0]], - [[68.0, 188.0, 108.0], [-12.0, 388.0, -12.0]], - [[-16.0, 384.0, 24.0], [12.0, 12.0, 12.0]], + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected_out, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), ) + expected_cross_tensor_with_baseline = ( [[1040.0, 1040.0, 1040.0], [184.0, 580.0, 184.0]], [[52.0, 1040.0, 132.0], [184.0, 580.0, -12.0]], [[52.0, 1040.0, 132.0], [184.0, 184.0, 184.0]], ) - for test_enable_cross_tensor_attribution, expected_out in [ - (True, expected_cross_tensor_with_baseline), - (False, expected_with_baseline), - ]: - self._ablation_test_assert( - ablation_algo, - (inp1, inp2, inp3), - expected_out, - additional_input=(1,), - feature_mask=(mask1, mask2, mask3), - baselines=(2, 3.0, 4), - perturbations_per_eval=(1, 2, 3), - test_enable_cross_tensor_attribution=[ - test_enable_cross_tensor_attribution - ], - ) + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected_cross_tensor_with_baseline, + additional_input=(1,), + feature_mask=(mask1, mask2, mask3), + baselines=(2, 3.0, 4), + perturbations_per_eval=(1, 2, 3), + ) def test_multi_input_ablation_with_mask_nt(self) -> None: ablation_algo = NoiseTunnel(FeatureAblation(BasicModel_MultiLayer_MultiInput())) @@ -882,50 +857,47 @@ def _ablation_test_assert( perturbations_per_eval: Tuple[int, ...] = (1,), baselines: BaselineType = None, target: TargetType = 0, - test_enable_cross_tensor_attribution: List[bool] = [True, False], test_future: bool = False, **kwargs: Any, ) -> None: - for enable_cross_tensor_attribution in test_enable_cross_tensor_attribution: - for batch_size in perturbations_per_eval: - self.assertTrue(ablation_algo.multiplies_by_inputs) - if isinstance(ablation_algo, FeatureAblation) and test_future: - attributions = ablation_algo.attribute_future( - test_input, - target=target, - feature_mask=feature_mask, - additional_forward_args=additional_input, - baselines=baselines, - perturbations_per_eval=batch_size, - **kwargs, - ).wait() - else: - attributions = ablation_algo.attribute( - test_input, - target=target, - feature_mask=feature_mask, - additional_forward_args=additional_input, - baselines=baselines, - perturbations_per_eval=batch_size, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - **kwargs, - ) - if isinstance(expected_ablation, tuple): - for i in range(len(expected_ablation)): - expected = expected_ablation[i] - if not isinstance(expected, torch.Tensor): - expected = torch.tensor(expected) - - self.assertEqual(attributions[i].shape, expected.shape) - self.assertEqual(attributions[i].dtype, expected.dtype) - assertTensorAlmostEqual(self, attributions[i], expected) - else: - if not isinstance(expected_ablation, torch.Tensor): - expected_ablation = torch.tensor(expected_ablation) - - self.assertEqual(attributions.shape, expected_ablation.shape) - self.assertEqual(attributions.dtype, expected_ablation.dtype) - assertTensorAlmostEqual(self, attributions, expected_ablation) + for batch_size in perturbations_per_eval: + self.assertTrue(ablation_algo.multiplies_by_inputs) + if isinstance(ablation_algo, FeatureAblation) and test_future: + attributions = ablation_algo.attribute_future( + test_input, + target=target, + feature_mask=feature_mask, + additional_forward_args=additional_input, + baselines=baselines, + perturbations_per_eval=batch_size, + **kwargs, + ).wait() + else: + attributions = ablation_algo.attribute( + test_input, + target=target, + feature_mask=feature_mask, + additional_forward_args=additional_input, + baselines=baselines, + perturbations_per_eval=batch_size, + **kwargs, + ) + if isinstance(expected_ablation, tuple): + for i in range(len(expected_ablation)): + expected = expected_ablation[i] + if not isinstance(expected, torch.Tensor): + expected = torch.tensor(expected) + + self.assertEqual(attributions[i].shape, expected.shape) + self.assertEqual(attributions[i].dtype, expected.dtype) + assertTensorAlmostEqual(self, attributions[i], expected) + else: + if not isinstance(expected_ablation, torch.Tensor): + expected_ablation = torch.tensor(expected_ablation) + + self.assertEqual(attributions.shape, expected_ablation.shape) + self.assertEqual(attributions.dtype, expected_ablation.dtype) + assertTensorAlmostEqual(self, attributions, expected_ablation) if __name__ == "__main__":