Skip to content

Commit d5ce08c

Browse files
committed
doctrings xcebra2
1 parent 5560756 commit d5ce08c

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

cebra/attribution/attribution_models.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -432,13 +432,39 @@ def compute_attribution_map(self):
432432
@dataclasses.dataclass
433433
@register("jacobian-based-batched")
434434
class JFMethodBasedBatched(JFMethodBased):
435-
"""Compute an attribution map based on the Jacobian using mini-batches.
435+
"""Batched version of the Jacobian-based attribution method.
436+
437+
This class extends JFMethodBased to compute attribution maps using mini-batches,
438+
which is useful for handling large datasets that don't fit in memory. It processes
439+
the input data in smaller chunks and combines the results.
440+
441+
Args:
442+
model: The trained CEBRA model to analyze
443+
input_data: Input data tensor to compute attributions for
444+
output_dimension: Output dimension to analyze. If ``None``, uses model's output dimension
445+
num_samples: Number of samples to use for attribution. If ``None``, uses full dataset
446+
seed: Random seed which is used to subsample the data
436447

437448
See also:
438-
:py:class:`JFMethodBased`
449+
:py:class:`JFMethodBased`: The base class for Jacobian-based attribution
439450
"""
440451
441452
def compute_attribution_map(self, batch_size=1024):
453+
"""Compute the attribution map using batched Jacobian method.
454+
455+
This method processes the input data in mini-batches to handle large datasets
456+
that don't fit in memory. It computes the Jacobian for each batch and combines
457+
the results.
458+
459+
Args:
460+
batch_size: Size of each mini-batch. Default is 1024.
461+
462+
Returns:
463+
dict: Dictionary containing attribution maps and their variants
464+
465+
Raises:
466+
ValueError: If batch_size is larger than the number of samples
467+
"""
442468
if batch_size > self.input_data.shape[0]:
443469
raise ValueError(
444470
f"Batch size ({batch_size}) is bigger than data ({self.input_data.shape[0]})"
@@ -457,7 +483,6 @@ def compute_attribution_map(self, batch_size=1024):
457483
}).items():
458484
result[key] = value
459485
for method in ['lsq', 'svd']:
460-
461486
result[f"{key}-inv-{method}"], result[
462487
f'time_inversion_{method}'] = self._inverse(value,
463488
method=method)
@@ -545,9 +570,9 @@ def compute_attribution_map(self,
545570
attribution_map.append(attribution_map_batch)
546571
547572
attribution_map = np.vstack(attribution_map)
548-
return self._reduce_attribution_map({
549-
'neuron-gradient': attribution_map,
550-
#'neuron-gradient-invsvd': self._inverse_svd(attribution_map)
573+
return self._reduce_attribution_map(
574+
{'neuron-gradient': attribution_map,
575+
#'neuron-gradient-invsvd': self._inverse_svd(attribution_map)
551576
})
552577
553578

0 commit comments

Comments
 (0)