diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index 06190af6e..a34537a44 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -704,9 +704,11 @@ def _construct_ablated_input_across_tensors( tensor_mask.append(mask) assert baseline is not None, "baseline must be provided" - ablated_input[start_idx:end_idx] = input_tensor[start_idx:end_idx] * ( - 1 - mask + ablated_feature = input_tensor[start_idx:end_idx] * (1 - mask).to( + input_tensor.dtype ) + (baseline * mask.to(input_tensor.dtype)) + ablated_input = ablated_input.to(ablated_feature.dtype) + ablated_input[start_idx:end_idx] = ablated_feature current_masks.append(torch.stack(tensor_mask, dim=0)) ablated_inputs.append(ablated_input) diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index 7ce411f74..b47af6264 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -220,6 +220,42 @@ def test_multi_input_ablation_with_mask(self) -> None: perturbations_per_eval=(1, 2, 3), ) + def test_multi_input_ablation_with_int_input_tensor_and_float_baseline( + self, + ) -> None: + def sum_forward(*inps: torch.Tensor) -> torch.Tensor: + flattened = [torch.flatten(inp, start_dim=1) for inp in inps] + return torch.cat(flattened, dim=1).sum(1) + + ablation_algo = FeatureAblation(sum_forward) + inp1 = torch.tensor([[0, 1], [3, 4]]) + inp2 = torch.tensor( + [ + [[0.1, 0.2], [0.3, 0.2]], + [[0.4, 0.5], [0.3, 0.2]], + ] + ) + inp3 = torch.tensor([[0], [1]]) + + expected = ( + torch.tensor([[-0.2, 0.8], [2.8, 3.8]]), + torch.tensor( + [ + [[-3.0, -2.9], [-2.8, -2.9]], + [[-2.7, -2.6], [-2.8, -2.9]], + ] + ), + torch.tensor([[-0.4], [0.6]]), + ) + self._ablation_test_assert( + ablation_algo, + (inp1, inp2, inp3), + expected, + target=None, + baselines=(0.2, 3.1, 0.4), + test_enable_cross_tensor_attribution=[False, True], + ) + def test_multi_input_ablation_with_mask_weighted(self) -> None: ablation_algo = FeatureAblation(BasicModel_MultiLayer_MultiInput()) ablation_algo.use_weights = True