Skip to content

Commit d5eac9e

Browse files
sarahtranfbfacebook-github-bot
authored andcommitted
Clean up unused function in Occlusion (#1650)
Summary: No longer needed when enable_cross_tensor_attribution=False is no longer supported in the FeatureAblation parent class Differential Revision: D83109001
1 parent 45e2051 commit d5eac9e

File tree

1 file changed

+0
-66
lines changed

1 file changed

+0
-66
lines changed

captum/attr/_core/occlusion.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -275,57 +275,6 @@ def attribute_future(self) -> None:
275275
"""
276276
raise NotImplementedError("attribute_future is not implemented for Occlusion")
277277

278-
def _construct_ablated_input(
279-
self,
280-
expanded_input: Tensor,
281-
input_mask: Union[None, Tensor, Tuple[Tensor, ...]],
282-
baseline: Union[None, float, Tensor],
283-
start_feature: int,
284-
end_feature: int,
285-
**kwargs: Any,
286-
) -> Tuple[Tensor, Tensor]:
287-
r"""
288-
Ablates given expanded_input tensor with given feature mask, feature range,
289-
and baselines, and any additional arguments.
290-
expanded_input shape is (num_features, num_examples, ...)
291-
with remaining dimensions corresponding to remaining original tensor
292-
dimensions and num_features = end_feature - start_feature.
293-
294-
input_mask is None for occlusion, and the mask is constructed
295-
using sliding_window_tensors, strides, and shift counts, which are provided in
296-
kwargs. baseline is expected to
297-
be broadcastable to match expanded_input.
298-
299-
This method returns the ablated input tensor, which has the same
300-
dimensionality as expanded_input as well as the corresponding mask with
301-
either the same dimensionality as expanded_input or second dimension
302-
being 1. This mask contains 1s in locations which have been ablated (and
303-
thus counted towards ablations for that feature) and 0s otherwise.
304-
"""
305-
input_mask = torch.stack(
306-
[
307-
self._occlusion_mask(
308-
expanded_input,
309-
j,
310-
kwargs["sliding_window_tensors"],
311-
kwargs["strides"],
312-
kwargs["shift_counts"],
313-
is_expanded_input=True,
314-
)
315-
for j in range(start_feature, end_feature)
316-
],
317-
dim=0,
318-
).long()
319-
assert baseline is not None, "baseline should not be None"
320-
ablated_tensor = (
321-
expanded_input
322-
* (
323-
torch.ones(1, dtype=torch.long, device=expanded_input.device)
324-
- input_mask
325-
).to(expanded_input.dtype)
326-
) + (baseline * input_mask.to(expanded_input.dtype))
327-
return ablated_tensor, input_mask
328-
329278
def _occlusion_mask(
330279
self,
331280
input: Tensor,
@@ -380,21 +329,6 @@ def _occlusion_mask(
380329
)
381330
return padded_tensor.reshape((1,) + tuple(padded_tensor.shape))
382331

383-
def _get_feature_range_and_mask(
384-
self, input: Tensor, input_mask: Optional[Tensor], **kwargs: Any
385-
) -> Tuple[int, int, Union[None, Tensor, Tuple[Tensor, ...]]]:
386-
feature_max = int(np.prod(kwargs["shift_counts"]))
387-
return 0, feature_max, None
388-
389-
def _get_feature_counts(
390-
self,
391-
inputs: TensorOrTupleOfTensorsGeneric,
392-
feature_mask: Tuple[Tensor, ...],
393-
**kwargs: Any,
394-
) -> Tuple[int, ...]:
395-
"""return the numbers of possible input features"""
396-
return tuple(np.prod(counts).astype(int) for counts in kwargs["shift_counts"])
397-
398332
def _get_feature_idx_to_tensor_idx(
399333
self, formatted_feature_mask: Tuple[Tensor, ...], **kwargs: Any
400334
) -> Dict[int, List[int]]:

0 commit comments

Comments
 (0)