Skip to content

Commit 5f26a24

Browse files
committed
format fix
1 parent ace0a68 commit 5f26a24

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

cebra/attribution/attribution_models.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -549,9 +549,20 @@ class NeuronGradientMethodBatched(NeuronGradientMethod):
549549
:py:class:`NeuronGradientMethod`
550550
"""
551551
552-
def compute_attribution_map(self,
553-
attribute_to_neuron_input=False,
554-
batch_size=1024):
552+
def compute_attribution_map(
553+
self,
554+
attribute_to_neuron_input: bool = False,
555+
batch_size: int = 1024
556+
) -> dict:
557+
"""Compute attribution map using mini-batches.
558+
559+
Args:
560+
attribute_to_neuron_input: If True, attribute to neuron input
561+
batch_size: Size of mini-batches for processing
562+
563+
Returns:
564+
Dictionary containing attribution maps
565+
"""
555566
input_data_batches = torch.split(self.input_data, batch_size)
556567
557568
attribution_map = []
@@ -571,9 +582,7 @@ def compute_attribution_map(self,
571582
572583
attribution_map = np.vstack(attribution_map)
573584
return self._reduce_attribution_map(
574-
{'neuron-gradient': attribution_map,
575-
#'neuron-gradient-invsvd': self._inverse_svd(attribution_map)
576-
})
585+
{'neuron-gradient': attribution_map})
577586
578587
579588
@dataclasses.dataclass
@@ -601,7 +610,7 @@ def __post_init__(self):
601610
"""
602611
super().__post_init__()
603612
self.captum_model = NeuronFeatureAblation(forward_func=self.model,
604-
layer=self.model)
613+
layer=self.model)
605614
606615
def compute_attribution_map(self,
607616
baselines=None,
@@ -638,20 +647,19 @@ def compute_attribution_map(self,
638647
639648
@dataclasses.dataclass
640649
@register("feature-ablation-batched")
641-
class FeatureAblationMethodBAtched(FeatureAblationMethod):
650+
class FeatureAblationMethodBatched(FeatureAblationMethod):
642651
"""As :py:class:`FeatureAblationMethod`, but using mini-batches.
643652

644653
See also:
645654
:py:class:`FeatureAblationMethod`
646655
"""
647656
648657
def compute_attribution_map(self,
649-
baselines=None,
650-
feature_mask=None,
651-
perturbations_per_eval=1,
652-
attribute_to_neuron_input=False,
653-
batch_size=1024):
654-
658+
baselines=None,
659+
feature_mask=None,
660+
perturbations_per_eval=1,
661+
attribute_to_neuron_input=False,
662+
batch_size=1024):
655663
input_data_batches = torch.split(self.input_data, batch_size)
656664
attribution_map = []
657665
for input_data_batch in input_data_batches:

0 commit comments

Comments
 (0)