Skip to content

Commit d71aefc

Browse files
jjunchofacebook-github-bot
authored andcommitted
Optimization in SVS when formatting total_attrib (#1648)
Summary: Pull Request resolved: #1648 This diff focuses on optimizing the Shapley Value Sampling by formatting a default dictionary to list all of the feature indices with a curr_mask that is non-zero, and only updates the total_attrib for values which should be updated rather than iterating through the entire feature space which will mostly be unaffected. Reviewed By: vivekmig Differential Revision: D81800703 fbshipit-source-id: 53b1a5f0589b4bf6392139d4772e5ec68887e4ae
1 parent 38230a7 commit d71aefc

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

captum/attr/_core/shapley_value.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
List,
1616
Optional,
1717
Sequence,
18+
Set,
1819
Tuple,
1920
Union,
2021
)
@@ -382,6 +383,7 @@ def attribute(
382383
current_add_args,
383384
current_target,
384385
current_masks,
386+
current_feat_list,
385387
) in self._perturbation_generator(
386388
inputs_tuple,
387389
additional_forward_args,
@@ -423,8 +425,7 @@ def attribute(
423425
all_eval[num_examples:] - all_eval[:-num_examples]
424426
).to(inputs_tuple[0].device)
425427
prev_results = all_eval[-num_examples:]
426-
427-
for j in range(len(total_attrib)):
428+
for j in current_feat_list:
428429
# format eval_diff to shape
429430
# (n_perturb, *output_shape, 1,.. 1)
430431
# where n_perturb may not be perturb_per_eval
@@ -545,6 +546,7 @@ def attribute_future(
545546
current_add_args,
546547
current_target,
547548
current_masks,
549+
_,
548550
) in self._perturbation_generator(
549551
inputs_tuple,
550552
additional_forward_args,
@@ -827,7 +829,9 @@ def _perturbation_generator(
827829
input_masks: Tuple[Tensor, ...],
828830
feature_permutation: Sequence[int],
829831
perturbations_per_eval: int,
830-
) -> Iterable[Tuple[Tuple[Tensor, ...], object, TargetType, Tuple[Tensor, ...]]]:
832+
) -> Iterable[
833+
Tuple[Tuple[Tensor, ...], object, TargetType, Tuple[Tensor, ...], Set[int]]
834+
]:
831835
"""
832836
This method is a generator which yields each perturbation to be evaluated
833837
including inputs, additional_forward_args, targets, and mask.
@@ -837,6 +841,7 @@ def _perturbation_generator(
837841
current_tensors = baselines
838842
current_tensors_list = []
839843
current_mask_list = []
844+
current_feat_list = set()
840845

841846
# Compute repeated additional args and targets
842847
additional_args_repeated = (
@@ -869,6 +874,7 @@ def _perturbation_generator(
869874
device=inputs[0].device,
870875
)
871876
)
877+
current_feat_list.update(feat_tensor_index_map[feature_permutation[i]])
872878

873879
if len(current_tensors_list) == perturbations_per_eval:
874880
if len(current_tensors_list) > 1:
@@ -888,9 +894,11 @@ def _perturbation_generator(
888894
additional_args_repeated,
889895
target_repeated,
890896
combined_masks,
897+
current_feat_list,
891898
)
892899
current_tensors_list = []
893900
current_mask_list = []
901+
current_feat_list = set()
894902

895903
# Create batch with remaining evaluations, may not be a complete batch
896904
# (= perturbations_per_eval)
@@ -916,6 +924,7 @@ def _perturbation_generator(
916924
additional_args_repeated,
917925
target_repeated,
918926
combined_masks,
927+
current_feat_list,
919928
)
920929

921930
def _get_n_evaluations(

0 commit comments

Comments
 (0)