4141
4242@dataclasses .dataclass
4343class AttributionMap :
44+ """Base class for computing attribution maps for CEBRA models.
45+
46+ Args:
47+ model: The trained CEBRA model to analyze
48+ input_data: Input data tensor to compute attributions for
49+ output_dimension: Output dimension to analyze. If ``None``, uses model's output dimension
50+ num_samples: Number of samples to use for attribution. If ``None``, uses full dataset
51+ seed: Random seed which is used to subsample the data. Only relevant if ``num_samples`` is not ``None``.
52+ """
53+
4454 model : nn .Module
4555 input_data : torch .Tensor
4656 output_dimension : int = None
@@ -78,10 +88,40 @@ def __post_init__(self):
7888 self .input_data = input_data
7989
8090 def compute_attribution_map (self ):
91+ """Compute the attribution map for the model.
92+
93+ Returns:
94+ dict: Attribution maps and their variants
95+
96+ Raises:
97+ NotImplementedError: Must be implemented by subclasses
98+ """
8199 raise NotImplementedError
82100
83101 def compute_metrics (self , attribution_map , ground_truth_map ):
84- # Note: 0: nonconnected, 1: connected
102+ """Compute metrics comparing attribution map to ground truth.
103+
104+ This function computes various statistical metrics to compare the attribution values
105+ between connected and non-connected neurons based on a ground truth connectivity map.
106+ It separates the attribution values into two groups based on the binary ground truth,
107+ and calculates summary statistics and differences between these groups.
108+
109+ Args:
110+ attribution_map: Computed attribution values representing the strength of connections
111+ between neurons
112+ ground_truth_map: Binary ground truth connectivity map where True indicates a
113+ connected neuron and False indicates a non-connected neuron
114+
115+ Returns:
116+ dict: Dictionary containing the following metrics:
117+ - max/mean/min_nonconnected: Statistics for non-connected neurons
118+ - max/mean/min_connected: Statistics for connected neurons
119+ - gap_max: Difference between max connected and max non-connected values
120+ - gap_mean: Difference between mean connected and mean non-connected values
121+ - gap_min: Difference between min connected and min non-connected values
122+ - gap_minmax: Difference between min connected and max non-connected values
123+ - max/min_jacobian: Global max/min values across all neurons
124+ """
85125 assert np .issubdtype (ground_truth_map .dtype , bool )
86126 connected_neurons = attribution_map [np .where (ground_truth_map )]
87127 non_connected_neurons = attribution_map [np .where (~ ground_truth_map )]
@@ -115,6 +155,15 @@ def compute_metrics(self, attribution_map, ground_truth_map):
115155 return metrics
116156
117157 def compute_attribution_score (self , attribution_map , ground_truth_map ):
158+ """Compute ROC AUC score between attribution map and ground truth.
159+
160+ Args:
161+ attribution_map: Computed attribution values
162+ ground_truth_map: Binary ground truth connectivity map
163+
164+ Returns:
165+ float: ROC AUC score
166+ """
118167 assert attribution_map .shape == ground_truth_map .shape
119168 assert np .issubdtype (ground_truth_map .dtype , bool )
120169 fpr , tpr , _ = sklearn .metrics .roc_curve ( # noqa: codespell:ignore fpr, tpr
@@ -125,6 +174,15 @@ def compute_attribution_score(self, attribution_map, ground_truth_map):
125174 @staticmethod
126175 def _check_moores_penrose_conditions (
127176 matrix : np .ndarray , matrix_inverse : np .ndarray ) -> np .ndarray :
177+ """Check Moore-Penrose conditions for a single matrix pair.
178+
179+ Args:
180+ matrix: Input matrix
181+ matrix_inverse: Putative pseudoinverse matrix
182+
183+ Returns:
184+ np.ndarray: Boolean array indicating which conditions are satisfied
185+ """
128186 matrix_inverse = matrix_inverse .T
129187 condition_1 = np .allclose (matrix @ matrix_inverse @ matrix , matrix )
130188 condition_2 = np .allclose (matrix_inverse @ matrix @ matrix_inverse ,
@@ -139,14 +197,14 @@ def _check_moores_penrose_conditions(
139197 def check_moores_penrose_conditions (
140198 self , jacobian : np .ndarray ,
141199 jacobian_pseudoinverse : np .ndarray ) -> np .ndarray :
142- """
143- Checks the four conditions for the Moore-Penrose conditions for the
144- pseudo-inverse of a matrix.
200+ """Check Moore-Penrose conditions for Jacobian matrices.
201+
145202 Args:
146- jacobian: The Jacobian matrix of dhape (num samples, output_dim, num_neurons).
147- jacobian_pseudoinverse: The pseudo-inverse of the Jacobian matrix of shape (num samples, num_neurons, output_dim).
203+ jacobian: Jacobian matrices of shape (num samples, output_dim, num_neurons)
204+ jacobian_pseudoinverse: Pseudoinverse matrices of shape (num samples, num_neurons, output_dim)
205+
148206 Returns:
149- moores_penrose_conditions: A boolean array of shape (num samples, 4) where each row corresponds to a sample and each column to a condition.
207+ Boolean array of shape (num samples, 4) indicating satisfied conditions
150208 """
151209 # check the four conditions
152210 conditions = np .zeros ((jacobian .shape [0 ], 4 ))
@@ -157,6 +215,15 @@ def check_moores_penrose_conditions(
157215 return conditions
158216
159217 def _inverse (self , jacobian , method = "lsq" ):
218+ """Compute inverse/pseudoinverse of Jacobian matrices.
219+
220+ Args:
221+ jacobian: Input Jacobian matrices
222+ method: Inversion method ('lsq_cvxpy', 'lsq', or 'svd')
223+
224+ Returns:
225+ (Inverse matrices, computation time)
226+ """
160227 # NOTE(stes): Before we used "np.linalg.pinv" here, which
161228 # is numerically not stable for the Jacobian matrices we
162229 # need to compute.
@@ -179,10 +246,14 @@ def _inverse(self, jacobian, method="lsq"):
179246 @staticmethod
180247 def _inverse_lsq_cvxpy (matrix : np .ndarray ,
181248 solver : str = 'SCS' ) -> np .ndarray :
182- """
183- Solves the least squares problem
184- min ||A @ X - I||_2 = (A @ X - I, A @ X - I) = (A @ X)**2 - 2 * (A @ X, I) + (I, I) =
185- = (A @ X)**2 - 2 * (A @ X, I) + const -> min quadratic function of X
249+ """Compute least squares inverse using CVXPY.
250+
251+ Args:
252+ matrix: Input matrix
253+ solver: CVXPY solver to use
254+
255+ Returns:
256+ np.ndarray: Least squares inverse matrix
186257 """
187258
188259 matrix_param = cp .Parameter ((matrix .shape [0 ], matrix .shape [1 ]))
@@ -201,13 +272,37 @@ def _inverse_lsq_cvxpy(matrix: np.ndarray,
201272
202273 @staticmethod
203274 def _inverse_lsq_scipy (jacobian ):
275+ """Compute least squares inverse using scipy.linalg.lstsq.
276+
277+ Args:
278+ jacobian: Input Jacobian matrix
279+
280+ Returns:
281+ np.ndarray: Least squares inverse matrix
282+ """
204283 return scipy .linalg .lstsq (jacobian , np .eye (jacobian .shape [0 ]))[0 ]
205284
206285 @staticmethod
207286 def _inverse_svd (jacobian ):
287+ """Compute pseudoinverse using SVD.
288+
289+ Args:
290+ jacobian: Input Jacobian matrix
291+
292+ Returns:
293+ np.ndarray: Pseudoinverse matrix
294+ """
208295 return scipy .linalg .pinv (jacobian )
209296
210297 def _reduce_attribution_map (self , attribution_maps ):
298+ """Reduce attribution maps by averaging across dimensions.
299+
300+ Args:
301+ attribution_maps: Dictionary of attribution maps to reduce
302+
303+ Returns:
304+ dict: Reduced attribution maps
305+ """
211306
212307 def _reduce (full_jacobian ):
213308 if full_jacobian .ndim == 4 :
@@ -227,6 +322,7 @@ def _reduce(full_jacobian):
227322@dataclasses .dataclass
228323@register ("jacobian-based" )
229324class JFMethodBased (AttributionMap ):
325+ """Compute the attribution map using the Jacobian of the model encoder."""
230326
231327 def _compute_jacobian (self , input_data ):
232328 return cebra .attribution ._jacobian .compute_jacobian (
@@ -261,6 +357,11 @@ def compute_attribution_map(self):
261357@dataclasses .dataclass
262358@register ("jacobian-based-batched" )
263359class JFMethodBasedBatched (JFMethodBased ):
360+ """Compute an attribution map based on the Jacobian using mini-batches.
361+
362+ See also:
363+ :py:class:`JFMethodBased`
364+ """
264365
265366 def compute_attribution_map (self , batch_size = 1024 ):
266367 if batch_size > self .input_data .shape [0 ]:
@@ -285,14 +386,19 @@ def compute_attribution_map(self, batch_size=1024):
285386 result [f"{ key } -inv-{ method } " ], result [
286387 f'time_inversion_{ method } ' ] = self ._inverse (value ,
287388 method = method )
288- # result[f"{key}-inv-{method}-conditions"] = self.check_moores_penrose_conditions(value, result[f"{key}-inv-{method}"])
289389
290390 return result
291391
292392
293393@dataclasses .dataclass
294394@register ("neuron-gradient" )
295395class NeuronGradientMethod (AttributionMap ):
396+ """Compute the attribution map using the neuron gradient from Captum.
397+
398+ Note:
399+ This method is equivalent to Jacobian-based attributions, but
400+ uses a different backend implementation.
401+ """
296402
297403 def __post_init__ (self ):
298404 super ().__post_init__ ()
@@ -330,6 +436,11 @@ def compute_attribution_map(self, attribute_to_neuron_input=False):
330436@dataclasses .dataclass
331437@register ("neuron-gradient-batched" )
332438class NeuronGradientMethodBatched (NeuronGradientMethod ):
439+ """As :py:class:`NeuronGradientMethod`, but using mini-batches.
440+
441+ See also:
442+ :py:class:`NeuronGradientMethod`
443+ """
333444
334445 def compute_attribution_map (self ,
335446 attribute_to_neuron_input = False ,
@@ -361,6 +472,7 @@ def compute_attribution_map(self,
361472@dataclasses .dataclass
362473@register ("feature-ablation" )
363474class FeatureAblationMethod (AttributionMap ):
475+ """Compute the attribution map using the feature ablation method from Captum."""
364476
365477 def __post_init__ (self ):
366478 super ().__post_init__ ()
@@ -393,6 +505,11 @@ def compute_attribution_map(self,
393505@dataclasses .dataclass
394506@register ("feature-ablation-batched" )
395507class FeatureAblationMethodBAtched (FeatureAblationMethod ):
508+ """As :py:class:`FeatureAblationMethod`, but using mini-batches.
509+
510+ See also:
511+ :py:class:`FeatureAblationMethod`
512+ """
396513
397514 def compute_attribution_map (self ,
398515 baselines = None ,
@@ -428,6 +545,7 @@ def compute_attribution_map(self,
428545@dataclasses .dataclass
429546@register ("integrated-gradients" )
430547class IntegratedGradientsMethod (AttributionMap ):
548+ """Compute the attribution map using the integrated gradients method from Captum."""
431549
432550 def __post_init__ (self ):
433551 super ().__post_init__ ()
@@ -465,6 +583,11 @@ def compute_attribution_map(self,
465583@dataclasses .dataclass
466584@register ("integrated-gradients-batched" )
467585class IntegratedGradientsMethodBatched (IntegratedGradientsMethod ):
586+ """As :py:class:`IntegratedGradientsMethod`, but using mini-batches.
587+
588+ See also:
589+ :py:class:`IntegratedGradientsMethod`
590+ """
468591
469592 def compute_attribution_map (self ,
470593 n_steps = 50 ,
@@ -504,6 +627,7 @@ def compute_attribution_map(self,
504627@dataclasses .dataclass
505628@register ("neuron-gradient-shap" )
506629class NeuronGradientShapMethod (AttributionMap ):
630+ """Compute the attribution map using the neuron gradient SHAP method from Captum."""
507631
508632 def __post_init__ (self ):
509633 super ().__post_init__ ()
@@ -548,6 +672,11 @@ def compute_attribution_map(self,
548672@dataclasses .dataclass
549673@register ("neuron-gradient-shap-batched" )
550674class NeuronGradientShapMethodBatched (NeuronGradientShapMethod ):
675+ """As :py:class:`NeuronGradientShapMethod`, but using mini-batches.
676+
677+ See also:
678+ :py:class:`NeuronGradientShapMethod`
679+ """
551680
552681 def compute_attribution_map (self ,
553682 baselines : str ,
0 commit comments