@@ -549,9 +549,20 @@ class NeuronGradientMethodBatched(NeuronGradientMethod):
549549 :py :class :`NeuronGradientMethod`
550550 """
551551
552- def compute_attribution_map(self,
553- attribute_to_neuron_input=False,
554- batch_size=1024):
552+ def compute_attribution_map(
553+ self,
554+ attribute_to_neuron_input: bool = False,
555+ batch_size: int = 1024
556+ ) -> dict:
557+ """ Compute attribution map using mini - batches .
558+
559+ Args :
560+ attribute_to_neuron_input : If True , attribute to neuron input
561+ batch_size : Size of mini - batches for processing
562+
563+ Returns :
564+ Dictionary containing attribution maps
565+ """
555566 input_data_batches = torch.split(self.input_data, batch_size)
556567
557568 attribution_map = []
@@ -571,9 +582,7 @@ def compute_attribution_map(self,
571582
572583 attribution_map = np.vstack(attribution_map)
573584 return self._reduce_attribution_map(
574- {'neuron-gradient': attribution_map,
575- #'neuron-gradient-invsvd': self._inverse_svd(attribution_map)
576- })
585+ {'neuron-gradient': attribution_map})
577586
578587
579588@dataclasses.dataclass
@@ -601,7 +610,7 @@ def __post_init__(self):
601610 """
602611 super().__post_init__()
603612 self.captum_model = NeuronFeatureAblation(forward_func=self.model,
604- layer=self.model)
613+ layer=self.model)
605614
606615 def compute_attribution_map(self,
607616 baselines=None,
@@ -638,20 +647,19 @@ def compute_attribution_map(self,
638647
639648@dataclasses.dataclass
640649@register("feature-ablation-batched")
641- class FeatureAblationMethodBAtched (FeatureAblationMethod):
650+ class FeatureAblationMethodBatched (FeatureAblationMethod):
642651 """ As :py :class :`FeatureAblationMethod` , but using mini - batches .
643652
644653 See also :
645654 :py :class :`FeatureAblationMethod`
646655 """
647656
648657 def compute_attribution_map(self,
649- baselines=None,
650- feature_mask=None,
651- perturbations_per_eval=1,
652- attribute_to_neuron_input=False,
653- batch_size=1024):
654-
658+ baselines=None,
659+ feature_mask=None,
660+ perturbations_per_eval=1,
661+ attribute_to_neuron_input=False,
662+ batch_size=1024):
655663 input_data_batches = torch.split(self.input_data, batch_size)
656664 attribution_map = []
657665 for input_data_batch in input_data_batches:
0 commit comments