Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 5 additions & 69 deletions captum/attr/_core/feature_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ class FeaturePermutation(FeatureAblation):
of examples to compute attributions and cannot be performed on a single example.

By default, each scalar value within
each input tensor is taken as a feature and shuffled independently, *unless*
attribute() is called with enable_cross_tensor_attribution=True. Passing
a feature mask, allows grouping features to be shuffled together.
each input tensor is taken as a feature and shuffled independently. Passing
a feature mask allows grouping features to be shuffled together (including
features defined across different input tensors).
Each input scalar in the group will be given the same attribution value
equal to the change in target as a result of shuffling the entire feature
group.
Expand Down Expand Up @@ -92,12 +92,6 @@ def __init__(
"""
FeatureAblation.__init__(self, forward_func=forward_func)
self.perm_func = perm_func
# Minimum number of elements needed in each input tensor, when
# `enable_cross_tensor_attribution` is False, otherwise the
# attribution for the tensor will be skipped. Set to 1 to throw if any
# input tensors only have one example
self._min_examples_per_batch = 2
# Similar to above, when `enable_cross_tensor_attribution` is True.
# Considering the case when we permute multiple input tensors at once
# through `feature_mask`, we disregard the feature group if the 0th
# dim of *any* input tensor in the group is less than
Expand All @@ -115,7 +109,6 @@ def attribute( # type: ignore
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
enable_cross_tensor_attribution: bool = True,
**kwargs: Any,
) -> TensorOrTupleOfTensorsGeneric:
r"""
Expand Down Expand Up @@ -187,18 +180,12 @@ def attribute( # type: ignore
input tensor. Each tensor should contain integers in
the range 0 to num_features - 1, and indices
corresponding to the same feature should have the
same value. Note that features within each input
tensor are ablated independently (not across
tensors), unless enable_cross_tensor_attribution is
True.

same value.
The first dimension of each mask must be 1, as we require
to have the same group of features for each input sample.

If None, then a feature mask is constructed which assigns
each scalar within a tensor as a separate feature, which
is permuted independently, unless
enable_cross_tensor_attribution is True.
each scalar within a tensor as a separate feature.
Default: None
perturbations_per_eval (int, optional): Allows permutations
of multiple features to be processed simultaneously
Expand All @@ -217,10 +204,6 @@ def attribute( # type: ignore
(e.g. time estimation). Otherwise, it will fallback to
a simple output of progress.
Default: False
enable_cross_tensor_attribution (bool, optional): If True, then
features can be grouped across input tensors depending on
the values in the feature mask.
Default: False
**kwargs (Any, optional): Any additional arguments used by child
classes of :class:`.FeatureAblation` (such as
:class:`.Occlusion`) to construct ablations. These
Expand Down Expand Up @@ -292,7 +275,6 @@ def attribute( # type: ignore
feature_mask=feature_mask,
perturbations_per_eval=perturbations_per_eval,
show_progress=show_progress,
enable_cross_tensor_attribution=enable_cross_tensor_attribution,
**kwargs,
)

Expand All @@ -304,7 +286,6 @@ def attribute_future(
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
perturbations_per_eval: int = 1,
show_progress: bool = False,
enable_cross_tensor_attribution: bool = True,
**kwargs: Any,
) -> Future[TensorOrTupleOfTensorsGeneric]:
"""
Expand All @@ -321,54 +302,9 @@ def attribute_future(
feature_mask=feature_mask,
perturbations_per_eval=perturbations_per_eval,
show_progress=show_progress,
enable_cross_tensor_attribution=enable_cross_tensor_attribution,
**kwargs,
)

def _construct_ablated_input(
self,
expanded_input: Tensor,
input_mask: Union[None, Tensor, Tuple[Tensor, ...]],
baseline: Union[None, float, Tensor],
start_feature: int,
end_feature: int,
**kwargs: Any,
) -> Tuple[Tensor, Tensor]:
r"""
This function permutes the features of `expanded_input` with a given
feature mask and feature range. Permutation occurs via calling
`self.perm_func` across each batch within `expanded_input`. As with
`FeatureAblation._construct_ablated_input`:
- `expanded_input.shape = (num_features, num_examples, ...)`
- `num_features = end_feature - start_feature` (i.e. start and end is a
half-closed interval)
- `input_mask` is a tensor of the same shape as one input, which
describes the locations of each feature via their "index"

Since `baselines` is set to None for `FeatureAblation.attribute, this
will be the zero tensor, however, it is not used.
"""
assert (
input_mask is not None
and not isinstance(input_mask, tuple)
and input_mask.shape[0] == 1
), (
"input_mask.shape[0] != 1: pass in one mask in order to permute"
"the same features for each input"
)
current_mask = torch.stack(
[input_mask == j for j in range(start_feature, end_feature)], dim=0
).bool()
current_mask = current_mask.to(expanded_input.device)

output = torch.stack(
[
self.perm_func(x, mask.squeeze(0))
for x, mask in zip(expanded_input, current_mask)
]
)
return output, current_mask

def _construct_ablated_input_across_tensors(
self,
inputs: Tuple[Tensor, ...],
Expand Down
66 changes: 0 additions & 66 deletions captum/attr/_core/occlusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,57 +275,6 @@ def attribute_future(self) -> None:
"""
raise NotImplementedError("attribute_future is not implemented for Occlusion")

def _construct_ablated_input(
self,
expanded_input: Tensor,
input_mask: Union[None, Tensor, Tuple[Tensor, ...]],
baseline: Union[None, float, Tensor],
start_feature: int,
end_feature: int,
**kwargs: Any,
) -> Tuple[Tensor, Tensor]:
r"""
Ablates given expanded_input tensor with given feature mask, feature range,
and baselines, and any additional arguments.
expanded_input shape is (num_features, num_examples, ...)
with remaining dimensions corresponding to remaining original tensor
dimensions and num_features = end_feature - start_feature.

input_mask is None for occlusion, and the mask is constructed
using sliding_window_tensors, strides, and shift counts, which are provided in
kwargs. baseline is expected to
be broadcastable to match expanded_input.

This method returns the ablated input tensor, which has the same
dimensionality as expanded_input as well as the corresponding mask with
either the same dimensionality as expanded_input or second dimension
being 1. This mask contains 1s in locations which have been ablated (and
thus counted towards ablations for that feature) and 0s otherwise.
"""
input_mask = torch.stack(
[
self._occlusion_mask(
expanded_input,
j,
kwargs["sliding_window_tensors"],
kwargs["strides"],
kwargs["shift_counts"],
is_expanded_input=True,
)
for j in range(start_feature, end_feature)
],
dim=0,
).long()
assert baseline is not None, "baseline should not be None"
ablated_tensor = (
expanded_input
* (
torch.ones(1, dtype=torch.long, device=expanded_input.device)
- input_mask
).to(expanded_input.dtype)
) + (baseline * input_mask.to(expanded_input.dtype))
return ablated_tensor, input_mask

def _occlusion_mask(
self,
input: Tensor,
Expand Down Expand Up @@ -380,21 +329,6 @@ def _occlusion_mask(
)
return padded_tensor.reshape((1,) + tuple(padded_tensor.shape))

def _get_feature_range_and_mask(
self, input: Tensor, input_mask: Optional[Tensor], **kwargs: Any
) -> Tuple[int, int, Union[None, Tensor, Tuple[Tensor, ...]]]:
feature_max = int(np.prod(kwargs["shift_counts"]))
return 0, feature_max, None

def _get_feature_counts(
self,
inputs: TensorOrTupleOfTensorsGeneric,
feature_mask: Tuple[Tensor, ...],
**kwargs: Any,
) -> Tuple[int, ...]:
"""return the numbers of possible input features"""
return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"])

def _get_feature_idx_to_tensor_idx(
self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any
) -> Dict[int, List[int]]:
Expand Down
Loading
Loading