diff --git a/captum/attr/_core/rex.py b/captum/attr/_core/rex.py new file mode 100644 index 0000000000..e67ba9a2b2 --- /dev/null +++ b/captum/attr/_core/rex.py @@ -0,0 +1,357 @@ +#!/usr/bin/env python3 + +# pyre-strict +import itertools +import math +import random +from collections import deque +from dataclasses import dataclass +from typing import cast, List, Sized, Tuple + +import torch +from captum._utils.typing import BaselineType, TensorOrTupleOfTensorsGeneric + +from captum.attr._utils.attribution import PerturbationAttribution +from captum.attr._utils.common import ( + _format_input_baseline, + _format_output, + _validate_input, +) +from captum.log.dummy_log import log_usage + + +class Partition(Sized): + def __init__( + self, + borders: None | List[Tuple[int, int]] = None, + elements: None | torch.Tensor = None, + size: int = -1, + ): + self.borders = borders + self.elements = elements + self.size = size + + self._mask = None + + def generate_mask(self, shape, device): + # generates a mask for a partition (False indicates membership) + if self._mask is not None: + return self._mask + self._mask = torch.zeros(shape, dtype=torch.bool, device=device) + + # non-contiguous case + if self.elements is not None: + self._mask[tuple(self.elements.T)] = True + + # contiguous case + elif self.borders is not None: + slices = list(slice(lo, hi) for (lo, hi) in self.borders) + self._mask[slices] = True + + return self._mask + + def __len__(self): + return self.size + + +@dataclass(eq=False) +class Mutant: + partitions: List[Partition] + data: torch.Tensor + + # eagerly create the underlying mutant data + def __init__(self, partitions: List[Partition], data: torch.Tensor, neutral, shape): + + # A bitmap in the shape of the input indicating membership to + # a partition in this mutant + mask = torch.zeros(shape, dtype=torch.bool, device=data.device) + for part in partitions: + mask |= part.generate_mask(mask.shape, data.device) + + self.partitions = partitions + self.data = torch.where(mask, data, neutral) + + def __len__(self): + return len(self.partitions) + + +def _powerset(s): + return ( + list(combo) for r in range(len(s) + 1) for combo in itertools.combinations(s, r) + ) + + +def _apply_responsibility(fi, part, responsibility): + distributed = responsibility / len(part) + mask = part.generate_mask(fi.shape, fi.device) + + return torch.where(mask, distributed, fi) + + +def _calculate_responsibility( + subject_partition: Partition, + mutants: List[Mutant], + consistent_mutants: List[Mutant], +) -> float: + recovery_set = {frozenset(m.partitions) for m in consistent_mutants} + + valid_witnesses = [] + for m in mutants: + if subject_partition in m.partitions: + continue + W = m.partitions + + W_set = frozenset(W) + W_plus_P_set = frozenset([subject_partition] + W) + + # W alone does NOT recover, but W ∪ {P} DOES recover. + if (W_set not in recovery_set) and (W_plus_P_set in recovery_set): + valid_witnesses.append(W) + + if not valid_witnesses: + return 0.0 + + k = min(len(w) for w in valid_witnesses) + return 1.0 / (1.0 + float(k)) + + +def _generate_indices(ts): + # return a tensor containing all indices in the input shape + return torch.tensor( + tuple(itertools.product(*(range(s) for s in ts.shape))), + dtype=torch.long, + device=ts.device, + ) + + +class ReX(PerturbationAttribution): + """ + A perturbation-based approach to computing attribution, derived from the + Halpern-Pearl definition of Actual Causality[1]. + + ReX conducts a recursive search on the input to find areas that are + most responsible[3] for a models prediction. ReX splits an input into "partitions", + and masks combinations of these partitions with baseline (neutral) values + to form "mutants". + + Intuitively, where masking a partition never changes a models + prediction, that partition is not responsible for the output. Conversely, + where some combination of masked partitions changes the prediction, each + partition has some responsibility. Specifically, their responsibility is 1/(1+k) + where k is the minimal number of *other* masked partitions required to create + a dependence on a partition. + + Responsible partitions are recursively searched to refine responsibility estimates, + and results are (optionally) merged to produce the final attribution map. + + + [1] - Cause: https://www.cs.cornell.edu/home/halpern/papers/modified-HPdef.pdf + [2] - ReX paper: https://arxiv.org/pdf/2411.08875 + [3] - Responsibility and Blame; https://arxiv.org/pdf/cs/0312038 + """ + + def __init__(self, forward_func): + r""" + Args: + forward_func (Callable): The function to be explained. *Must* return + a scalar for which the equality operator is defined. + """ + PerturbationAttribution.__init__(self, forward_func) + + @log_usage(part_of_slo=True) + def attribute( + self, + inputs: TensorOrTupleOfTensorsGeneric, + baselines: BaselineType = 0, + search_depth: int = 10, + n_partitions: int = 4, + n_searches: int = 5, + assume_locality: bool = False, + ) -> TensorOrTupleOfTensorsGeneric: + r""" + Args: + inputs (Tensor or tuple[Tensor, ...]): An input or tuple of inputs + to be explained. Each input must be of the shape expected by + the forward_func. Where multiple examples are provided, they + must be listed in a tuple. + + baselines (Tensor or tuple[Tensor, ...]): A neutral values to be used + as occlusion values. Where a scalar or tensor is provided, they + are broadcast to the input shape. Where tuples are provided, + they are paired element-wise, and must match the structure of + the input. + + search_depth (int, optional): The maximum depth to which ReX will refine + responsibility estimates for causes. + + n_partitions (optional): The maximum number of partitions to be made out of + the input at each search step. At least 1, and no larger than the + partition size. Where ``contiguous partitioning`` is set to False, + partitions are created using previous attribution maps as heuristics. + + n_searches (int, optional): The number of times the search is to be ran. + + assume_locality (int, optional): Where True, partitioning is contiguous and + attribution maps are merged after each search. Otherwise, + partitioning is initially random, then uses the previous attribution + map as a heuristic for further searches, returning the result of the + final search. + """ + + inputs, baselines = _format_input_baseline(inputs, baselines) + _validate_input(inputs, baselines) + + self._n_partitions: int = n_partitions + self._max_depth: int = search_depth + self._n_searches: int = n_searches + self._assume_locality: bool = assume_locality + + is_input_tuple = isinstance(inputs, tuple) + is_baseline_tuple = isinstance(baselines, tuple) + + attributions = [] + + # broadcast baselines, explain + if is_input_tuple and is_baseline_tuple: + for input, baseline in zip(inputs, baselines): + attributions.append(self._explain(input, baseline)) + elif is_input_tuple and not is_baseline_tuple: + for input in inputs: + attributions.append(self._explain(input, baselines)) + else: + attributions.append(self._explain(inputs, baselines)) + + return cast( + TensorOrTupleOfTensorsGeneric, + _format_output(is_input_tuple, tuple(attributions)), + ) + + @torch.no_grad() + def _explain(self, input, baseline) -> torch.Tensor: + self._device = input.device + self._shape = input.shape + self._size = input.numel() + + prediction = self.forward_func(input) + + prev_attribution = torch.full_like( + input, 0.0, dtype=torch.float32, device=self._device + ) + attribution = torch.full_like( + input, 1.0 / self._size, dtype=torch.float32, device=self._device + ) + + initial_partition = Partition( + borders=list((0, top) for top in self._shape), + elements=_generate_indices(input), + size=self._size, + ) + + for i in range(1, self._n_searches + 1): + Q: deque = deque() + Q.append((initial_partition, 0)) + + while Q: + prev_part, depth = Q.popleft() + partitions = ( + self._contiguous_partition(prev_part, depth) + if self._assume_locality + else self._partition(prev_part, attribution) + ) + + mutants = [ + Mutant(part, input, baseline, self._shape) + for part in _powerset(partitions) + ] + consistent_mutants = [ + mut for mut in mutants if self.forward_func(mut.data) == prediction + ] + + for part in partitions: + resp = _calculate_responsibility(part, mutants, consistent_mutants) + attribution = _apply_responsibility(attribution, part, resp) + + if resp == 1 and len(part) > 1 and self._max_depth > depth: + Q.append((part, depth + 1)) + + # take average of responsibility maps + if self._assume_locality: + prev_attribution += (1 / i) * (attribution - prev_attribution) + attribution = prev_attribution + + return attribution.clone().detach() + + def _partition( + self, part: Partition, responsibility: torch.Tensor + ) -> List[Partition]: + # shuffle candidate indices (randomize tiebreakers) + perm = torch.randperm(len(part), device=self._device) + + assert ( + part is not None + ), "Partitioning strategy changed mid search. Contact developer" + population = part.elements[perm] # type: ignore + + weights = responsibility[tuple(population.T)] + + if torch.sum(weights, dim=None) == 0: + weights = torch.ones_like(weights, device=self._device) / len(weights) + target_weight = torch.sum(weights) / self._n_partitions + + # sort for greedy selection + idx = torch.argsort(weights, descending=False) + weight_sorted, pop_sorted = weights[idx], population[idx] + + # cumulative sum of weights / weight per bucket rounded down gives us bucket ids + eps = torch.finfo(weight_sorted.dtype).eps + c = weight_sorted.cumsum(0) + eps + bin_id = torch.div(c, target_weight, rounding_mode="floor").clamp_min(0).long() + + # count elements in each bucket, and split input accordingly + _, counts = torch.unique_consecutive(bin_id, return_counts=True) + groups = torch.split(pop_sorted, counts.tolist()) + + partitions = [Partition(elements=g, size=len(g)) for g in groups] + return partitions + + def _contiguous_partition(self, part, depth): + ndim = len(self._shape) + split_dim = -1 + + # find a dimension we can split + dmin, dmax = max(self._shape), 0 + for i in range(ndim): + candidate_dim = (i + depth) % ndim + dmin, dmax = tuple(part.borders[candidate_dim]) + + if dmax - dmin > 1: + split_dim = candidate_dim + break + + if split_dim == -1: + return [part] + n_splits = min((dmax - dmin), self._n_partitions) - 1 + + # drop splits randomly + split_points = random.sample(range(dmin + 1, dmax), n_splits) + split_borders = sorted(set([dmin, *split_points, dmax])) + + bins = [] + for i in range(len(split_borders) - 1): + new_borders = list(part.borders) + new_borders[split_dim] = (split_borders[i], split_borders[i + 1]) + + bins.append( + Partition( + borders=tuple(new_borders), + size=math.prod(hi - lo for (lo, hi) in new_borders), + ) + ) + + return bins + + def multiplies_by_inputs(self): + return False + + def has_convergence_delta(self): + return True diff --git a/docs/algorithms_comparison_matrix.md b/docs/algorithms_comparison_matrix.md index e74128ee96..21d27514c4 100644 --- a/docs/algorithms_comparison_matrix.md +++ b/docs/algorithms_comparison_matrix.md @@ -207,7 +207,16 @@ Please, scroll to the right for more details. Depends on the choice of above mentioned attribution algorithm. Depends on the choice of above mentioned attribution algorithm. | Adds gaussian noise to each input example #samples times, calls any above mentioned attribution algorithm for all #samples per example and aggregates / smoothens them based on different techniques for each input example. Supported smoothing techniques include: smoothgrad, vargrad, smoothgrad_sq. - + + ReX + Perturbation + Any function returning a single value + O(#partitions ^ #max_depth) - user defined values + Any function returning a single value + O( #iterations (#partitions ^ #max_depth)) + Yes (strong assumption regarding neutral baseline) + Perturbation based approach based on a recursive search over the input. By recursively occluding partitions of an input, ReX searches for partitions who's values have predictive value wrt. the output. + **^ Including Layer Variant** diff --git a/docs/attribution_algorithms.md b/docs/attribution_algorithms.md index f1d00a8f53..f2b701b359 100644 --- a/docs/attribution_algorithms.md +++ b/docs/attribution_algorithms.md @@ -134,6 +134,15 @@ Kernel SHAP is a method that uses the LIME framework to compute Shapley Values. To learn more about KernelSHAP, visit the following resources: - [Original paper](https://arxiv.org/abs/1705.07874) +### ReX +ReX is a perturbation-based explainability approach, grounded in the theory of Actual Causality[1]. It works by partitioning the input, and occluding all combinations of partitions using a neutral masking value. Where there masking some combination of partitions changes the output of the model, those partitions are recursively re-partitioned to search for ever-smaller parts of the input which are responsible for the final output. + +To learn more about actual causality, responsibility and ReX: +- [Actual Causality](https://www.cs.cornell.edu/home/halpern/papers/causalitybook-ch1-3.html) +- [Responsibility and Blame](https://arxiv.org/pdf/cs/0312038) +- [ReX Original Paper(called DC-Causal here)](https://www.hanachockler.com/iccv2021/) + + ## Layer Attribution ### Layer Conductance Conductance combines the neuron activation with the partial derivatives of both the neuron with respect to the input and the output with respect to the neuron to build a more complete picture of neuron importance. diff --git a/sphinx/source/attribution.rst b/sphinx/source/attribution.rst index ace52dd9a8..922dbaa285 100644 --- a/sphinx/source/attribution.rst +++ b/sphinx/source/attribution.rst @@ -18,3 +18,4 @@ Attribution lime kernel_shap lrp + rex \ No newline at end of file diff --git a/sphinx/source/rex.rst b/sphinx/source/rex.rst new file mode 100644 index 0000000000..daa69933f5 --- /dev/null +++ b/sphinx/source/rex.rst @@ -0,0 +1,6 @@ +ReX +=== + +.. autoclass:: captum.attr._core.rex.ReX + :members: + :show-inheritance: diff --git a/tests/attr/test_rex.py b/tests/attr/test_rex.py new file mode 100644 index 0000000000..b6fc82f9c6 --- /dev/null +++ b/tests/attr/test_rex.py @@ -0,0 +1,222 @@ +import itertools +import math +import random +import statistics + +import matplotlib.pyplot as plt +import torch +import torch.nn.functional as F +from captum.attr._core.rex import ReX + +from captum.testing.helpers.basic import BaseTest +from parameterized import parameterized + + +def visualize_tensor(tensor, cmap="viridis"): + arr = tensor.detach().cpu().numpy() + plt.imshow(arr, cmap=cmap) + plt.colorbar() + plt.show() + + +class Test(BaseTest): + # rename for convenience + ts = torch.tensor + + depth_opts = range(4, 10) + n_partition_opts = range(4, 7) + n_search_opts = range(10, 15) + assume_locality_opts = [True, False] + + all_options = list( + itertools.product( + depth_opts, n_partition_opts, n_search_opts, assume_locality_opts + ) + ) + + def _generate_gaussian_pdf(self, shape, mean): + k = len(shape) + + cov = 0.1 * torch.eye(k) * statistics.mean(shape) + dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=cov) + + grids = torch.meshgrid( + *[torch.arange(n, dtype=torch.float64) for n in shape], indexing="ij" + ) + coords = torch.stack(grids, dim=-1).reshape(-1, k) + + pdf_vals = torch.exp(dist.log_prob(coords)) + return pdf_vals.reshape(*shape) + + @parameterized.expand( + [ + # inputs: baselines: + (ts([1, 2, 3]), ts([[2, 3], [3, 4]])), + ((ts([1]), ts([2]), ts([3])), (ts([1]), ts([2]))), + ((ts([1])), ()), + ((), ts([1])), + ] + ) + def test_input_baseline_mismatch_throws(self, input, baseline): + rex = ReX(lambda x: 1 / 0) # dummy forward, should be unreachable + with self.assertRaises(AssertionError): + rex.attribute(input, baseline) + + @parameterized.expand( + [ + (ts([1, 2, 3]), 0), + (ts([[1, 2, 3], [4, 5, 6]]), 0), + (ts([1, 2, 3, 4]), ts([0, 0, 0, 0])), + (ts([[1, 2], [1, 2]]), ts([[0, 0], [0, 0]])), + (ts([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), 0), + ((ts([1, 2]), ts([3, 4]), ts([5, 6])), (0, 0, 0)), + ( + (ts([1, 2]), ts([3, 4]), ts([5, 6])), + (ts([0, 0]), ts([0, 0]), ts([0, 0])), + ), + ((ts([1, 2]), ts([3, 4])), (ts([0, 0]), ts([0, 0]))), + ] + ) + def test_valid_input_baseline(self, input, baseline): + for o in self.all_options: + rex = ReX(lambda x: True) + + attributions = rex.attribute(input, baseline, *o)[0] + + inp_unwrapped = input + if isinstance(input, tuple): + inp_unwrapped = input[0] + + # Forward_func returns a constant, no responsibility in input + self.assertEqual(torch.sum(attributions), 0) + self.assertEqual(attributions.size(), inp_unwrapped.size()) + + @parameterized.expand( + [ + # input # selected_idx + (ts([1, 2, 3]), 0), + (ts([[1, 2], [3, 4]]), (0, 1)), + (ts([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), (0, 1, 0)), + ] + ) + def test_selector_function(self, input, idx): + for o in self.all_options: + rex = ReX(lambda x: x[idx]) + + attributions = rex.attribute(input, 0, *o)[0] + self.assertEqual( + attributions[idx], 1, f"expected 1 at {idx} but found {attributions}" + ) + + attributions[idx] = 0 + self.assertEqual(torch.sum(attributions), 0) + + @parameterized.expand( + [ + # input shape # important idx + ((4, 4), (0, 0)), + ((12, 12, 12), (1, 2, 1)), + ((12, 12, 12, 6), (1, 1, 4, 1)), + ((1920, 1080), (1, 1)), # image-like + ] + ) + def test_selector_function_large_input(self, input_shape, idx): + rex = ReX(lambda x: x[idx]) + + input = torch.ones(*input_shape) + attributions = rex.attribute( + input, 0, n_partitions=2, search_depth=10, n_searches=3 + )[0] + self.assertGreater(attributions[idx], 0) + attributions[idx] = 0 + self.assertLess(torch.sum(attributions), 1) + + @parameterized.expand( + [ + # input shape # lhs_idx # rhs_idx + ((2, 4), (0, 2), (1, 3)) + ] + ) + def test_boolean_or(self, input_shape, lhs_idx, rhs_idx): + for o in self.all_options: + rex = ReX(lambda x: max(x[lhs_idx], x[rhs_idx])) + input = torch.ones(input_shape) + + attributions = rex.attribute(input, 0, *o)[0] + + self.assertEqual(attributions[lhs_idx], 1.0, f"{attributions}") + self.assertEqual(attributions[rhs_idx], 1.0, f"{attributions}") + + attributions[lhs_idx] = 0 + attributions[rhs_idx] = 0 + self.assertLess(torch.sum(attributions), 1, f"{attributions}") + + @parameterized.expand( + [ + # input shape # lhs_idx # rhs_idx + ((2, 4), (0, 2), (0, 3)) + ] + ) + def test_boolean_and(self, input_shape, lhs_idx, rhs_idx): + for i, o in enumerate(self.all_options): + rex = ReX(lambda x: min(x[lhs_idx], x[rhs_idx])) + input = torch.ones(input_shape) + + attributions = rex.attribute(input, 0, *o)[0] + + self.assertEqual(attributions[lhs_idx], 0.5, f"{attributions}, {i}, {o}") + self.assertEqual(attributions[rhs_idx], 0.5, f"{attributions}, {i}, {o}") + + attributions[lhs_idx] = 0 + attributions[rhs_idx] = 0 + self.assertLess(torch.sum(attributions), 1, f"{attributions}") + + @parameterized.expand( + [ + # shape # mean + ((30, 30),), + ((50, 50),), + ((100, 100),), + ] + ) + def test_gaussian_recovery(self, shape): + random.seed() + eps = 1e-12 + + p = torch.zeros(shape) + for _ in range(3): + center = self.ts([int(random.random() * dim) for dim in shape]) + p += self._generate_gaussian_pdf(shape, center) + + p += eps + p = p / torch.sum(p) + + thresh = math.sqrt(torch.mean(p)) + + def _forward(inp): + return 1 if torch.sum(inp) > thresh else 0 + + rex = ReX(_forward) + for b in self.n_partition_opts: + attributions = rex.attribute( + p, + 0, + n_partitions=b, + search_depth=10, + n_searches=25, + assume_locality=True, + )[0] + + attributions += eps + attrib_norm = attributions / torch.sum(attributions) + + # visualize_tensor(p) + # visualize_tensor(attrib_norm) + # visualize_tensor(p - attrib_norm) + + mid = 0.5 * (p + attrib_norm) + jsd = 0.5 * F.kl_div(p.log(), mid, reduction="sum") + 0.5 * F.kl_div( + attrib_norm.log(), mid, reduction="sum" + ) + + self.assertLess(jsd, 0.1)