@@ -432,13 +432,39 @@ def compute_attribution_map(self):
432432@dataclasses.dataclass
433433@register("jacobian-based-batched")
434434class JFMethodBasedBatched(JFMethodBased):
435- """ Compute an attribution map based on the Jacobian using mini - batches .
435+ """ Batched version of the Jacobian - based attribution method .
436+
437+ This class extends JFMethodBased to compute attribution maps using mini - batches ,
438+ which is useful for handling large datasets that don 't fit in memory . It processes
439+ the input data in smaller chunks and combines the results .
440+
441+ Args :
442+ model : The trained CEBRA model to analyze
443+ input_data : Input data tensor to compute attributions for
444+ output_dimension : Output dimension to analyze . If `` None `` , uses model 's output dimension
445+ num_samples : Number of samples to use for attribution . If `` None `` , uses full dataset
446+ seed : Random seed which is used to subsample the data
436447
437448 See also :
438- :py :class :`JFMethodBased`
449+ :py :class :`JFMethodBased` : The base class for Jacobian - based attribution
439450 """
440451
441452 def compute_attribution_map(self, batch_size=1024):
453+ """ Compute the attribution map using batched Jacobian method .
454+
455+ This method processes the input data in mini - batches to handle large datasets
456+ that don 't fit in memory . It computes the Jacobian for each batch and combines
457+ the results .
458+
459+ Args :
460+ batch_size : Size of each mini - batch . Default is 1024.
461+
462+ Returns :
463+ dict : Dictionary containing attribution maps and their variants
464+
465+ Raises :
466+ ValueError : If batch_size is larger than the number of samples
467+ """
442468 if batch_size > self.input_data.shape[0]:
443469 raise ValueError(
444470 f"Batch size ({batch_size}) is bigger than data ({self.input_data.shape[0]})"
@@ -457,7 +483,6 @@ def compute_attribution_map(self, batch_size=1024):
457483 }).items():
458484 result[key] = value
459485 for method in ['lsq', 'svd']:
460-
461486 result[f"{key}-inv-{method}"], result[
462487 f'time_inversion_{method}'] = self._inverse(value,
463488 method=method)
@@ -545,9 +570,9 @@ def compute_attribution_map(self,
545570 attribution_map.append(attribution_map_batch)
546571
547572 attribution_map = np.vstack(attribution_map)
548- return self._reduce_attribution_map({
549- 'neuron-gradient': attribution_map,
550- #'neuron-gradient-invsvd': self._inverse_svd(attribution_map)
573+ return self._reduce_attribution_map(
574+ { 'neuron-gradient': attribution_map,
575+ #'neuron-gradient-invsvd': self._inverse_svd(attribution_map)
551576 })
552577
553578
0 commit comments