@@ -275,57 +275,6 @@ def attribute_future(self) -> None:
275
275
"""
276
276
raise NotImplementedError ("attribute_future is not implemented for Occlusion" )
277
277
278
- def _construct_ablated_input (
279
- self ,
280
- expanded_input : Tensor ,
281
- input_mask : Union [None , Tensor , Tuple [Tensor , ...]],
282
- baseline : Union [None , float , Tensor ],
283
- start_feature : int ,
284
- end_feature : int ,
285
- ** kwargs : Any ,
286
- ) -> Tuple [Tensor , Tensor ]:
287
- r"""
288
- Ablates given expanded_input tensor with given feature mask, feature range,
289
- and baselines, and any additional arguments.
290
- expanded_input shape is (num_features, num_examples, ...)
291
- with remaining dimensions corresponding to remaining original tensor
292
- dimensions and num_features = end_feature - start_feature.
293
-
294
- input_mask is None for occlusion, and the mask is constructed
295
- using sliding_window_tensors, strides, and shift counts, which are provided in
296
- kwargs. baseline is expected to
297
- be broadcastable to match expanded_input.
298
-
299
- This method returns the ablated input tensor, which has the same
300
- dimensionality as expanded_input as well as the corresponding mask with
301
- either the same dimensionality as expanded_input or second dimension
302
- being 1. This mask contains 1s in locations which have been ablated (and
303
- thus counted towards ablations for that feature) and 0s otherwise.
304
- """
305
- input_mask = torch .stack (
306
- [
307
- self ._occlusion_mask (
308
- expanded_input ,
309
- j ,
310
- kwargs ["sliding_window_tensors" ],
311
- kwargs ["strides" ],
312
- kwargs ["shift_counts" ],
313
- is_expanded_input = True ,
314
- )
315
- for j in range (start_feature , end_feature )
316
- ],
317
- dim = 0 ,
318
- ).long ()
319
- assert baseline is not None , "baseline should not be None"
320
- ablated_tensor = (
321
- expanded_input
322
- * (
323
- torch .ones (1 , dtype = torch .long , device = expanded_input .device )
324
- - input_mask
325
- ).to (expanded_input .dtype )
326
- ) + (baseline * input_mask .to (expanded_input .dtype ))
327
- return ablated_tensor , input_mask
328
-
329
278
def _occlusion_mask (
330
279
self ,
331
280
input : Tensor ,
@@ -380,21 +329,6 @@ def _occlusion_mask(
380
329
)
381
330
return padded_tensor .reshape ((1 ,) + tuple (padded_tensor .shape ))
382
331
383
- def _get_feature_range_and_mask (
384
- self , input : Tensor , input_mask : Optional [Tensor ], ** kwargs : Any
385
- ) -> Tuple [int , int , Union [None , Tensor , Tuple [Tensor , ...]]]:
386
- feature_max = int (np .prod (kwargs ["shift_counts" ]))
387
- return 0 , feature_max , None
388
-
389
- def _get_feature_counts (
390
- self ,
391
- inputs : TensorOrTupleOfTensorsGeneric ,
392
- feature_mask : Tuple [Tensor , ...],
393
- ** kwargs : Any ,
394
- ) -> Tuple [int , ...]:
395
- """return the numbers of possible input features"""
396
- return tuple (np .prod (counts ).astype (int ) for counts in kwargs ["shift_counts" ])
397
-
398
332
def _get_feature_idx_to_tensor_idx (
399
333
self , formatted_feature_mask : Tuple [Tensor , ...], ** kwargs : Any
400
334
) -> Dict [int , List [int ]]:
0 commit comments