diff --git a/captum/attr/_core/feature_permutation.py b/captum/attr/_core/feature_permutation.py index 238f6958e..1b6eb08fa 100644 --- a/captum/attr/_core/feature_permutation.py +++ b/captum/attr/_core/feature_permutation.py @@ -55,9 +55,9 @@ class FeaturePermutation(FeatureAblation): of examples to compute attributions and cannot be performed on a single example. By default, each scalar value within - each input tensor is taken as a feature and shuffled independently, *unless* - attribute() is called with enable_cross_tensor_attribution=True. Passing - a feature mask, allows grouping features to be shuffled together. + each input tensor is taken as a feature and shuffled independently. Passing + a feature mask allows grouping features to be shuffled together (including + features defined across different input tensors). Each input scalar in the group will be given the same attribution value equal to the change in target as a result of shuffling the entire feature group. @@ -92,12 +92,6 @@ def __init__( """ FeatureAblation.__init__(self, forward_func=forward_func) self.perm_func = perm_func - # Minimum number of elements needed in each input tensor, when - # `enable_cross_tensor_attribution` is False, otherwise the - # attribution for the tensor will be skipped. Set to 1 to throw if any - # input tensors only have one example - self._min_examples_per_batch = 2 - # Similar to above, when `enable_cross_tensor_attribution` is True. # Considering the case when we permute multiple input tensors at once # through `feature_mask`, we disregard the feature group if the 0th # dim of *any* input tensor in the group is less than @@ -115,7 +109,6 @@ def attribute( # type: ignore feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, show_progress: bool = False, - enable_cross_tensor_attribution: bool = True, **kwargs: Any, ) -> TensorOrTupleOfTensorsGeneric: r""" @@ -187,18 +180,12 @@ def attribute( # type: ignore input tensor. Each tensor should contain integers in the range 0 to num_features - 1, and indices corresponding to the same feature should have the - same value. Note that features within each input - tensor are ablated independently (not across - tensors), unless enable_cross_tensor_attribution is - True. - + same value. The first dimension of each mask must be 1, as we require to have the same group of features for each input sample. If None, then a feature mask is constructed which assigns - each scalar within a tensor as a separate feature, which - is permuted independently, unless - enable_cross_tensor_attribution is True. + each scalar within a tensor as a separate feature. Default: None perturbations_per_eval (int, optional): Allows permutations of multiple features to be processed simultaneously @@ -217,10 +204,6 @@ def attribute( # type: ignore (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False - enable_cross_tensor_attribution (bool, optional): If True, then - features can be grouped across input tensors depending on - the values in the feature mask. - Default: False **kwargs (Any, optional): Any additional arguments used by child classes of :class:`.FeatureAblation` (such as :class:`.Occlusion`) to construct ablations. These @@ -292,7 +275,6 @@ def attribute( # type: ignore feature_mask=feature_mask, perturbations_per_eval=perturbations_per_eval, show_progress=show_progress, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, **kwargs, ) @@ -304,7 +286,6 @@ def attribute_future( feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None, perturbations_per_eval: int = 1, show_progress: bool = False, - enable_cross_tensor_attribution: bool = True, **kwargs: Any, ) -> Future[TensorOrTupleOfTensorsGeneric]: """ @@ -321,54 +302,9 @@ def attribute_future( feature_mask=feature_mask, perturbations_per_eval=perturbations_per_eval, show_progress=show_progress, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, **kwargs, ) - def _construct_ablated_input( - self, - expanded_input: Tensor, - input_mask: Union[None, Tensor, Tuple[Tensor, ...]], - baseline: Union[None, float, Tensor], - start_feature: int, - end_feature: int, - **kwargs: Any, - ) -> Tuple[Tensor, Tensor]: - r""" - This function permutes the features of `expanded_input` with a given - feature mask and feature range. Permutation occurs via calling - `self.perm_func` across each batch within `expanded_input`. As with - `FeatureAblation._construct_ablated_input`: - - `expanded_input.shape = (num_features, num_examples, ...)` - - `num_features = end_feature - start_feature` (i.e. start and end is a - half-closed interval) - - `input_mask` is a tensor of the same shape as one input, which - describes the locations of each feature via their "index" - - Since `baselines` is set to None for `FeatureAblation.attribute, this - will be the zero tensor, however, it is not used. - """ - assert ( - input_mask is not None - and not isinstance(input_mask, tuple) - and input_mask.shape[0] == 1 - ), ( - "input_mask.shape[0] != 1: pass in one mask in order to permute" - "the same features for each input" - ) - current_mask = torch.stack( - [input_mask == j for j in range(start_feature, end_feature)], dim=0 - ).bool() - current_mask = current_mask.to(expanded_input.device) - - output = torch.stack( - [ - self.perm_func(x, mask.squeeze(0)) - for x, mask in zip(expanded_input, current_mask) - ] - ) - return output, current_mask - def _construct_ablated_input_across_tensors( self, inputs: Tuple[Tensor, ...], diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py index 3e5b9c2b6..9279a0b48 100644 --- a/captum/attr/_core/occlusion.py +++ b/captum/attr/_core/occlusion.py @@ -275,57 +275,6 @@ def attribute_future(self) -> None: """ raise NotImplementedError("attribute_future is not implemented for Occlusion") - def _construct_ablated_input( - self, - expanded_input: Tensor, - input_mask: Union[None, Tensor, Tuple[Tensor, ...]], - baseline: Union[None, float, Tensor], - start_feature: int, - end_feature: int, - **kwargs: Any, - ) -> Tuple[Tensor, Tensor]: - r""" - Ablates given expanded_input tensor with given feature mask, feature range, - and baselines, and any additional arguments. - expanded_input shape is (num_features, num_examples, ...) - with remaining dimensions corresponding to remaining original tensor - dimensions and num_features = end_feature - start_feature. - - input_mask is None for occlusion, and the mask is constructed - using sliding_window_tensors, strides, and shift counts, which are provided in - kwargs. baseline is expected to - be broadcastable to match expanded_input. - - This method returns the ablated input tensor, which has the same - dimensionality as expanded_input as well as the corresponding mask with - either the same dimensionality as expanded_input or second dimension - being 1. This mask contains 1s in locations which have been ablated (and - thus counted towards ablations for that feature) and 0s otherwise. - """ - input_mask = torch.stack( - [ - self._occlusion_mask( - expanded_input, - j, - kwargs["sliding_window_tensors"], - kwargs["strides"], - kwargs["shift_counts"], - is_expanded_input=True, - ) - for j in range(start_feature, end_feature) - ], - dim=0, - ).long() - assert baseline is not None, "baseline should not be None" - ablated_tensor = ( - expanded_input - * ( - torch.ones(1, dtype=torch.long, device=expanded_input.device) - - input_mask - ).to(expanded_input.dtype) - ) + (baseline * input_mask.to(expanded_input.dtype)) - return ablated_tensor, input_mask - def _occlusion_mask( self, input: Tensor, @@ -380,21 +329,6 @@ def _occlusion_mask( ) return padded_tensor.reshape((1,) + tuple(padded_tensor.shape)) - def _get_feature_range_and_mask( - self, input: Tensor, input_mask: Optional[Tensor], **kwargs: Any - ) -> Tuple[int, int, Union[None, Tensor, Tuple[Tensor, ...]]]: - feature_max = int(np.prod(kwargs["shift_counts"])) - return 0, feature_max, None - - def _get_feature_counts( - self, - inputs: TensorOrTupleOfTensorsGeneric, - feature_mask: Tuple[Tensor, ...], - **kwargs: Any, - ) -> Tuple[int, ...]: - """return the numbers of possible input features""" - return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"]) - def _get_feature_idx_to_tensor_idx( self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any ) -> Dict[int, List[int]]: diff --git a/tests/attr/test_feature_permutation.py b/tests/attr/test_feature_permutation.py index 329fea812..69683a402 100644 --- a/tests/attr/test_feature_permutation.py +++ b/tests/attr/test_feature_permutation.py @@ -103,31 +103,12 @@ def forward_func(x: Tensor) -> Tensor: inp[:, 0] = constant_value zeros = torch.zeros_like(inp[:, 0]) - for enable_cross_tensor_attribution in (True, False): - attribs = feature_importance.attribute( - inp, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) - self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size) - assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max") - self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all()) - - def test_simple_input_with_min_examples(self) -> None: - def forward_func(x: Tensor) -> Tensor: - return x.sum(dim=-1) - - feature_importance = FeaturePermutation(forward_func=forward_func) - inp = torch.tensor([[1.0, 2.0]]) - assertTensorAlmostEqual( - self, - feature_importance.attribute(inp, enable_cross_tensor_attribution=False), - torch.tensor([[0.0, 0.0]]), - delta=0.0, + attribs = feature_importance.attribute( + inp, ) - - feature_importance._min_examples_per_batch = 1 - with self.assertRaises(AssertionError): - feature_importance.attribute(inp, enable_cross_tensor_attribution=False) + self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size) + assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max") + self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all()) def test_simple_input_with_min_examples_in_group(self) -> None: def forward_func(x: Tensor) -> Tensor: @@ -137,22 +118,20 @@ def forward_func(x: Tensor) -> Tensor: inp = torch.tensor([[1.0, 2.0]]) assertTensorAlmostEqual( self, - feature_importance.attribute(inp, enable_cross_tensor_attribution=True), + feature_importance.attribute(inp), torch.tensor([[0.0, 0.0]]), delta=0.0, ) assertTensorAlmostEqual( self, - feature_importance.attribute( - torch.tensor([]), enable_cross_tensor_attribution=True - ), + feature_importance.attribute(torch.tensor([])), torch.tensor([0.0]), delta=0.0, ) feature_importance._min_examples_per_batch_grouped = 1 with self.assertRaises(AssertionError): - feature_importance.attribute(inp, enable_cross_tensor_attribution=True) + feature_importance.attribute(inp) def test_simple_input_custom_mask_with_min_examples_in_group(self) -> None: def forward_func(x1: Tensor, x2: Tensor) -> Tensor: @@ -169,18 +148,14 @@ def forward_func(x1: Tensor, x2: Tensor) -> Tensor: ) assertTensorAlmostEqual( self, - feature_importance.attribute( - inp, feature_mask=mask, enable_cross_tensor_attribution=True - )[0], + feature_importance.attribute(inp, feature_mask=mask)[0], torch.tensor([[0.0, 0.0]]), delta=0.0, ) feature_importance._min_examples_per_batch_grouped = 1 with self.assertRaises(AssertionError): - feature_importance.attribute( - inp, feature_mask=mask, enable_cross_tensor_attribution=True - ) + feature_importance.attribute(inp, feature_mask=mask) def test_single_input_with_future( self, @@ -200,18 +175,16 @@ def forward_func(x: Tensor) -> Tensor: inp[:, 0] = constant_value zeros = torch.zeros_like(inp[:, 0]) - for enable_cross_tensor_attribution in [True, False]: - attribs = feature_importance.attribute_future( - inp, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) + attribs = feature_importance.attribute_future( + inp, + ) - self.assertTrue(type(attribs) is torch.Future) - attribs = attribs.wait() + self.assertTrue(type(attribs) is torch.Future) + attribs = attribs.wait() - self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size) - assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max") - self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all()) + self.assertTrue(attribs.squeeze(0).size() == (batch_size,) + input_size) + assertTensorAlmostEqual(self, attribs[:, 0], zeros, delta=0.05, mode="max") + self.assertTrue((attribs[:, 1 : input_size[0]].abs() > 0).all()) def test_multi_input( self, @@ -247,24 +220,22 @@ def forward_func(*x: Tensor) -> Tensor: ) inp[1][:, :, 1] = 4 - for enable_cross_tensor_attribution in (True, False): - attribs = feature_importance.attribute( - inp, - feature_mask=feature_mask, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) + attribs = feature_importance.attribute( + inp, + feature_mask=feature_mask, + ) - self.assertTrue(isinstance(attribs, tuple)) - self.assertTrue(len(attribs) == 2) + self.assertTrue(isinstance(attribs, tuple)) + self.assertTrue(len(attribs) == 2) - self.assertTrue(attribs[0].squeeze(0).size() == inp1_size) - self.assertTrue(attribs[1].squeeze(0).size() == inp2_size) + self.assertTrue(attribs[0].squeeze(0).size() == inp1_size) + self.assertTrue(attribs[1].squeeze(0).size() == inp2_size) - self.assertTrue((attribs[1][:, :, 1] == 0).all()) - self.assertTrue((attribs[1][:, :, 2] == 0).all()) + self.assertTrue((attribs[1][:, :, 1] == 0).all()) + self.assertTrue((attribs[1][:, :, 2] == 0).all()) - self.assertTrue((attribs[0] != 0).all()) - self.assertTrue((attribs[1][:, :, 0] != 0).all()) + self.assertTrue((attribs[0] != 0).all()) + self.assertTrue((attribs[1][:, :, 0] != 0).all()) def test_multi_input_group_across_input_tensors( self, @@ -295,9 +266,7 @@ def forward_func(*x: Tensor) -> Tensor: feature_mask = tuple( torch.zeros_like(inp_tensor[0]).unsqueeze(0) for inp_tensor in inp ) - attribs = feature_importance.attribute( - inp, feature_mask=feature_mask, enable_cross_tensor_attribution=True - ) + attribs = feature_importance.attribute(inp, feature_mask=feature_mask) self.assertTrue(isinstance(attribs, tuple)) self.assertTrue(len(attribs) == 2) @@ -348,26 +317,24 @@ def forward_func(*x: Tensor) -> Tensor: inp[1][:, :, 1] = 4 - for enable_cross_tensor_attribution in [True, False]: - attribs = feature_importance.attribute_future( - inp, - feature_mask=feature_mask, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) - self.assertTrue(type(attribs) is torch.Future) - attribs = attribs.wait() + attribs = feature_importance.attribute_future( + inp, + feature_mask=feature_mask, + ) + self.assertTrue(type(attribs) is torch.Future) + attribs = attribs.wait() - self.assertTrue(isinstance(attribs, tuple)) - self.assertTrue(len(attribs) == 2) + self.assertTrue(isinstance(attribs, tuple)) + self.assertTrue(len(attribs) == 2) - self.assertTrue(attribs[0].squeeze(0).size() == inp1_size) - self.assertTrue(attribs[1].squeeze(0).size() == inp2_size) + self.assertTrue(attribs[0].squeeze(0).size() == inp1_size) + self.assertTrue(attribs[1].squeeze(0).size() == inp2_size) - self.assertTrue((attribs[1][:, :, 1] == 0).all()) - self.assertTrue((attribs[1][:, :, 2] == 0).all()) + self.assertTrue((attribs[1][:, :, 1] == 0).all()) + self.assertTrue((attribs[1][:, :, 2] == 0).all()) - self.assertTrue((attribs[0] != 0).all()) - self.assertTrue((attribs[1][:, :, 0] != 0).all()) + self.assertTrue((attribs[0] != 0).all()) + self.assertTrue((attribs[1][:, :, 0] != 0).all()) def test_multiple_perturbations_per_eval( self, @@ -420,28 +387,26 @@ def forward_func(x: Tensor) -> Tensor: forward_func=self.construct_future_forward(forward_func) ) - for enable_cross_tensor_attribution in [True, False]: - attribs = feature_importance.attribute_future( - inp, - perturbations_per_eval=perturbations_per_eval, - target=target, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) - self.assertTrue(type(attribs) is torch.Future) - attribs = attribs.wait() + attribs = feature_importance.attribute_future( + inp, + perturbations_per_eval=perturbations_per_eval, + target=target, + ) + self.assertTrue(type(attribs) is torch.Future) + attribs = attribs.wait() - self.assertTrue(attribs.size() == (batch_size,) + input_size) + self.assertTrue(attribs.size() == (batch_size,) + input_size) - for i in range(inp.size(1)): - if i == target: - continue - assertTensorAlmostEqual( - self, attribs[:, i], torch.zeros_like(attribs[:, i]) - ) + for i in range(inp.size(1)): + if i == target: + continue + assertTensorAlmostEqual( + self, attribs[:, i], torch.zeros_like(attribs[:, i]) + ) - y = forward_func(inp) - actual_diff = torch.stack([(y[0] - y[1])[target], (y[1] - y[0])[target]]) - assertTensorAlmostEqual(self, attribs[:, target], actual_diff) + y = forward_func(inp) + actual_diff = torch.stack([(y[0] - y[1])[target], (y[1] - y[0])[target]]) + assertTensorAlmostEqual(self, attribs[:, target], actual_diff) def test_broadcastable_masks( self, @@ -461,30 +426,28 @@ def forward_func(x: Tensor) -> Tensor: torch.tensor([[0, 1, 2, 3]]), torch.tensor([[[0, 1, 2, 3], [3, 3, 4, 5], [6, 6, 4, 6], [7, 8, 9, 10]]]), ] - for enable_cross_tensor_attribution in (True, False): - for mask in masks: + for mask in masks: + + attribs = feature_importance.attribute( + inp, + feature_mask=mask, + ) + self.assertTrue(attribs is not None) + self.assertTrue(attribs.shape == inp.shape) - attribs = feature_importance.attribute( - inp, - feature_mask=mask, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, + fm = mask.expand_as(inp[0]) + + features = set(mask.flatten()) + for feature in features: + m = (fm == feature).bool() + attribs_for_feature = attribs[:, m] + assertTensorAlmostEqual( + self, + attribs_for_feature[0], + -attribs_for_feature[1], + delta=0.05, + mode="max", ) - self.assertTrue(attribs is not None) - self.assertTrue(attribs.shape == inp.shape) - - fm = mask.expand_as(inp[0]) - - features = set(mask.flatten()) - for feature in features: - m = (fm == feature).bool() - attribs_for_feature = attribs[:, m] - assertTensorAlmostEqual( - self, - attribs_for_feature[0], - -attribs_for_feature[1], - delta=0.05, - mode="max", - ) def test_broadcastable_masks_with_future( self, @@ -507,35 +470,33 @@ def forward_func(x: Tensor) -> Tensor: torch.tensor([[[0, 1, 2, 3], [3, 3, 4, 5], [6, 6, 4, 6], [7, 8, 9, 10]]]), ] - for enable_cross_tensor_attribution in [True, False]: - results = [] - for mask in masks: - attribs_future = feature_importance.attribute_future( - inp, - feature_mask=mask, - enable_cross_tensor_attribution=enable_cross_tensor_attribution, + results = [] + for mask in masks: + attribs_future = feature_importance.attribute_future( + inp, + feature_mask=mask, + ) + results.append(attribs_future) + self.assertTrue(attribs_future is not None) + + for idx in range(len(results)): + attribs = results[idx].wait() + self.assertTrue(attribs is not None) + self.assertTrue(attribs.shape == inp.shape) + + fm = masks[idx].expand_as(inp[0]) + + features = set(masks[idx].flatten()) + for feature in features: + m = (fm == feature).bool() + attribs_for_feature = attribs[:, m] + assertTensorAlmostEqual( + self, + attribs_for_feature[0], + -attribs_for_feature[1], + delta=0.05, + mode="max", ) - results.append(attribs_future) - self.assertTrue(attribs_future is not None) - - for idx in range(len(results)): - attribs = results[idx].wait() - self.assertTrue(attribs is not None) - self.assertTrue(attribs.shape == inp.shape) - - fm = masks[idx].expand_as(inp[0]) - - features = set(masks[idx].flatten()) - for feature in features: - m = (fm == feature).bool() - attribs_for_feature = attribs[:, m] - assertTensorAlmostEqual( - self, - attribs_for_feature[0], - -attribs_for_feature[1], - delta=0.05, - mode="max", - ) def test_empty_sparse_features(self) -> None: model = BasicModelWithSparseInputs() @@ -544,13 +505,11 @@ def test_empty_sparse_features(self) -> None: # test empty sparse tensor feature_importance = FeaturePermutation(model) - for enable_cross_tensor_attribution in (True, False): - attr1, attr2 = feature_importance.attribute( - (inp1, inp2), - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) - self.assertEqual(attr1.shape, (1, 3)) - self.assertEqual(attr2.shape, (1,)) + attr1, attr2 = feature_importance.attribute( + (inp1, inp2), + ) + self.assertEqual(attr1.shape, (1, 3)) + self.assertEqual(attr2.shape, (1,)) def test_sparse_features(self) -> None: model = BasicModelWithSparseInputs() @@ -560,21 +519,18 @@ def test_sparse_features(self) -> None: feature_importance = FeaturePermutation(model) - for enable_cross_tensor_attribution in [True, False]: - set_all_random_seeds(1234) - total_attr1, total_attr2 = feature_importance.attribute( + set_all_random_seeds(1234) + total_attr1, total_attr2 = feature_importance.attribute( + (inp1, inp2), + ) + for _ in range(50): + attr1, attr2 = feature_importance.attribute( (inp1, inp2), - enable_cross_tensor_attribution=enable_cross_tensor_attribution, ) - for _ in range(50): - attr1, attr2 = feature_importance.attribute( - (inp1, inp2), - enable_cross_tensor_attribution=enable_cross_tensor_attribution, - ) - total_attr1 += attr1 - total_attr2 += attr2 - total_attr1 /= 50 - total_attr2 /= 50 - self.assertEqual(total_attr2.shape, (1,)) - assertTensorAlmostEqual(self, total_attr1, torch.zeros_like(total_attr1)) - assertTensorAlmostEqual(self, total_attr2, [-6.0], delta=0.2) + total_attr1 += attr1 + total_attr2 += attr2 + total_attr1 /= 50 + total_attr2 /= 50 + self.assertEqual(total_attr2.shape, (1,)) + assertTensorAlmostEqual(self, total_attr1, torch.zeros_like(total_attr1)) + assertTensorAlmostEqual(self, total_attr2, [-6.0], delta=0.2)