From f2ab1454b2495901276650e55e6d71d2331c3118 Mon Sep 17 00:00:00 2001 From: Sarah Tran Date: Mon, 8 Sep 2025 19:55:18 -0700 Subject: [PATCH] Update FeatureAblation to handle precision loss when baseline is more granular than input when cross tensor attribution is enabled (#1644) Summary: Noticed when flipping the flag, this test case failed: https://www.internalfb.com/code/fbsource/[faf71541b1ec0fae639f82d487b81fb18ea3e523]/fbcode/pytorch/captum/tests/attr/test_dataloader_attr.py?lines=138%2C134 The ablated tensor was `tensor([0])` instead of `tensor([[0.1])` since the baseline was a float-type and the input tensors were int tensors. https://www.internalfb.com/code/fbsource/[f2fcc926a6f3669602bac4d28c2d92e4197c96b9]/fbcode/pytorch/captum/captum/attr/_core/feature_ablation.py?lines=707-709 `ablated_input` is just a copy of the `input_tensor`, so during assignment, the ablated feature tensor incorrectly gets cast to an int tensor for this case. Differential Revision: D81980219 --- captum/attr/_core/feature_ablation.py | 6 +++-- tests/attr/test_feature_ablation.py | 36 +++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 06190af6e..a34537a44 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -704,9 +704,11 @@ def _construct_ablated_input_across_tensors( tensor_mask.append(mask) assert baseline is not None, "baseline must be provided" - ablated_input[start_idx:end_idx] = input_tensor[start_idx:end_idx] * ( - 1 - mask + ablated_feature = input_tensor[start_idx:end_idx] * (1 - mask).to( + input_tensor.dtype ) + (baseline * mask.to(input_tensor.dtype)) + ablated_input = ablated_input.to(ablated_feature.dtype) + ablated_input[start_idx:end_idx] = ablated_feature current_masks.append(torch.stack(tensor_mask, dim=0)) ablated_inputs.append(ablated_input) diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index 7ce411f74..b47af6264 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -220,6 +220,42 @@ def test_multi_input_ablation_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) + def test_multi_input_ablation_with_int_input_tensor_and_float_baseline( + self, + ) -> None: + def sum_forward(*inps: torch.Tensor) -> torch.Tensor: + flattened = [torch.flatten(inp, start_dim=1) for inp in inps] + return torch.cat(flattened, dim=1).sum(1) + + ablation_algo = FeatureAblation(sum_forward) + inp1 = torch.tensor([[0, 1], [3, 4]]) + inp2 = torch.tensor( + [ + [[0.1, 0.2], [0.3, 0.2]], + [[0.4, 0.5], [0.3, 0.2]], + ] + ) + inp3 = torch.tensor([[0], [1]]) + + expected = ( + torch.tensor([[-0.2, 0.8], [2.8, 3.8]]), + torch.tensor( + [ + [[-3.0, -2.9], [-2.8, -2.9]], + [[-2.7, -2.6], [-2.8, -2.9]], + ] + ), + torch.tensor([[-0.4], [0.6]]), + ) + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected, + target=None, + baselines=(0.2, 3.1, 0.4), + test_enable_cross_tensor_attribution=[False, True], + ) + def test_multi_input_ablation_with_mask_weighted(self) -> None: ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput()) ablation_algo.use_weights = True