Skip to content

Commit df4f661

Browse files
committed
fix docstrings
1 parent 6f91018 commit df4f661

File tree

4 files changed

+215
-12
lines changed

4 files changed

+215
-12
lines changed

cebra/attribution/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121
#
22+
"""Attribution methods for CEBRA.
23+
24+
This module was added in v0.6.0 and contains attribution methods described and benchmarked
25+
in :cite:`schneider2025xcebra`:
26+
27+
.. [schneider2025xcebra] Schneider, S., González Laiz, R., Filippova, A., Frey, M., & Mathis, M. W. (2025).
28+
Time-series attribution maps with regularized contrastive learning.
29+
The 28th International Conference on Artificial Intelligence and Statistics.
30+
https://openreview.net/forum?id=aGrCXoTB4P
31+
"""
2232
import cebra.registry
2333

2434
cebra.registry.add_helper_functions(__name__)

cebra/attribution/attribution_models.py

Lines changed: 141 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@
4141

4242
@dataclasses.dataclass
4343
class AttributionMap:
44+
"""Base class for computing attribution maps for CEBRA models.
45+
46+
Args:
47+
model: The trained CEBRA model to analyze
48+
input_data: Input data tensor to compute attributions for
49+
output_dimension: Output dimension to analyze. If ``None``, uses model's output dimension
50+
num_samples: Number of samples to use for attribution. If ``None``, uses full dataset
51+
seed: Random seed which is used to subsample the data. Only relevant if ``num_samples`` is not ``None``.
52+
"""
53+
4454
model: nn.Module
4555
input_data: torch.Tensor
4656
output_dimension: int = None
@@ -78,10 +88,40 @@ def __post_init__(self):
7888
self.input_data = input_data
7989

8090
def compute_attribution_map(self):
91+
"""Compute the attribution map for the model.
92+
93+
Returns:
94+
dict: Attribution maps and their variants
95+
96+
Raises:
97+
NotImplementedError: Must be implemented by subclasses
98+
"""
8199
raise NotImplementedError
82100

83101
def compute_metrics(self, attribution_map, ground_truth_map):
84-
# Note: 0: nonconnected, 1: connected
102+
"""Compute metrics comparing attribution map to ground truth.
103+
104+
This function computes various statistical metrics to compare the attribution values
105+
between connected and non-connected neurons based on a ground truth connectivity map.
106+
It separates the attribution values into two groups based on the binary ground truth,
107+
and calculates summary statistics and differences between these groups.
108+
109+
Args:
110+
attribution_map: Computed attribution values representing the strength of connections
111+
between neurons
112+
ground_truth_map: Binary ground truth connectivity map where True indicates a
113+
connected neuron and False indicates a non-connected neuron
114+
115+
Returns:
116+
dict: Dictionary containing the following metrics:
117+
- max/mean/min_nonconnected: Statistics for non-connected neurons
118+
- max/mean/min_connected: Statistics for connected neurons
119+
- gap_max: Difference between max connected and max non-connected values
120+
- gap_mean: Difference between mean connected and mean non-connected values
121+
- gap_min: Difference between min connected and min non-connected values
122+
- gap_minmax: Difference between min connected and max non-connected values
123+
- max/min_jacobian: Global max/min values across all neurons
124+
"""
85125
assert np.issubdtype(ground_truth_map.dtype, bool)
86126
connected_neurons = attribution_map[np.where(ground_truth_map)]
87127
non_connected_neurons = attribution_map[np.where(~ground_truth_map)]
@@ -115,6 +155,15 @@ def compute_metrics(self, attribution_map, ground_truth_map):
115155
return metrics
116156

117157
def compute_attribution_score(self, attribution_map, ground_truth_map):
158+
"""Compute ROC AUC score between attribution map and ground truth.
159+
160+
Args:
161+
attribution_map: Computed attribution values
162+
ground_truth_map: Binary ground truth connectivity map
163+
164+
Returns:
165+
float: ROC AUC score
166+
"""
118167
assert attribution_map.shape == ground_truth_map.shape
119168
assert np.issubdtype(ground_truth_map.dtype, bool)
120169
fpr, tpr, _ = sklearn.metrics.roc_curve( # noqa: codespell:ignore fpr, tpr
@@ -125,6 +174,15 @@ def compute_attribution_score(self, attribution_map, ground_truth_map):
125174
@staticmethod
126175
def _check_moores_penrose_conditions(
127176
matrix: np.ndarray, matrix_inverse: np.ndarray) -> np.ndarray:
177+
"""Check Moore-Penrose conditions for a single matrix pair.
178+
179+
Args:
180+
matrix: Input matrix
181+
matrix_inverse: Putative pseudoinverse matrix
182+
183+
Returns:
184+
np.ndarray: Boolean array indicating which conditions are satisfied
185+
"""
128186
matrix_inverse = matrix_inverse.T
129187
condition_1 = np.allclose(matrix @ matrix_inverse @ matrix, matrix)
130188
condition_2 = np.allclose(matrix_inverse @ matrix @ matrix_inverse,
@@ -139,14 +197,14 @@ def _check_moores_penrose_conditions(
139197
def check_moores_penrose_conditions(
140198
self, jacobian: np.ndarray,
141199
jacobian_pseudoinverse: np.ndarray) -> np.ndarray:
142-
"""
143-
Checks the four conditions for the Moore-Penrose conditions for the
144-
pseudo-inverse of a matrix.
200+
"""Check Moore-Penrose conditions for Jacobian matrices.
201+
145202
Args:
146-
jacobian: The Jacobian matrix of dhape (num samples, output_dim, num_neurons).
147-
jacobian_pseudoinverse: The pseudo-inverse of the Jacobian matrix of shape (num samples, num_neurons, output_dim).
203+
jacobian: Jacobian matrices of shape (num samples, output_dim, num_neurons)
204+
jacobian_pseudoinverse: Pseudoinverse matrices of shape (num samples, num_neurons, output_dim)
205+
148206
Returns:
149-
moores_penrose_conditions: A boolean array of shape (num samples, 4) where each row corresponds to a sample and each column to a condition.
207+
Boolean array of shape (num samples, 4) indicating satisfied conditions
150208
"""
151209
# check the four conditions
152210
conditions = np.zeros((jacobian.shape[0], 4))
@@ -157,6 +215,15 @@ def check_moores_penrose_conditions(
157215
return conditions
158216

159217
def _inverse(self, jacobian, method="lsq"):
218+
"""Compute inverse/pseudoinverse of Jacobian matrices.
219+
220+
Args:
221+
jacobian: Input Jacobian matrices
222+
method: Inversion method ('lsq_cvxpy', 'lsq', or 'svd')
223+
224+
Returns:
225+
(Inverse matrices, computation time)
226+
"""
160227
# NOTE(stes): Before we used "np.linalg.pinv" here, which
161228
# is numerically not stable for the Jacobian matrices we
162229
# need to compute.
@@ -179,10 +246,14 @@ def _inverse(self, jacobian, method="lsq"):
179246
@staticmethod
180247
def _inverse_lsq_cvxpy(matrix: np.ndarray,
181248
solver: str = 'SCS') -> np.ndarray:
182-
"""
183-
Solves the least squares problem
184-
min ||A @ X - I||_2 = (A @ X - I, A @ X - I) = (A @ X)**2 - 2 * (A @ X, I) + (I, I) =
185-
= (A @ X)**2 - 2 * (A @ X, I) + const -> min quadratic function of X
249+
"""Compute least squares inverse using CVXPY.
250+
251+
Args:
252+
matrix: Input matrix
253+
solver: CVXPY solver to use
254+
255+
Returns:
256+
np.ndarray: Least squares inverse matrix
186257
"""
187258

188259
matrix_param = cp.Parameter((matrix.shape[0], matrix.shape[1]))
@@ -201,13 +272,37 @@ def _inverse_lsq_cvxpy(matrix: np.ndarray,
201272

202273
@staticmethod
203274
def _inverse_lsq_scipy(jacobian):
275+
"""Compute least squares inverse using scipy.linalg.lstsq.
276+
277+
Args:
278+
jacobian: Input Jacobian matrix
279+
280+
Returns:
281+
np.ndarray: Least squares inverse matrix
282+
"""
204283
return scipy.linalg.lstsq(jacobian, np.eye(jacobian.shape[0]))[0]
205284

206285
@staticmethod
207286
def _inverse_svd(jacobian):
287+
"""Compute pseudoinverse using SVD.
288+
289+
Args:
290+
jacobian: Input Jacobian matrix
291+
292+
Returns:
293+
np.ndarray: Pseudoinverse matrix
294+
"""
208295
return scipy.linalg.pinv(jacobian)
209296

210297
def _reduce_attribution_map(self, attribution_maps):
298+
"""Reduce attribution maps by averaging across dimensions.
299+
300+
Args:
301+
attribution_maps: Dictionary of attribution maps to reduce
302+
303+
Returns:
304+
dict: Reduced attribution maps
305+
"""
211306

212307
def _reduce(full_jacobian):
213308
if full_jacobian.ndim == 4:
@@ -227,6 +322,7 @@ def _reduce(full_jacobian):
227322
@dataclasses.dataclass
228323
@register("jacobian-based")
229324
class JFMethodBased(AttributionMap):
325+
"""Compute the attribution map using the Jacobian of the model encoder."""
230326

231327
def _compute_jacobian(self, input_data):
232328
return cebra.attribution._jacobian.compute_jacobian(
@@ -261,6 +357,11 @@ def compute_attribution_map(self):
261357
@dataclasses.dataclass
262358
@register("jacobian-based-batched")
263359
class JFMethodBasedBatched(JFMethodBased):
360+
"""Compute an attribution map based on the Jacobian using mini-batches.
361+
362+
See also:
363+
:py:class:`JFMethodBased`
364+
"""
264365

265366
def compute_attribution_map(self, batch_size=1024):
266367
if batch_size > self.input_data.shape[0]:
@@ -285,14 +386,19 @@ def compute_attribution_map(self, batch_size=1024):
285386
result[f"{key}-inv-{method}"], result[
286387
f'time_inversion_{method}'] = self._inverse(value,
287388
method=method)
288-
# result[f"{key}-inv-{method}-conditions"] = self.check_moores_penrose_conditions(value, result[f"{key}-inv-{method}"])
289389

290390
return result
291391

292392

293393
@dataclasses.dataclass
294394
@register("neuron-gradient")
295395
class NeuronGradientMethod(AttributionMap):
396+
"""Compute the attribution map using the neuron gradient from Captum.
397+
398+
Note:
399+
This method is equivalent to Jacobian-based attributions, but
400+
uses a different backend implementation.
401+
"""
296402

297403
def __post_init__(self):
298404
super().__post_init__()
@@ -330,6 +436,11 @@ def compute_attribution_map(self, attribute_to_neuron_input=False):
330436
@dataclasses.dataclass
331437
@register("neuron-gradient-batched")
332438
class NeuronGradientMethodBatched(NeuronGradientMethod):
439+
"""As :py:class:`NeuronGradientMethod`, but using mini-batches.
440+
441+
See also:
442+
:py:class:`NeuronGradientMethod`
443+
"""
333444

334445
def compute_attribution_map(self,
335446
attribute_to_neuron_input=False,
@@ -361,6 +472,7 @@ def compute_attribution_map(self,
361472
@dataclasses.dataclass
362473
@register("feature-ablation")
363474
class FeatureAblationMethod(AttributionMap):
475+
"""Compute the attribution map using the feature ablation method from Captum."""
364476

365477
def __post_init__(self):
366478
super().__post_init__()
@@ -393,6 +505,11 @@ def compute_attribution_map(self,
393505
@dataclasses.dataclass
394506
@register("feature-ablation-batched")
395507
class FeatureAblationMethodBAtched(FeatureAblationMethod):
508+
"""As :py:class:`FeatureAblationMethod`, but using mini-batches.
509+
510+
See also:
511+
:py:class:`FeatureAblationMethod`
512+
"""
396513

397514
def compute_attribution_map(self,
398515
baselines=None,
@@ -428,6 +545,7 @@ def compute_attribution_map(self,
428545
@dataclasses.dataclass
429546
@register("integrated-gradients")
430547
class IntegratedGradientsMethod(AttributionMap):
548+
"""Compute the attribution map using the integrated gradients method from Captum."""
431549

432550
def __post_init__(self):
433551
super().__post_init__()
@@ -465,6 +583,11 @@ def compute_attribution_map(self,
465583
@dataclasses.dataclass
466584
@register("integrated-gradients-batched")
467585
class IntegratedGradientsMethodBatched(IntegratedGradientsMethod):
586+
"""As :py:class:`IntegratedGradientsMethod`, but using mini-batches.
587+
588+
See also:
589+
:py:class:`IntegratedGradientsMethod`
590+
"""
468591

469592
def compute_attribution_map(self,
470593
n_steps=50,
@@ -504,6 +627,7 @@ def compute_attribution_map(self,
504627
@dataclasses.dataclass
505628
@register("neuron-gradient-shap")
506629
class NeuronGradientShapMethod(AttributionMap):
630+
"""Compute the attribution map using the neuron gradient SHAP method from Captum."""
507631

508632
def __post_init__(self):
509633
super().__post_init__()
@@ -548,6 +672,11 @@ def compute_attribution_map(self,
548672
@dataclasses.dataclass
549673
@register("neuron-gradient-shap-batched")
550674
class NeuronGradientShapMethodBatched(NeuronGradientShapMethod):
675+
"""As :py:class:`NeuronGradientShapMethod`, but using mini-batches.
676+
677+
See also:
678+
:py:class:`NeuronGradientShapMethod`
679+
"""
551680

552681
def compute_attribution_map(self,
553682
baselines: str,

cebra/data/datasets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def _iter_property(self, attr):
303303
return (getattr(data, attr) for data in self.iter_sessions())
304304

305305

306+
# TODO(stes): This should be a single session dataset?
306307
class DatasetxCEBRA(cebra_io.HasDevice):
307308

308309
def __init__(

cebra/models/multi_criterions.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,69 @@
2828

2929

3030
class MultiCriterions(nn.Module):
31+
"""A module for handling multiple loss functions with different criteria.
32+
33+
This module allows combining multiple loss functions, each operating on specific
34+
slices of the input data. It supports both supervised and contrastive learning modes.
35+
36+
Args:
37+
losses: A list of dictionaries containing loss configurations. Each dictionary should have:
38+
- 'indices': Tuple of (start, end) indices for the data slice
39+
- 'supervised_loss': Dict with loss config for supervised mode
40+
- 'contrastive_loss': Dict with loss config for contrastive mode
41+
Loss configs should contain:
42+
- 'name': Name of the loss function
43+
- 'kwargs': Optional parameters for the loss function
44+
mode: Either "supervised" or "contrastive" to specify the training mode
45+
46+
The loss functions can be from torch.nn or custom implementations from cebra.models.criterions.
47+
Each criterion is applied to its corresponding slice of the input data during forward pass.
48+
49+
Example:
50+
>>> import torch
51+
>>> from cebra.data.datatypes import Batch
52+
>>> # Define loss configurations for a hybrid model with both contrastive and supervised losses
53+
>>> losses = [
54+
... {
55+
... 'indices': (0, 10), # First 10 dimensions
56+
... 'contrastive_loss': {
57+
... 'name': 'InfoNCE', # Using CEBRA's InfoNCE loss
58+
... 'kwargs': {'temperature': 1.0}
59+
... },
60+
... 'supervised_loss': {
61+
... 'name': 'nn.MSELoss', # Using PyTorch's MSE loss
62+
... 'kwargs': {}
63+
... }
64+
... },
65+
... {
66+
... 'indices': (10, 20), # Next 10 dimensions
67+
... 'contrastive_loss': {
68+
... 'name': 'InfoNCE', # Using CEBRA's InfoNCE loss
69+
... 'kwargs': {'temperature': 0.5}
70+
... },
71+
... 'supervised_loss': {
72+
... 'name': 'nn.L1Loss', # Using PyTorch's L1 loss
73+
... 'kwargs': {}
74+
... }
75+
... }
76+
... ]
77+
>>> # Create sample predictions (2 batches of 32 samples each with 10 features)
78+
>>> ref1 = torch.randn(32, 10)
79+
>>> pos1 = torch.randn(32, 10)
80+
>>> neg1 = torch.randn(32, 10)
81+
>>> ref2 = torch.randn(32, 10)
82+
>>> pos2 = torch.randn(32, 10)
83+
>>> neg2 = torch.randn(32, 10)
84+
>>> predictions = (
85+
... Batch(reference=ref1, positive=pos1, negative=neg1),
86+
... Batch(reference=ref2, positive=pos2, negative=neg2)
87+
... )
88+
>>> # Create multi-criterion module in contrastive mode
89+
>>> multi_loss = MultiCriterions(losses, mode="contrastive")
90+
>>> # Forward pass with multiple predictions
91+
>>> losses = multi_loss(predictions) # Returns list of loss values
92+
>>> assert len(losses) == 2 # One loss per criterion
93+
"""
3194

3295
def __init__(self, losses, mode):
3396
super(MultiCriterions, self).__init__()

0 commit comments

Comments
 (0)